Compare commits

..

50 Commits

Author SHA1 Message Date
openhands
373a94a839 Add MockGitHubService for real HTTP-based testing
- Create MockGitHubService: in-process FastAPI server that implements
  GitHub API endpoints
- Server tracks state (comments, reactions, API calls) for verification
- No more MagicMock complexity - real HTTP calls via PyGithub
- Fixture patches Github base_url to point to mock server
- Simplified test using service.assert_comment_sent() etc.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-25 17:42:27 +00:00
openhands
e50bb44441 Simplify V1 GitHub Resolver test to single E2E test
Reduced to exactly ONE test that verifies the complete webhook flow:
1. Receives GitHub webhook payload
2. Routes to V1 path (v1_enabled=True)
3. Starts agent server via start_app_conversation
4. Verifies 'I'm on it' message is sent
5. Verifies eyes reaction is added

TestLLM is available for injection when running real agent server.
For this test, we mock start_app_conversation to simulate agent behavior.

Run with:
  cd enterprise
  PYTHONPATH='.:' poetry run pytest tests/integration/v1_github_resolver -v

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-25 17:28:55 +00:00
openhands
894922e335 Add real agent server tests with OpenHands DB fixture
This update adds:
1. test_real_agent_server_with_openhands_db - verifies full flow with both
   enterprise and OpenHands databases set up

2. openhands_db fixture - creates async OpenHands tables for
   app_conversation_start_task and conversation_metadata

3. Fixed test isolation issue by noting --forked flag requirement

Two-tier testing strategy:
- Tier 1 (CI): Mocks ProcessSandbox, verifies correct flow path
- Tier 2 (Staging): Full E2E with real ProcessSandbox + TestLLM injection

Tests now verify:
1. V1 path is correctly selected
2. _create_v1_conversation is called
3. start_app_conversation is accessed (agent server creation)
4. 'I'm on it' message is sent
5. Eyes reaction is added
6. Callback processor sends summary

Run tests with:
  cd enterprise
  PYTHONPATH='.:' poetry run pytest tests/integration/v1_github_resolver -v --forked

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-25 16:59:22 +00:00
openhands
4370e9534f Add test for V1 conversation creation path
This update adds:
1. TestV1WebhookFlowWithRealAgentServer class that verifies:
   - V1 path is correctly selected
   - _create_v1_conversation is called
   - Real agent server would be started (requires full database setup)

2. Documentation on requirements for full E2E testing:
   - Both enterprise and openhands databases
   - Tables: app_conversation_start_task, conversation_metadata
   - ProcessSandbox working directory and port allocation

Tests verify the core flow:
- Webhook → V1 detection → agent server creation → 'I'm on it' message
- Callback processor → summary request → GitHub post

NOTE: Tests require --forked flag for proper isolation due to
module-level caching affecting test state.

Run tests with:
  cd enterprise
  PYTHONPATH='.:' poetry run pytest tests/integration/v1_github_resolver -v --forked

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-25 16:52:52 +00:00
openhands
b4bdb887be Enhance V1 GitHub Resolver tests to verify agent server creation and messages
This update enhances the integration tests to verify:
1. Agent server is created via start_app_conversation
2. 'I'm on it' message is sent to GitHub
3. Agent summary is posted back via callback processor
4. Eyes reaction is added to acknowledge the request

Changes:
- Add GithubServiceImpl mock to avoid real GitHub API calls
- Add TestLLM implementation for trajectory-based testing
- Fix template directory path for jinja2 templates
- Add AppConversationStartTask with required fields
- Fix mock for reactions via get_comment().create_reaction()

All 4 tests now pass:
- test_webhook_triggers_start_app_conversation
- test_v1_callback_processor_sends_summary
- test_signature_creation
- test_issue_comment_payload_structure

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-25 04:38:21 +00:00
openhands
3c5972876b Add minimal integration test framework for V1 GitHub Resolver
This PR adds a minimal Tier 1 integration test for the V1 GitHub Resolver
webhook flow. The test verifies that:

- Webhook payload with @openhands mention is correctly detected
- User lookup via Keycloak returns the correct user ID
- V1 conversation creation path is triggered when enabled
- V0 path is NOT called when V1 is enabled

Key components:
- Database fixtures with SQLite in-memory database
- Session maker patching across all importing modules
- Seeded test data (user, org, auth tokens, GitHub installation)
- Mocks for Keycloak, GitHub API, and conversation creation

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-25 04:16:20 +00:00
aivong-openhands
0f1ad46a47 Fix CVE-2025-62727: Update starlette to 0.49.1 (#13016)
Co-authored-by: OpenHands CVE Fix Bot <openhands@all-hands.dev>
Co-authored-by: Ray Myers <ray.myers@gmail.com>
2026-02-24 10:55:32 -06:00
sp.wack
5367bef43a fix: detect team/org-level budget errors in error banner (#13003) 2026-02-24 20:55:11 +04:00
Tim O'Farrell
3afeccfe7f fix: prevent token refresh deadlock with double-checked locking and timeouts (#13020)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-24 08:13:57 -07:00
Tim O'Farrell
0677c035ff Optimize get_sandbox_by_session_api_key with hash lookup (#13019)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-24 13:55:21 +00:00
Hiep Le
68165b52d9 feat(backend): add pagination and email filtering for organization members (#12999) 2026-02-24 16:02:24 +07:00
Dream
dcc8217317 feat(frontend): add mutateWithToast utility for standardized mutation toast handling (#12433)
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2026-02-24 15:06:44 +07:00
jpelletier1
d1410949ff Experiment - Add 'Add Team Members' button to Avatar menu in SaaS mode (#12647)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-23 23:06:57 +04:00
Tim O'Farrell
a6c0d80fe1 Fix: Logout on 401 error in useGitUser; downgrade provider error to warning (#12935)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-23 10:15:54 -07:00
Tim O'Farrell
0efb1db85d Bumped SDK to 1.11.5 (#13002) 2026-02-23 09:31:31 -07:00
Hiep Le
8e0f74c92c fix(backend): ensure members are removed from the corresponding litellm team when removed from an organization (#12996) 2026-02-23 18:45:31 +07:00
Hiep Le
6e1ba3d836 fix(backend): update current_org_id when removing a member from an organization (#12995) 2026-02-23 18:21:37 +07:00
Hiep Le
0ec97893d1 fix(backend): unable to delete an organization after inviting at least one member (#12993) 2026-02-23 18:21:10 +07:00
Tim O'Farrell
ddb809bc43 Add webhook endpoint authentication bypass and admin context unfiltered data access (#12956)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-23 09:28:49 +00:00
Alona
872f2b87f2 fix: add retry logic with exponential backoff to send_welcome_email (#12450)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2026-02-20 20:42:00 +00:00
Graham Neubig
ee86005a3a Align PR review workflow with software-agent-sdk (#12963)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-20 21:02:32 +01:00
Graham Neubig
d4aa30580b Migrate PR review workflow to use extensions action (#12917)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-20 18:44:02 +00:00
Tim O'Farrell
2f0e879129 Fix session_maker to accept kwargs for backward compatibility (#12960)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-20 10:56:44 -07:00
sp.wack
3bc2ef954e fix(backend): config values (#12944) 2026-02-20 17:53:35 +04:00
Ray Myers
32ab2a24c6 Remove enterprise-preview job and workflow (#12350)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-20 03:36:14 +01:00
Engel Nyst
a6e148d1e6 refactor: use consolidated pr-review action (#12801)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-20 02:08:42 +01:00
Manrique Vargas
3fc977eddd fix(mcp): skip conversation link when conversation_id is None (#12941)
Signed-off-by: machov <mv1742@nyu.edu>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2026-02-19 21:41:26 +00:00
John-Mason P. Shackelford
89a6890269 Fix URL encoding in Jira OAuth authorization URLs (#12399)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2026-02-19 21:40:29 +00:00
Hiep Le
8927ac2230 fix(backend): organization members now see correct shared credit balance (#12942) 2026-02-20 01:34:53 +07:00
Rohit Malhotra
f3429e33ca Fix Resend sync to respect deleted users (#12904)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-19 17:43:15 +00:00
Tim O'Farrell
7cd219792b Add type hints and use model objects in api_keys.py endpoints (#12939)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-19 08:40:46 -07:00
Hiep Le
2aabe2ed8c fix(backend): add organization filtering to V1 conversation queries (#12923) 2026-02-19 20:39:28 +07:00
Tim O'Farrell
731a9a813e More readable logs for local debugging (#12926) 2026-02-19 02:27:57 -07:00
Tim O'Farrell
123e556fed Added endpoint for readiness probe (#12927) 2026-02-19 02:27:35 -07:00
Chujiang
6676cae249 fix: add missing type hints and improve test logging (#12810)
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
2026-02-19 00:58:39 +01:00
Clay Arnold
fede37b496 fix: add claude-opus-4-6 to temperature/top_p guard (#12874)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 00:33:17 +01:00
Hiep Le
3bcd6f18df fix(backend): set user email fields from user_info during create_user (#12921) 2026-02-19 02:06:20 +07:00
Rohit Malhotra
0da18440c2 Mention free MiniMax usage and drop free credits (#12918)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 13:54:05 -05:00
Hiep Le
ac76e10048 refactor(backend): include current_org_id in organization list response (#12915) 2026-02-18 20:35:40 +07:00
Hiep Le
b98bae8b5f refactor(backend): rename orgmemberresponse.role_name to role (#12914) 2026-02-18 20:23:07 +07:00
Tim O'Farrell
516721d1ee fix: add default uuid4 to event_callback_result primary key (#12908)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 05:57:13 -07:00
Hiep Le
4d6f66ca28 feat: add user invitation logic (#12883)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 13:24:19 +07:00
chuckbutkus
b18568da0b Feature/permission based authorization (#12906)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 01:17:39 -05:00
mamoodi
83dd3c169c Release 1.4.0 (#12897) 2026-02-17 13:09:29 -05:00
Tim O'Farrell
35bddb14f1 fix: preserve import order in clean_proactive_convo_table.py (#12901)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: OpenHands Bot <contact@all-hands.dev>
2026-02-17 17:52:54 +00:00
Tim O'Farrell
e8425218e2 Remove alembic errors dumped into logs by cron jobs (#12900) 2026-02-17 17:22:54 +00:00
Rohit Malhotra
0a879fa781 Grant free credits after minimum purchase (#12899)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-17 11:00:42 -05:00
Hiep Le
41e142bbab fix(backend): system prompt override (planning agent) (#12893) 2026-02-17 16:15:26 +07:00
Engel Nyst
b06b9eedac fix: wire suggested task prompts for V1 (#12787)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-16 23:57:32 +01:00
Tim O'Farrell
a9afafa991 Default model for new users is minimax (#12889) 2026-02-16 12:24:30 -07:00
139 changed files with 12569 additions and 1476 deletions

View File

@@ -1,29 +0,0 @@
# Feature branch preview for enterprise code
name: Enterprise Preview
# Run on PRs labeled
on:
pull_request:
types: [labeled]
# Match ghcr-build.yml, but don't interrupt it.
concurrency:
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
cancel-in-progress: false
jobs:
# This must happen for the PR Docker workflow when the label is present,
# and also if it's added after the fact. Thus, it exists in both places.
enterprise-preview:
name: Enterprise preview
if: github.event.label.name == 'deploy'
runs-on: blacksmith-4vcpu-ubuntu-2204
steps:
# This should match the version in ghcr-build.yml
- name: Trigger remote job
run: |
curl --fail-with-body -sS -X POST \
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
-H "Accept: application/vnd.github+json" \
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches

View File

@@ -240,21 +240,6 @@ jobs:
# Add build attestations for better security
sbom: true
enterprise-preview:
name: Enterprise preview
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
runs-on: blacksmith-4vcpu-ubuntu-2204
needs: [ghcr_build_enterprise]
steps:
# This should match the version in enterprise-preview.yml
- name: Trigger remote job
run: |
curl --fail-with-body -sS -X POST \
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
-H "Accept: application/vnd.github+json" \
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
# "All Runtime Tests Passed" is a required job for PRs to merge
# We can remove this once the config changes
runtime_tests_check_success:

View File

@@ -2,16 +2,11 @@
name: PR Review by OpenHands
on:
# Use pull_request_target to allow fork PRs to access secrets when triggered by maintainers
# Security: This workflow runs when:
# 1. A new PR is opened (non-draft), OR
# 2. A draft PR is marked as ready for review, OR
# 3. A maintainer adds the 'review-this' label, OR
# 4. A maintainer requests openhands-agent or all-hands-bot as a reviewer
# Only users with write access can add labels or request reviews, ensuring security.
# The PR code is explicitly checked out for review, but secrets are only accessible
# because the workflow runs in the base repository context
pull_request_target:
# TEMPORARY MITIGATION (Clinejection hardening)
#
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
# workflow is fully hardened for untrusted execution.
pull_request:
types: [opened, ready_for_review, labeled, review_requested]
permissions:
@@ -21,107 +16,33 @@ permissions:
jobs:
pr-review:
# Run when one of the following conditions is met:
# 1. A new non-draft PR is opened by a trusted contributor, OR
# 2. A draft PR is converted to ready for review by a trusted contributor, OR
# 3. 'review-this' label is added, OR
# 4. openhands-agent or all-hands-bot is requested as a reviewer
# Note: FIRST_TIME_CONTRIBUTOR PRs require manual trigger via label/reviewer request
# Note: fork PRs will not have access to repository secrets under `pull_request`.
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
if: |
(github.event.action == 'opened' && github.event.pull_request.draft == false && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
(github.event.action == 'ready_for_review' && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
github.event.label.name == 'review-this' ||
github.event.requested_reviewer.login == 'openhands-agent' ||
github.event.requested_reviewer.login == 'all-hands-bot'
github.event.pull_request.head.repo.full_name == github.repository &&
(
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
github.event.action == 'ready_for_review' ||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
(
github.event.action == 'review_requested' &&
(
github.event.requested_reviewer.login == 'openhands-agent' ||
github.event.requested_reviewer.login == 'all-hands-bot'
)
)
)
concurrency:
group: pr-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
runs-on: blacksmith-4vcpu-ubuntu-2404
env:
LLM_MODEL: litellm_proxy/claude-sonnet-4-5-20250929
LLM_BASE_URL: https://llm-proxy.app.all-hands.dev
# PR context will be automatically provided by the agent script
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_BODY: ${{ github.event.pull_request.body }}
PR_BASE_BRANCH: ${{ github.event.pull_request.base.ref }}
PR_HEAD_BRANCH: ${{ github.event.pull_request.head.ref }}
REPO_NAME: ${{ github.repository }}
runs-on: ubuntu-24.04
steps:
- name: Checkout software-agent-sdk repository
uses: actions/checkout@v5
- name: Run PR Review
uses: OpenHands/extensions/plugins/pr-review@main
with:
repository: OpenHands/software-agent-sdk
path: software-agent-sdk
- name: Checkout PR repository
uses: actions/checkout@v5
with:
# When using pull_request_target, explicitly checkout the PR branch
# This ensures we review the actual PR code (including fork PRs)
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}
fetch-depth: 0
# Security: Don't persist credentials to prevent untrusted PR code from using them
persist-credentials: false
path: pr-repo
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.13'
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install GitHub CLI
run: |
# Install GitHub CLI for posting review comments
sudo apt-get update
sudo apt-get install -y gh
- name: Install OpenHands dependencies
run: |
# Install OpenHands SDK and tools from local checkout
uv pip install --system ./software-agent-sdk/openhands-sdk ./software-agent-sdk/openhands-tools
- name: Check required configuration
env:
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
run: |
if [ -z "$LLM_API_KEY" ]; then
echo "Error: LLM_API_KEY secret is not set."
exit 1
fi
echo "PR Number: $PR_NUMBER"
echo "PR Title: $PR_TITLE"
echo "Repository: $REPO_NAME"
echo "LLM model: $LLM_MODEL"
if [ -n "$LLM_BASE_URL" ]; then
echo "LLM base URL: $LLM_BASE_URL"
fi
- name: Run PR review
env:
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
GITHUB_TOKEN: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
run: |
# Change to the PR repository directory so agent can analyze the code
cd pr-repo
# Run the PR review script from the software-agent-sdk checkout
uv run python ../software-agent-sdk/examples/03_github_workflows/02_pr_review/agent_script.py
- name: Upload logs as artifact
uses: actions/upload-artifact@v5
if: always()
with:
name: openhands-pr-review-logs
path: |
*.log
output/
retention-days: 7
llm-model: litellm_proxy/claude-sonnet-4-5-20250929
llm-base-url: https://llm-proxy.app.all-hands.dev
review-style: roasted
llm-api-key: ${{ secrets.LLM_API_KEY }}
github-token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
lmnr-api-key: ${{ secrets.LMNR_SKILLS_API_KEY }}

View File

@@ -0,0 +1,85 @@
---
name: PR Review Evaluation
# This workflow evaluates how well PR review comments were addressed.
# It runs when a PR is closed to assess review effectiveness.
#
# Security note: pull_request_target is safe here because:
# 1. Only triggers on PR close (not on code changes)
# 2. Does not checkout PR code - only downloads artifacts from trusted workflow runs
# 3. Runs evaluation scripts from the extensions repo, not from the PR
on:
pull_request_target:
types: [closed]
permissions:
contents: read
pull-requests: read
jobs:
evaluate:
runs-on: ubuntu-24.04
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REPO_NAME: ${{ github.repository }}
PR_MERGED: ${{ github.event.pull_request.merged }}
steps:
- name: Download review trace artifact
id: download-trace
uses: dawidd6/action-download-artifact@v6
continue-on-error: true
with:
workflow: pr-review-by-openhands.yml
name: pr-review-trace-${{ github.event.pull_request.number }}
path: trace-info
search_artifacts: true
if_no_artifact_found: warn
- name: Check if trace file exists
id: check-trace
run: |
if [ -f "trace-info/laminar_trace_info.json" ]; then
echo "trace_exists=true" >> $GITHUB_OUTPUT
echo "Found trace file for PR #$PR_NUMBER"
else
echo "trace_exists=false" >> $GITHUB_OUTPUT
echo "No trace file found for PR #$PR_NUMBER - skipping evaluation"
fi
# Always checkout main branch for security - cannot test script changes in PRs
- name: Checkout extensions repository
if: steps.check-trace.outputs.trace_exists == 'true'
uses: actions/checkout@v5
with:
repository: OpenHands/extensions
path: extensions
- name: Set up Python
if: steps.check-trace.outputs.trace_exists == 'true'
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install dependencies
if: steps.check-trace.outputs.trace_exists == 'true'
run: pip install lmnr
- name: Run evaluation
if: steps.check-trace.outputs.trace_exists == 'true'
env:
# Script expects LMNR_PROJECT_API_KEY; org secret is named LMNR_SKILLS_API_KEY
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python extensions/plugins/pr-review/scripts/evaluate_review.py \
--trace-file trace-info/laminar_trace_info.json
- name: Upload evaluation logs
uses: actions/upload-artifact@v5
if: always() && steps.check-trace.outputs.trace_exists == 'true'
with:
name: pr-review-evaluation-${{ github.event.pull_request.number }}
path: '*.log'
retention-days: 30

View File

@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
### OpenHands Cloud
This is a deployment of OpenHands GUI, running on hosted infrastructure.
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
OpenHands Cloud comes with source-available features and integrations:
- Integrations with Slack, Jira, and Linear

View File

@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
return agent
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
agent = agent.model_copy(
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
)
# Skip experiment for planning agents which require their specialized prompt
if agent.system_prompt_filename != 'system_prompt_planning.j2':
agent = agent.model_copy(
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
)
return agent

View File

@@ -0,0 +1,37 @@
"""Add pending_free_credits flag to org table.
Revision ID: 093
Revises: 092
Create Date: 2025-02-17 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '093'
down_revision: Union[str, None] = '092'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add pending_free_credits column to org table with default false.
# New orgs will have this set to TRUE at creation time.
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
op.add_column(
'org',
sa.Column(
'pending_free_credits',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)
def downgrade() -> None:
op.drop_column('org', 'pending_free_credits')

View File

@@ -0,0 +1,110 @@
"""create org_invitation table
Revision ID: 094
Revises: 093
Create Date: 2026-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '094'
down_revision: Union[str, None] = '093'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create org_invitation table
op.create_table(
'org_invitation',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('token', sa.String(64), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
'status',
sa.String(20),
nullable=False,
server_default=sa.text("'pending'"),
),
sa.Column(
'created_at',
sa.DateTime,
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('accepted_at', sa.DateTime, nullable=True),
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
# Foreign key constraints
sa.ForeignKeyConstraint(
['org_id'],
['org.id'],
name='org_invitation_org_fkey',
ondelete='CASCADE',
),
sa.ForeignKeyConstraint(
['role_id'],
['role.id'],
name='org_invitation_role_fkey',
),
sa.ForeignKeyConstraint(
['inviter_id'],
['user.id'],
name='org_invitation_inviter_fkey',
),
sa.ForeignKeyConstraint(
['accepted_by_user_id'],
['user.id'],
name='org_invitation_accepter_fkey',
),
)
# Create indexes
op.create_index(
'ix_org_invitation_token',
'org_invitation',
['token'],
unique=True,
)
op.create_index(
'ix_org_invitation_org_id',
'org_invitation',
['org_id'],
)
op.create_index(
'ix_org_invitation_email',
'org_invitation',
['email'],
)
op.create_index(
'ix_org_invitation_status',
'org_invitation',
['status'],
)
# Composite index for checking pending invitations
op.create_index(
'ix_org_invitation_org_email_status',
'org_invitation',
['org_id', 'email', 'status'],
)
def downgrade() -> None:
# Drop indexes
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
# Drop table
op.drop_table('org_invitation')

View File

@@ -0,0 +1,37 @@
"""Drop pending_free_credits column from org table.
Revision ID: 095
Revises: 094
Create Date: 2025-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '095'
down_revision: Union[str, None] = '094'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Drop the pending_free_credits column from org table.
# This column was used for tracking free credit eligibility but is no longer needed.
op.drop_column('org', 'pending_free_credits')
def downgrade() -> None:
# Re-add pending_free_credits column with default false.
op.add_column(
'org',
sa.Column(
'pending_free_credits',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)

View File

@@ -0,0 +1,67 @@
"""Create resend_synced_users table.
Revision ID: 096
Revises: 095
Create Date: 2025-02-17 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '096'
down_revision: Union[str, None] = '095'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
op.create_table(
'resend_synced_users',
sa.Column(
'id',
sa.UUID(as_uuid=True),
nullable=False,
primary_key=True,
),
sa.Column('email', sa.String(), nullable=False),
sa.Column('audience_id', sa.String(), nullable=False),
sa.Column(
'synced_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('keycloak_user_id', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint(
'email', 'audience_id', name='uq_resend_synced_email_audience'
),
)
# Create index on email for fast lookups
op.create_index(
'ix_resend_synced_users_email',
'resend_synced_users',
['email'],
)
# Create index on audience_id for filtering by audience
op.create_index(
'ix_resend_synced_users_audience_id',
'resend_synced_users',
['audience_id'],
)
def downgrade() -> None:
"""Drop resend_synced_users table."""
op.drop_index(
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
)
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
op.drop_table('resend_synced_users')

View File

@@ -0,0 +1,41 @@
"""Add session_api_key_hash to v1_remote_sandbox table
Revision ID: 097
Revises: 096
Create Date: 2025-02-24 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '097'
down_revision: Union[str, None] = '096'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add session_api_key_hash column to v1_remote_sandbox table."""
op.add_column(
'v1_remote_sandbox',
sa.Column('session_api_key_hash', sa.String(), nullable=True),
)
op.create_index(
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
'v1_remote_sandbox',
['session_api_key_hash'],
unique=False,
)
def downgrade() -> None:
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
op.drop_index(
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
table_name='v1_remote_sandbox',
)
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')

26
enterprise/poetry.lock generated
View File

@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.11.4"
version = "1.11.5"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.11.4-py3-none-any.whl", hash = "sha256:739bdb774dbfcd23d6e87ee6ee32bc0999f22300037506b6dd33e9ea67fa5c2a"},
{file = "openhands_agent_server-1.11.4.tar.gz", hash = "sha256:41247f7022a046eb50ca3b552bc6d12bfa9776e1bd27d0989da91b9f7ac77ca2"},
{file = "openhands_agent_server-1.11.5-py3-none-any.whl", hash = "sha256:8bae7063f232791d58a5c31919f58b557f7cce60e6295773985c7dadc556cb9e"},
{file = "openhands_agent_server-1.11.5.tar.gz", hash = "sha256:b61366d727c61ab9b7fcd66faab53f230f8ef0928c1177a388d2c5c4be6ebbd0"},
]
[package.dependencies]
@@ -6126,7 +6126,7 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-ai"
version = "1.3.0"
version = "1.4.0"
description = "OpenHands: Code Less, Make More"
optional = false
python-versions = "^3.12,<3.14"
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
numpy = "*"
openai = "2.8"
openhands-aci = "0.3.2"
openhands-agent-server = "1.11.4"
openhands-sdk = "1.11.4"
openhands-tools = "1.11.4"
openhands-agent-server = "1.11.5"
openhands-sdk = "1.11.5"
openhands-tools = "1.11.5"
opentelemetry-api = ">=1.33.1"
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
pathspec = ">=0.12.1"
@@ -6225,14 +6225,14 @@ url = ".."
[[package]]
name = "openhands-sdk"
version = "1.11.4"
version = "1.11.5"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.11.4-py3-none-any.whl", hash = "sha256:9f4607c5d94b56fbcd533207026ee892779dd50e29bce79277ff82454a4f76d5"},
{file = "openhands_sdk-1.11.4.tar.gz", hash = "sha256:4088744f6b8856eeab22d3bc17e47d1736ea7ced945c2fa126bd7d48c14bb313"},
{file = "openhands_sdk-1.11.5-py3-none-any.whl", hash = "sha256:f949cd540cbecc339d90fb0cca2a5f29e1b62566b82b5aee82ef40f259d14e60"},
{file = "openhands_sdk-1.11.5.tar.gz", hash = "sha256:dd6225876b7b8dbb6c608559f2718c3d0bf44d0bb741e990b185c6cdc5150c5a"},
]
[package.dependencies]
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.11.4"
version = "1.11.5"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.11.4-py3-none-any.whl", hash = "sha256:efd721b73e87a0dac69171a76931363fa59fcde98107ca86081ee7bf0253673a"},
{file = "openhands_tools-1.11.4.tar.gz", hash = "sha256:80671b1ea8c85a5247a75ea2340ae31d76363e9c723b104699a9a77e66d2043c"},
{file = "openhands_tools-1.11.5-py3-none-any.whl", hash = "sha256:1e981e1e7f3544184fe946cee8eb6bd287010cdef77d83ebac945c9f42df3baf"},
{file = "openhands_tools-1.11.5.tar.gz", hash = "sha256:d7b1163f6505a51b07147e7d8972062c129ecc46571a71f28d5470355e06650e"},
]
[package.dependencies]

View File

@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
from server.routes.integration.slack import slack_router # noqa: E402
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.org_invitations import ( # noqa: E402
accept_router as invitation_accept_router,
)
from server.routes.org_invitations import ( # noqa: E402
invitation_router,
)
from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
@@ -99,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
base_app.include_router(api_keys_router) # Add routes for API key management
base_app.include_router(org_router) # Add routes for organization management
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app)
add_debugging_routes(
base_app

View File

@@ -38,3 +38,9 @@ class ExpiredError(AuthError):
"""Error when a token has expired (Usually the refresh token)"""
pass
class TokenRefreshError(AuthError):
"""Error when token refresh fails due to timeout or lock contention"""
pass

View File

@@ -1,21 +1,18 @@
"""
Permission-based authorization dependencies for API endpoints (SAAS mode).
Permission-based authorization dependencies for API endpoints.
This module provides FastAPI dependencies for checking user permissions
within organizations. It uses a permission-based authorization model where
roles (owner, admin, member) are mapped to specific permissions.
This is the SAAS/enterprise implementation that performs real authorization
checks against the database.
Permissions are defined in the Permission enum and mapped to roles via
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
the familiar role-based hierarchy.
Usage:
from server.auth.authorization import (
Permission,
require_permission,
require_org_role,
require_org_user,
require_org_admin,
require_org_owner,
)
@router.get('/{org_id}/settings')
@@ -25,6 +22,14 @@ Usage:
):
# Only users with VIEW_LLM_SETTINGS permission can access
...
@router.patch('/{org_id}/settings')
async def update_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
):
# Only users with EDIT_LLM_SETTINGS permission can access
...
"""
from enum import Enum
@@ -36,10 +41,50 @@ from storage.role import Role
from storage.role_store import RoleStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import Permission
from openhands.server.user_auth import get_user_id
class Permission(str, Enum):
"""Permissions that can be assigned to roles."""
# Secrets
MANAGE_SECRETS = 'manage_secrets'
# MCP
MANAGE_MCP = 'manage_mcp'
# Integrations
MANAGE_INTEGRATIONS = 'manage_integrations'
# Application Settings
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
# API Keys
MANAGE_API_KEYS = 'manage_api_keys'
# LLM Settings
VIEW_LLM_SETTINGS = 'view_llm_settings'
EDIT_LLM_SETTINGS = 'edit_llm_settings'
# Billing
VIEW_BILLING = 'view_billing'
ADD_CREDITS = 'add_credits'
# Organization Members
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
# Organization Management
VIEW_ORG_SETTINGS = 'view_org_settings'
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
DELETE_ORGANIZATION = 'delete_organization'
# Temporary permissions until we finish the API updates.
EDIT_ORG_SETTINGS = 'edit_org_settings'
class RoleName(str, Enum):
"""Role names used in the system."""
@@ -67,6 +112,9 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
Permission.CHANGE_USER_ROLE_MEMBER,
Permission.CHANGE_USER_ROLE_ADMIN,
Permission.CHANGE_USER_ROLE_OWNER,
# Organization Management
Permission.VIEW_ORG_SETTINGS,
Permission.EDIT_ORG_SETTINGS,
# Organization Management (Owner only)
Permission.CHANGE_ORGANIZATION_NAME,
Permission.DELETE_ORGANIZATION,
@@ -88,6 +136,9 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
Permission.INVITE_USER_TO_ORGANIZATION,
Permission.CHANGE_USER_ROLE_MEMBER,
Permission.CHANGE_USER_ROLE_ADMIN,
# Organization Management
Permission.VIEW_ORG_SETTINGS,
Permission.EDIT_ORG_SETTINGS,
]
),
RoleName.MEMBER: frozenset(
@@ -98,42 +149,81 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
# LLM Settings (View only)
# Settings (View only)
Permission.VIEW_ORG_SETTINGS,
Permission.VIEW_LLM_SETTINGS,
]
),
}
def get_role_permissions(role_name: str) -> frozenset[Permission]:
"""Get the permissions for a role."""
try:
role_enum = RoleName(role_name)
return ROLE_PERMISSIONS.get(role_enum, frozenset())
except ValueError:
return frozenset()
def get_user_org_role(user_id: str, org_id: UUID) -> Role | None:
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
"""
Get the user's role in an organization.
Get the user's role in an organization (synchronous version).
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID
org_id: Organization ID, or None to use the user's current organization
Returns:
Role object if user is a member, None otherwise
"""
from uuid import UUID as parse_uuid
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if org_id is None:
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
else:
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
return None
return RoleStore.get_role_by_id(org_member.role_id)
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
"""
Get the user's role in an organization (async version).
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID, or None to use the user's current organization
Returns:
Role object if user is a member, None otherwise
"""
from uuid import UUID as parse_uuid
if org_id is None:
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
parse_uuid(user_id)
)
else:
org_member = await OrgMemberStore.get_org_member_async(
org_id, parse_uuid(user_id)
)
if not org_member:
return None
return await RoleStore.get_role_by_id_async(org_member.role_id)
def get_role_permissions(role_name: str) -> frozenset[Permission]:
"""
Get the permissions for a role.
Args:
role_name: Name of the role
Returns:
Set of permissions for the role
"""
try:
role_enum = RoleName(role_name)
return ROLE_PERMISSIONS.get(role_enum, frozenset())
except ValueError:
return frozenset()
def has_permission(user_role: Role, permission: Permission) -> bool:
"""
Check if a role has a specific permission.
@@ -149,23 +239,6 @@ def has_permission(user_role: Role, permission: Permission) -> bool:
return permission in permissions
def has_required_role(user_role: Role, required_role: Role) -> bool:
"""
Check if user's role meets or exceeds the required role.
Uses role hierarchy based on rank where lower rank = higher position
(e.g., rank 1 owner > rank 2 admin > rank 3 user).
Args:
user_role: User's actual Role object
required_role: Minimum required Role object
Returns:
True if user has sufficient permissions
"""
return user_role.rank <= required_role.rank
def require_permission(permission: Permission):
"""
Factory function that creates a dependency to require a specific permission.
@@ -192,7 +265,7 @@ def require_permission(permission: Permission):
"""
async def permission_checker(
org_id: UUID,
org_id: UUID | None = None,
user_id: str | None = Depends(get_user_id),
) -> str:
if not user_id:
@@ -201,7 +274,7 @@ def require_permission(permission: Permission):
detail='User not authenticated',
)
user_role = get_user_org_role(user_id, org_id)
user_role = await get_user_org_role_async(user_id, org_id)
if not user_role:
logger.warning(
@@ -231,90 +304,3 @@ def require_permission(permission: Permission):
return user_id
return permission_checker
def require_org_role(required_role_name: str):
"""
Factory function that creates a dependency to require a minimum org role.
This creates a FastAPI dependency that:
1. Extracts org_id from the path parameter
2. Gets the authenticated user_id
3. Checks if the user has the required role in the organization
4. Returns the user_id if authorized, raises HTTPException otherwise
Role hierarchy is based on rank from the Role class, where
lower rank = higher position (e.g., rank 1 > rank 2 > rank 3).
Usage:
@router.get('/{org_id}/resource')
async def get_resource(
org_id: UUID,
user_id: str = Depends(require_org_role('user')),
):
...
Args:
required_role_name: Name of the minimum required role to access the endpoint
Returns:
Dependency function that validates role and returns user_id
"""
async def role_checker(
org_id: UUID,
user_id: str | None = Depends(get_user_id),
) -> str:
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='User not authenticated',
)
user_role = get_user_org_role(user_id, org_id)
if not user_role:
logger.warning(
'User not a member of organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='User is not a member of this organization',
)
required_role = RoleStore.get_role_by_name(required_role_name)
if not required_role:
logger.error(
'Required role not found in database',
extra={'required_role': required_role_name},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Role configuration error',
)
if not has_required_role(user_role, required_role):
logger.warning(
'Insufficient role permissions',
extra={
'user_id': user_id,
'org_id': str(org_id),
'user_role': user_role.name,
'required_role': required_role_name,
},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f'Requires {required_role_name} role or higher',
)
return user_id
return role_checker
# Convenience dependencies for common role checks
require_org_user = require_org_role('user')
require_org_admin = require_org_role('admin')
require_org_owner = require_org_role('owner')

View File

@@ -49,6 +49,10 @@ from openhands.integrations.service_types import ProviderType
from openhands.server.types import SessionExpiredError
from openhands.utils.http_session import httpx_verify_option
# HTTP timeout for external IDP calls (in seconds)
# This prevents indefinite blocking if an IDP is slow or unresponsive
IDP_HTTP_TIMEOUT = 15.0
def _before_sleep_callback(retry_state: RetryCallState) -> None:
logger.info(f'Retry attempt {retry_state.attempt_number} for Keycloak operation')
@@ -202,7 +206,9 @@ class TokenManager:
access_token: str,
idp: ProviderType,
) -> dict[str, str | int]:
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
async with httpx.AsyncClient(
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
) as client:
base_url = KEYCLOAK_SERVER_URL_EXT if self.external else KEYCLOAK_SERVER_URL
url = f'{base_url}/realms/{KEYCLOAK_REALM_NAME}/broker/{idp.value}/token'
headers = {
@@ -361,7 +367,9 @@ class TokenManager:
'refresh_token': refresh_token,
'grant_type': 'refresh_token',
}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
async with httpx.AsyncClient(
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
) as client:
response = await client.post(url, data=payload)
response.raise_for_status()
logger.info('Successfully refreshed GitHub token')
@@ -387,7 +395,9 @@ class TokenManager:
'refresh_token': refresh_token,
'grant_type': 'refresh_token',
}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
async with httpx.AsyncClient(
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
) as client:
response = await client.post(url, data=payload)
response.raise_for_status()
logger.info('Successfully refreshed GitLab token')
@@ -415,7 +425,9 @@ class TokenManager:
'refresh_token': refresh_token,
}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
async with httpx.AsyncClient(
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
) as client:
response = await client.post(url, data=data, headers=headers)
response.raise_for_status()
logger.info('Successfully refreshed Bitbucket token')

View File

@@ -30,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
2: 'claude-3-7-sonnet-20250219',
3: 'claude-sonnet-4-20250514',
4: 'claude-sonnet-4-20250514',
5: 'claude-opus-4-5-20251101',
# Minimax is now the default as it gives results close to claude in terms of quality
# but at a much lower price
5: 'minimax-m2.5',
}
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
@@ -59,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
},
}
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')

View File

@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
obj['stack_info'] = format_stack(stack_info)
result = json.dumps(obj, **kwargs)
# Swap out newlines to make things easier to read. This will produce
# invalid json but means we can have similar logs in local development
# to production, making things easier to correlate. Obviously,
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
if LOG_JSON_FOR_CONSOLE:
result = result.replace('\\n', '\n')
return result

View File

@@ -160,10 +160,10 @@ class SetAuthCookieMiddleware:
'/api/billing/customer-setup-success',
'/api/billing/stripe-webhook',
'/api/email/resend',
'/api/organizations/members/invite/accept',
'/oauth/device/authorize',
'/oauth/device/token',
'/api/v1/web-client/config',
'/api/v1/webhooks/secrets',
)
if path in ignore_paths:
return False
@@ -174,6 +174,10 @@ class SetAuthCookieMiddleware:
):
return False
# Webhooks access is controlled using separate API keys
if path.startswith('/api/v1/webhooks/'):
return False
is_mcp = path.startswith('/mcp')
is_api_route = path.startswith('/api')
return is_api_route or is_mcp

View File

@@ -2,6 +2,7 @@ from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from storage.api_key import ApiKey
from storage.api_key_store import ApiKeyStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
@@ -135,9 +136,9 @@ class ApiKeyCreate(BaseModel):
class ApiKeyResponse(BaseModel):
id: int
name: str | None = None
created_at: str
last_used_at: str | None = None
expires_at: str | None = None
created_at: datetime
last_used_at: datetime | None = None
expires_at: datetime | None = None
class ApiKeyCreateResponse(ApiKeyResponse):
@@ -152,12 +153,29 @@ class ByorPermittedResponse(BaseModel):
permitted: bool
@api_router.get('/llm/byor/permitted', response_model=ByorPermittedResponse)
async def check_byor_permitted(user_id: str = Depends(get_user_id)):
class MessageResponse(BaseModel):
message: str
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
"""Convert an ApiKey model to an ApiKeyResponse."""
return ApiKeyResponse(
id=key.id,
name=key.name,
created_at=key.created_at,
last_used_at=key.last_used_at,
expires_at=key.expires_at,
)
@api_router.get('/llm/byor/permitted', tags=['Keys'])
async def check_byor_permitted(
user_id: str = Depends(get_user_id),
) -> ByorPermittedResponse:
"""Check if BYOR key export is permitted for the user's current org."""
try:
permitted = await OrgService.check_byor_export_enabled(user_id)
return {'permitted': permitted}
return ByorPermittedResponse(permitted=permitted)
except Exception as e:
logger.exception(
'Error checking BYOR export permission', extra={'error': str(e)}
@@ -168,8 +186,10 @@ async def check_byor_permitted(user_id: str = Depends(get_user_id)):
)
@api_router.post('', response_model=ApiKeyCreateResponse)
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
@api_router.post('', tags=['Keys'])
async def create_api_key(
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
) -> ApiKeyCreateResponse:
"""Create a new API key for the authenticated user."""
try:
api_key = await api_key_store.create_api_key(
@@ -178,48 +198,29 @@ async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user
# Get the created key details
keys = await api_key_store.list_api_keys(user_id)
for key in keys:
if key['name'] == key_data.name:
return {
**key,
'key': api_key,
'created_at': (
key['created_at'].isoformat() if key['created_at'] else None
),
'last_used_at': (
key['last_used_at'].isoformat() if key['last_used_at'] else None
),
'expires_at': (
key['expires_at'].isoformat() if key['expires_at'] else None
),
}
if key.name == key_data.name:
return ApiKeyCreateResponse(
id=key.id,
name=key.name,
key=api_key,
created_at=key.created_at,
last_used_at=key.last_used_at,
expires_at=key.expires_at,
)
except Exception:
logger.exception('Error creating API key')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create API key',
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create API key',
)
@api_router.get('', response_model=list[ApiKeyResponse])
async def list_api_keys(user_id: str = Depends(get_user_id)):
@api_router.get('', tags=['Keys'])
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
"""List all API keys for the authenticated user."""
try:
keys = await api_key_store.list_api_keys(user_id)
return [
{
**key,
'created_at': (
key['created_at'].isoformat() if key['created_at'] else None
),
'last_used_at': (
key['last_used_at'].isoformat() if key['last_used_at'] else None
),
'expires_at': (
key['expires_at'].isoformat() if key['expires_at'] else None
),
}
for key in keys
]
return [api_key_to_response(key) for key in keys]
except Exception:
logger.exception('Error listing API keys')
raise HTTPException(
@@ -228,8 +229,10 @@ async def list_api_keys(user_id: str = Depends(get_user_id)):
)
@api_router.delete('/{key_id}')
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
@api_router.delete('/{key_id}', tags=['Keys'])
async def delete_api_key(
key_id: int, user_id: str = Depends(get_user_id)
) -> MessageResponse:
"""Delete an API key."""
try:
# First, verify the key belongs to the user
@@ -237,7 +240,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
key_to_delete = None
for key in keys:
if key['id'] == key_id:
if key.id == key_id:
key_to_delete = key
break
@@ -255,7 +258,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete API key',
)
return {'message': 'API key deleted successfully'}
return MessageResponse(message='API key deleted successfully')
except HTTPException:
raise
except Exception:
@@ -266,8 +269,10 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
)
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
@api_router.get('/llm/byor', tags=['Keys'])
async def get_llm_api_key_for_byor(
user_id: str = Depends(get_user_id),
) -> LlmApiKeyResponse:
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
This endpoint validates that the key exists in LiteLLM before returning it.
@@ -290,7 +295,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
# Validate that the key is actually registered in LiteLLM
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
if is_valid:
return {'key': byor_key}
return LlmApiKeyResponse(key=byor_key)
else:
# Key exists in DB but is invalid in LiteLLM - regenerate it
logger.warning(
@@ -315,7 +320,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
'Successfully generated and stored new BYOR key',
extra={'user_id': user_id},
)
return {'key': key}
return LlmApiKeyResponse(key=key)
else:
logger.error(
'Failed to generate new BYOR LLM API key',
@@ -337,8 +342,10 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
)
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
@api_router.post('/llm/byor/refresh', tags=['Keys'])
async def refresh_llm_api_key_for_byor(
user_id: str = Depends(get_user_id),
) -> LlmApiKeyResponse:
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
@@ -391,7 +398,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
'BYOR LLM API key refresh completed successfully',
extra={'user_id': user_id},
)
return {'key': key}
return LlmApiKeyResponse(key=key)
except HTTPException as he:
logger.error(
'HTTP exception during BYOR LLM API key refresh',

View File

@@ -5,6 +5,7 @@ import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional
from urllib.parse import quote
from uuid import UUID as parse_uuid
import posthog
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
OrgInvitationService,
UserAlreadyMemberError,
)
from storage.database import session_maker
from storage.user import User
from storage.user_store import UserStore
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
)
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
Returns:
Tuple of (redirect_url, recaptcha_token, invitation_token).
Tokens may be None.
"""
if not state:
return '', None, None
try:
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return (
state_data.get('redirect_url', ''),
state_data.get('recaptcha_token'),
state_data.get('invitation_token'),
)
except Exception:
# Old format - state is just the redirect URL
return state, None, None
# Keep alias for backward compatibility
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
"""Extract redirect URL and reCAPTCHA token from OAuth state.
Deprecated: Use _extract_oauth_state instead.
Returns:
Tuple of (redirect_url, recaptcha_token). Token may be None.
"""
if not state:
return '', None
try:
# Try to decode as JSON (new format with reCAPTCHA)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
except Exception:
# Old format - state is just the redirect URL
return state, None
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
return redirect_url, recaptcha_token
@oauth_router.get('/keycloak/callback')
@@ -130,8 +156,8 @@ async def keycloak_callback(
error: Optional[str] = None,
error_description: Optional[str] = None,
):
# Extract redirect URL and reCAPTCHA token from state
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
# Extract redirect URL, reCAPTCHA token, and invitation token from state
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
if not redirect_url:
redirect_url = str(request.base_url)
@@ -302,8 +328,13 @@ async def keycloak_callback(
from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
response = RedirectResponse(redirect_url, status_code=302)
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
# Preserve invitation token so it can be included in OAuth state after verification
if invitation_token:
verification_redirect_url = (
f'{verification_redirect_url}&invitation_token={invitation_token}'
)
response = RedirectResponse(verification_redirect_url, status_code=302)
return response
# default to github IDP for now.
@@ -381,14 +412,90 @@ async def keycloak_callback(
)
has_accepted_tos = user.accepted_tos is not None
# Process invitation token if present (after email verification but before TOS)
if invitation_token:
try:
logger.info(
'Processing invitation token during auth callback',
extra={
'user_id': user_id,
'invitation_token_prefix': invitation_token[:10] + '...',
},
)
await OrgInvitationService.accept_invitation(
invitation_token, parse_uuid(user_id)
)
logger.info(
'Invitation accepted during auth callback',
extra={'user_id': user_id},
)
except InvitationExpiredError:
logger.warning(
'Invitation expired during auth callback',
extra={'user_id': user_id},
)
# Add query param to redirect URL
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
except InvitationInvalidError as e:
logger.warning(
'Invalid invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
except UserAlreadyMemberError:
logger.info(
'User already member during invitation acceptance',
extra={'user_id': user_id},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
except EmailMismatchError as e:
logger.warning(
'Email mismatch during auth callback invitation acceptance',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
except Exception as e:
logger.exception(
'Unexpected error processing invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
# Don't fail the login if invitation processing fails
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_error=true'
else:
redirect_url = f'{redirect_url}?invitation_error=true'
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
tos_redirect_url = (
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
)
if invitation_token:
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302)
else:
if invitation_token:
redirect_url = f'{redirect_url}&invitation_success=true'
response = RedirectResponse(redirect_url, status_code=302)
set_response_cookie(

View File

@@ -9,9 +9,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from integrations import stripe_service
from pydantic import BaseModel
from server.constants import (
STRIPE_API_KEY,
)
from server.constants import STRIPE_API_KEY
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
@@ -95,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(user.current_org_id)
)
# Update to use calculate_credits
spend = user_team_info.get('spend', 0)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(user.current_org_id)
)
credits = max(max_budget - spend, 0)
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
@@ -149,7 +147,7 @@ async def create_customer_setup_session(
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{base_url}?free_credits=success',
success_url=f'{base_url}?setup=success',
cancel_url=f'{base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -251,9 +249,11 @@ async def success_callback(session_id: str, request: Request):
)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
user_team_info, billing_session.user_id, str(user.current_org_id)
)
org = session.query(Org).filter(Org.id == user.current_org_id).first()
new_max_budget = max_budget + add_credits
await LiteLlmManager.update_team_and_users_budget(
@@ -261,8 +261,6 @@ async def success_callback(session_id: str, request: Request):
)
# Enable BYOR export for the org now that they've purchased credits
# Update within the same session to avoid nested session issues
org = session.query(Org).filter(Org.id == user.current_org_id).first()
if org:
org.byor_export_enabled = True

View File

@@ -4,7 +4,7 @@ import json
import os
import re
import uuid
from urllib.parse import urlparse
from urllib.parse import urlencode, urlparse
import requests
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
@@ -371,9 +371,7 @@ async def create_jira_workspace(request: Request, workspace_data: JiraWorkspaceC
'prompt': 'consent',
}
auth_url = (
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
)
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
return JSONResponse(
content={
@@ -432,9 +430,7 @@ async def create_workspace_link(request: Request, link_data: JiraLinkCreate):
'response_type': 'code',
'prompt': 'consent',
}
auth_url = (
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
)
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
return JSONResponse(
content={

View File

@@ -2,7 +2,7 @@ import json
import os
import re
import uuid
from urllib.parse import urlparse
from urllib.parse import urlencode, urlparse
import requests
from fastapi import (
@@ -316,7 +316,7 @@ async def create_jira_dc_workspace(
'response_type': 'code',
}
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
return JSONResponse(
content={
@@ -436,7 +436,7 @@ async def create_workspace_link(request: Request, link_data: JiraDcLinkCreate):
'state': state,
'response_type': 'code',
}
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
return JSONResponse(
content={

View File

@@ -0,0 +1,122 @@
"""
Pydantic models and custom exceptions for organization invitations.
"""
from pydantic import BaseModel, EmailStr
from storage.org_invitation import OrgInvitation
from storage.role_store import RoleStore
class InvitationError(Exception):
"""Base exception for invitation errors."""
pass
class InvitationAlreadyExistsError(InvitationError):
"""Raised when a pending invitation already exists for the email."""
def __init__(
self, message: str = 'A pending invitation already exists for this email'
):
super().__init__(message)
class UserAlreadyMemberError(InvitationError):
"""Raised when the user is already a member of the organization."""
def __init__(self, message: str = 'User is already a member of this organization'):
super().__init__(message)
class InvitationExpiredError(InvitationError):
"""Raised when the invitation has expired."""
def __init__(self, message: str = 'Invitation has expired'):
super().__init__(message)
class InvitationInvalidError(InvitationError):
"""Raised when the invitation is invalid or revoked."""
def __init__(self, message: str = 'Invitation is no longer valid'):
super().__init__(message)
class InsufficientPermissionError(InvitationError):
"""Raised when the user lacks permission to perform the action."""
def __init__(self, message: str = 'Insufficient permission'):
super().__init__(message)
class EmailMismatchError(InvitationError):
"""Raised when the accepting user's email doesn't match the invitation email."""
def __init__(self, message: str = 'Your email does not match the invitation'):
super().__init__(message)
class InvitationCreate(BaseModel):
"""Request model for creating invitation(s)."""
emails: list[EmailStr]
role: str = 'member' # Default to member role
class InvitationResponse(BaseModel):
"""Response model for invitation details."""
id: int
email: str
role: str
status: str
created_at: str
expires_at: str
inviter_email: str | None = None
@classmethod
def from_invitation(
cls,
invitation: OrgInvitation,
inviter_email: str | None = None,
) -> 'InvitationResponse':
"""Create an InvitationResponse from an OrgInvitation entity.
Args:
invitation: The invitation entity to convert
inviter_email: Optional email of the inviter
Returns:
InvitationResponse: The response model instance
"""
role_name = ''
if invitation.role:
role_name = invitation.role.name
elif invitation.role_id:
role = RoleStore.get_role_by_id(invitation.role_id)
role_name = role.name if role else ''
return cls(
id=invitation.id,
email=invitation.email,
role=role_name,
status=invitation.status,
created_at=invitation.created_at.isoformat(),
expires_at=invitation.expires_at.isoformat(),
inviter_email=inviter_email,
)
class InvitationFailure(BaseModel):
"""Response model for a failed invitation."""
email: str
error: str
class BatchInvitationResponse(BaseModel):
"""Response model for batch invitation creation."""
successful: list[InvitationResponse]
failed: list[InvitationFailure]

View File

@@ -0,0 +1,226 @@
"""API routes for organization invitations."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from server.routes.org_invitation_models import (
BatchInvitationResponse,
EmailMismatchError,
InsufficientPermissionError,
InvitationCreate,
InvitationExpiredError,
InvitationFailure,
InvitationInvalidError,
InvitationResponse,
UserAlreadyMemberError,
)
from server.services.org_invitation_service import OrgInvitationService
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth.user_auth import get_user_auth
# Router for invitation operations on an organization (requires org_id)
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
# Router for accepting invitations (no org_id required)
accept_router = APIRouter(prefix='/api/organizations/members/invite')
@invitation_router.post(
'/invite',
response_model=BatchInvitationResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_invitation(
org_id: UUID,
invitation_data: InvitationCreate,
request: Request,
user_id: str = Depends(get_user_id),
):
"""Create organization invitations for multiple email addresses.
Sends emails to invitees with secure links to join the organization.
Supports batch invitations - some may succeed while others fail.
Permission rules:
- Only owners and admins can create invitations
- Admins can only invite with 'member' or 'admin' role (not 'owner')
- Owners can invite with any role
Args:
org_id: Organization UUID
invitation_data: Invitation details (emails array, role)
request: FastAPI request
user_id: Authenticated user ID (from dependency)
Returns:
BatchInvitationResponse: Lists of successful and failed invitations
Raises:
HTTPException 400: Invalid role or organization not found
HTTPException 403: User lacks permission to invite
HTTPException 429: Rate limit exceeded
"""
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
await check_rate_limit_by_user_id(
request=request,
key_prefix='org_invitation_create',
user_id=user_id,
user_rate_limit_seconds=6,
)
try:
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=[str(email) for email in invitation_data.emails],
role_name=invitation_data.role,
inviter_id=UUID(user_id),
)
logger.info(
'Batch organization invitations created',
extra={
'org_id': str(org_id),
'total_emails': len(invitation_data.emails),
'successful': len(successful),
'failed': len(failed),
'inviter_id': user_id,
},
)
return BatchInvitationResponse(
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
failed=[
InvitationFailure(email=email, error=error) for email, error in failed
],
)
except InsufficientPermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error creating batch invitations',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@accept_router.get('/accept')
async def accept_invitation(
token: str,
request: Request,
):
"""Accept an organization invitation via token.
This endpoint is accessed via the link in the invitation email.
Flow:
1. If user is authenticated: Accept invitation directly and redirect to home
2. If user is not authenticated: Redirect to login page with invitation token
- Frontend stores token and includes it in OAuth state during login
- After authentication, keycloak_callback processes the invitation
Args:
token: The invitation token from the email link
request: FastAPI request
Returns:
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
or home page with error query params on failure
"""
base_url = str(request.base_url).rstrip('/')
# Try to get user_id from auth (may not be authenticated)
user_id = None
try:
user_auth = await get_user_auth(request)
if user_auth:
user_id = await user_auth.get_user_id()
except Exception:
pass
if not user_id:
# User not authenticated - redirect to login page with invitation token
# Frontend will store the token and include it in OAuth state during login
logger.info(
'Invitation accept: redirecting unauthenticated user to login',
extra={'token_prefix': token[:10] + '...'},
)
login_url = f'{base_url}/login?invitation_token={token}'
return RedirectResponse(login_url, status_code=302)
# User is authenticated - process the invitation directly
try:
await OrgInvitationService.accept_invitation(token, UUID(user_id))
logger.info(
'Invitation accepted successfully',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
},
)
# Redirect to home page on success
return RedirectResponse(f'{base_url}/', status_code=302)
except InvitationExpiredError:
logger.warning(
'Invitation accept failed: expired',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
except InvitationInvalidError as e:
logger.warning(
'Invitation accept failed: invalid',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
except UserAlreadyMemberError:
logger.info(
'Invitation accept: user already member',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
except EmailMismatchError as e:
logger.warning(
'Invitation accept failed: email mismatch',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
except Exception as e:
logger.exception(
'Unexpected error accepting invitation',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)

View File

@@ -214,6 +214,7 @@ class OrgPage(BaseModel):
items: list[OrgResponse]
next_page_id: str | None = None
current_org_id: str | None = None
class OrgUpdate(BaseModel):
@@ -257,7 +258,7 @@ class OrgMemberResponse(BaseModel):
user_id: str
email: str | None
role_id: int
role_name: str
role: str
role_rank: int
status: str | None
@@ -266,7 +267,8 @@ class OrgMemberPage(BaseModel):
"""Paginated response for organization members."""
items: list[OrgMemberResponse]
next_page_id: str | None = None
current_page: int = 1
per_page: int = 10
class OrgMemberUpdate(BaseModel):

View File

@@ -3,9 +3,8 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from server.auth.authorization import (
require_org_admin,
require_org_owner,
require_org_user,
Permission,
require_permission,
)
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
@@ -33,6 +32,7 @@ from server.routes.org_models import (
)
from server.services.org_member_service import OrgMemberService
from storage.org_service import OrgService
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@@ -79,6 +79,12 @@ async def list_user_orgs(
)
try:
# Fetch user to get current_org_id
user = await UserStore.get_user_by_id_async(user_id)
current_org_id = (
str(user.current_org_id) if user and user.current_org_id else None
)
# Fetch organizations from service layer
orgs, next_page_id = OrgService.get_user_orgs_paginated(
user_id=user_id,
@@ -100,7 +106,11 @@ async def list_user_orgs(
},
)
return OrgPage(items=org_responses, next_page_id=next_page_id)
return OrgPage(
items=org_responses,
next_page_id=next_page_id,
current_org_id=current_org_id,
)
except Exception as e:
logger.exception(
@@ -194,26 +204,26 @@ async def create_org(
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,
user_id: str = Depends(require_org_user),
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
) -> OrgResponse:
"""Get organization details by ID.
This endpoint allows authenticated users who are members of an organization
to retrieve its details. Only members of the organization can access this endpoint.
Requires user, admin, or owner role.
This endpoint retrieves details for a specific organization. Access requires
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
(member, admin, and owner roles).
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by dependency, requires org membership)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgResponse: The organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user is not a member of the organization
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
HTTPException: 404 if organization not found
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 500 if retrieval fails
"""
logger.info(
@@ -313,25 +323,24 @@ async def get_me(
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
async def delete_org(
org_id: UUID,
user_id: str = Depends(require_org_owner),
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
) -> dict:
"""Delete an organization.
This endpoint allows authenticated organization owners to delete their organization.
All associated data including organization members, conversations, billing data,
and external LiteLLM team resources will be permanently removed.
Requires owner role.
This endpoint permanently deletes an organization and all associated data including
organization members, conversations, billing data, and external LiteLLM team resources.
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
Args:
org_id: Organization ID to delete
user_id: Authenticated user ID (injected by dependency, requires owner role)
org_id: Organization ID to delete (UUID)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
dict: Confirmation message with deleted organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user is not an owner of the organization
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
HTTPException: 404 if organization not found
HTTPException: 500 if deletion fails
"""
@@ -424,25 +433,26 @@ async def delete_org(
async def update_org(
org_id: UUID,
update_data: OrgUpdate,
user_id: str = Depends(require_org_admin),
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
) -> OrgResponse:
"""Update an existing organization.
This endpoint allows authenticated admins and owners to update organization settings.
Requires admin or owner role in the organization.
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
permission, which is granted to admin and owner roles.
Args:
org_id: Organization ID to update (UUID validated by FastAPI)
org_id: Organization ID to update (UUID)
update_data: Organization update data
user_id: Authenticated user ID (injected by dependency, requires admin role)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgResponse: The updated organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user is not an admin or owner of the organization
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
HTTPException: 404 if organization not found
HTTPException: 409 if organization name already exists
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
@@ -506,10 +516,10 @@ async def update_org(
@org_router.get('/{org_id}/members')
async def get_org_members(
org_id: str,
org_id: UUID,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
Query(title='Optional page offset for pagination'),
] = None,
limit: Annotated[
int,
@@ -518,16 +528,48 @@ async def get_org_members(
gt=0,
lte=100,
),
] = 100,
current_user_id: str = Depends(get_user_id),
] = 10,
email: Annotated[
str | None,
Query(
title='Filter members by email (case-insensitive partial match)',
min_length=1,
max_length=255,
),
] = None,
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
) -> OrgMemberPage:
"""Get all members of an organization with cursor-based pagination."""
"""Get all members of an organization with pagination and optional email filter.
This endpoint retrieves a paginated list of organization members. Access requires
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
(member, admin, and owner roles).
Args:
org_id: Organization ID (UUID)
page_id: Optional page offset for pagination
limit: Maximum number of members to return (1-100, default 10)
email: Optional email filter (case-insensitive partial match)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgMemberPage: Paginated list of organization members with
current_page and per_page metadata. Use the /count endpoint
to get the total count separately.
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
HTTPException: 400 if org_id or page_id format is invalid
HTTPException: 500 if retrieval fails
"""
try:
success, error_code, data = await OrgMemberService.get_org_members(
org_id=UUID(org_id),
current_user_id=UUID(current_user_id),
org_id=org_id,
current_user_id=UUID(user_id),
page_id=page_id,
limit=limit,
email_filter=email,
)
if not success:
@@ -570,9 +612,67 @@ async def get_org_members(
)
@org_router.get('/{org_id}/members/count')
async def get_org_members_count(
org_id: UUID,
email: Annotated[
str | None,
Query(
title='Filter members by email (case-insensitive partial match)',
min_length=1,
max_length=255,
),
] = None,
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
) -> int:
"""Get count of organization members with optional email filter.
This endpoint returns the total count of organization members matching
the filter criteria. Access requires the VIEW_ORG_SETTINGS permission,
which is granted to all organization members (member, admin, and owner roles).
Args:
org_id: Organization ID (UUID)
email: Optional email filter (case-insensitive partial match)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
int: Total count of organization members matching the filter
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission or is not a member
HTTPException: 400 if org_id format is invalid
HTTPException: 500 if retrieval fails
"""
try:
return await OrgMemberService.get_org_members_count(
org_id=org_id,
current_user_id=UUID(user_id),
email_filter=email,
)
except OrgMemberNotFoundError:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='You are not a member of this organization',
)
except ValueError:
logger.exception('Invalid UUID format')
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Invalid organization ID format',
)
except Exception:
logger.exception('Error retrieving organization member count')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve member count',
)
@org_router.delete('/{org_id}/members/{user_id}')
async def remove_org_member(
org_id: str,
org_id: UUID,
user_id: str,
current_user_id: str = Depends(get_user_id),
):
@@ -586,7 +686,7 @@ async def remove_org_member(
"""
try:
success, error = await OrgMemberService.remove_org_member(
org_id=UUID(org_id),
org_id=org_id,
target_user_id=UUID(user_id),
current_user_id=UUID(current_user_id),
)
@@ -718,7 +818,7 @@ async def switch_org(
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
async def update_org_member(
org_id: str,
org_id: UUID,
user_id: str,
update_data: OrgMemberUpdate,
current_user_id: str = Depends(get_user_id),
@@ -735,7 +835,7 @@ async def update_org_member(
"""
try:
return await OrgMemberService.update_org_member(
org_id=UUID(org_id),
org_id=org_id,
target_user_id=UUID(user_id),
current_user_id=UUID(current_user_id),
update_data=update_data,

View File

@@ -0,0 +1,131 @@
"""Email service for sending transactional emails via Resend."""
import os
try:
import resend
RESEND_AVAILABLE = True
except ImportError:
RESEND_AVAILABLE = False
from openhands.core.logger import openhands_logger as logger
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
class EmailService:
"""Service for sending transactional emails."""
@staticmethod
def _get_resend_client() -> bool:
"""Initialize and return the Resend client.
Returns:
bool: True if client is ready, False otherwise
"""
if not RESEND_AVAILABLE:
logger.warning('Resend library not installed, skipping email')
return False
resend_api_key = os.environ.get('RESEND_API_KEY')
if not resend_api_key:
logger.warning('RESEND_API_KEY not configured, skipping email')
return False
resend.api_key = resend_api_key
return True
@staticmethod
def send_invitation_email(
to_email: str,
org_name: str,
inviter_name: str,
role_name: str,
invitation_token: str,
invitation_id: int,
) -> None:
"""Send an organization invitation email.
Args:
to_email: Recipient's email address
org_name: Name of the organization
inviter_name: Display name of the person who sent the invite
role_name: Role being offered (e.g., 'member', 'admin')
invitation_token: The secure invitation token
invitation_id: The invitation ID for logging
"""
if not EmailService._get_resend_client():
return
# Build invitation URL
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
params = {
'from': from_email,
'to': [to_email],
'subject': f"You're invited to join {org_name} on OpenHands",
'html': f"""
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
<p>Hi,</p>
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
<p>Click the button below to accept the invitation:</p>
<p style="margin: 30px 0;">
<a href="{invitation_url}"
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
text-decoration: none; border-radius: 8px; display: inline-block;
font-size: 14px; font-weight: 600;">
Accept Invitation
</a>
</p>
<p style="color: #666; font-size: 14px;">
Or copy and paste this link into your browser:<br>
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
</p>
<p style="color: #666; font-size: 14px;">
This invitation will expire in 7 days.
</p>
<p style="color: #666; font-size: 14px;">
If you weren't expecting this invitation, you can safely ignore this email.
</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #999; font-size: 12px;">
Best,<br>
The OpenHands Team
</p>
</div>
""",
}
try:
response = resend.Emails.send(params)
logger.info(
'Invitation email sent',
extra={
'invitation_id': invitation_id,
'email': to_email,
'response_id': response.get('id') if response else None,
},
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation_id,
'email': to_email,
'error': str(e),
},
)
raise

View File

@@ -0,0 +1,397 @@
"""Service for managing organization invitations."""
import asyncio
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import ROLE_ADMIN, ROLE_OWNER
from server.routes.org_invitation_models import (
EmailMismatchError,
InsufficientPermissionError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.services.email_service import EmailService
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import OrgInvitationStore
from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
class OrgInvitationService:
"""Service for organization invitation operations."""
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_name: str,
inviter_id: UUID,
) -> OrgInvitation:
"""Create a new organization invitation.
This method:
1. Validates the organization exists
2. Validates this is not a personal workspace
3. Checks inviter has owner/admin role
4. Validates role assignment permissions
5. Checks if user is already a member
6. Creates the invitation
7. Sends the invitation email
Args:
org_id: Organization UUID
email: Invitee's email address
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitation
Returns:
OrgInvitation: The created invitation
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
UserAlreadyMemberError: If email is already a member
InvitationAlreadyExistsError: If pending invitation exists
"""
email = email.lower().strip()
logger.info(
'Creating organization invitation',
extra={
'org_id': str(org_id),
'email': email,
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
# Step 2: Check this is not a personal workspace
# A personal workspace has org_id matching the user's id
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
# Step 3: Check inviter is a member and has permission
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
# Step 4: Validate role assignment permissions
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
# Get the target role
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 5: Check if user is already a member (by email)
existing_user = await UserStore.get_user_by_email_async(email)
if existing_user:
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
if existing_member:
raise UserAlreadyMemberError(
'User is already a member of this organization'
)
# Step 6: Create the invitation
invitation = await OrgInvitationStore.create_invitation(
org_id=org_id,
email=email,
role_id=target_role.id,
inviter_id=inviter_id,
)
# Step 7: Send invitation email
try:
# Get inviter info for the email
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_name = 'A team member'
if inviter_user and inviter_user.email:
inviter_name = inviter_user.email.split('@')[0]
EmailService.send_invitation_email(
to_email=email,
org_name=org.name,
inviter_name=inviter_name,
role_name=target_role.name,
invitation_token=invitation.token,
invitation_id=invitation.id,
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation.id,
'email': email,
'error': str(e),
},
)
# Don't fail the invitation creation if email fails
# The user can still access via direct link
return invitation
@staticmethod
async def create_invitations_batch(
org_id: UUID,
emails: list[str],
role_name: str,
inviter_id: UUID,
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
"""Create multiple organization invitations concurrently.
Validates permissions once upfront, then creates invitations in parallel.
Args:
org_id: Organization UUID
emails: List of invitee email addresses
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitations
Returns:
Tuple of (successful_invitations, failed_emails_with_errors)
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
"""
logger.info(
'Creating batch organization invitations',
extra={
'org_id': str(org_id),
'email_count': len(emails),
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate permissions upfront (shared for all emails)
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 2: Create invitations concurrently
async def create_single(
email: str,
) -> tuple[str, OrgInvitation | None, str | None]:
"""Create single invitation, return (email, invitation, error)."""
try:
invitation = await OrgInvitationService.create_invitation(
org_id=org_id,
email=email,
role_name=role_name,
inviter_id=inviter_id,
)
return (email, invitation, None)
except (UserAlreadyMemberError, ValueError) as e:
return (email, None, str(e))
results = await asyncio.gather(*[create_single(email) for email in emails])
# Step 3: Separate successes and failures
successful: list[OrgInvitation] = []
failed: list[tuple[str, str]] = []
for email, invitation, error in results:
if invitation:
successful.append(invitation)
elif error:
failed.append((email, error))
logger.info(
'Batch invitation creation completed',
extra={
'org_id': str(org_id),
'successful': len(successful),
'failed': len(failed),
},
)
return successful, failed
@staticmethod
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
"""Accept an organization invitation.
This method:
1. Validates the token and invitation status
2. Checks expiration
3. Verifies user is not already a member
4. Creates LiteLLM integration
5. Adds user to the organization
6. Marks invitation as accepted
Args:
token: The invitation token
user_id: The user accepting the invitation
Returns:
OrgInvitation: The accepted invitation
Raises:
InvitationInvalidError: If token is invalid or invitation not pending
InvitationExpiredError: If invitation has expired
UserAlreadyMemberError: If user is already a member
"""
logger.info(
'Accepting organization invitation',
extra={
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
'user_id': str(user_id),
},
)
# Step 1: Get and validate invitation
invitation = await OrgInvitationStore.get_invitation_by_token(token)
if not invitation:
raise InvitationInvalidError('Invalid invitation token')
if invitation.status != OrgInvitation.STATUS_PENDING:
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
raise InvitationInvalidError('Invitation has already been accepted')
elif invitation.status == OrgInvitation.STATUS_REVOKED:
raise InvitationInvalidError('Invitation has been revoked')
else:
raise InvitationInvalidError('Invitation is no longer valid')
# Step 2: Check expiration
if OrgInvitationStore.is_token_expired(invitation):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
raise InvitationExpiredError('Invitation has expired')
# Step 2.5: Verify user email matches invitation email
user = await UserStore.get_user_by_id_async(str(user_id))
if not user:
raise InvitationInvalidError('User not found')
user_email = user.email
# Fallback: fetch email from Keycloak if not in database (for existing users)
if not user_email:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
user_email = user_info.get('email') if user_info else None
if not user_email:
raise EmailMismatchError('Your account does not have an email address')
user_email = user_email.lower().strip()
invitation_email = invitation.email.lower().strip()
if user_email != invitation_email:
logger.warning(
'Email mismatch during invitation acceptance',
extra={
'user_id': str(user_id),
'user_email': user_email,
'invitation_email': invitation_email,
'invitation_id': invitation.id,
},
)
raise EmailMismatchError()
# Step 3: Check if user is already a member
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
if existing_member:
raise UserAlreadyMemberError(
'You are already a member of this organization'
)
# Step 4: Create LiteLLM integration for the user in the new org
try:
settings = await OrgService.create_litellm_integration(
invitation.org_id, str(user_id)
)
except Exception as e:
logger.error(
'Failed to create LiteLLM integration for invitation acceptance',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'error': str(e),
},
)
raise InvitationInvalidError(
'Failed to set up organization access. Please try again.'
)
# Step 5: Add user to organization
from storage.org_member_store import OrgMemberStore as OMS
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
# Don't override with org defaults - use invitation-specified role
org_member_kwargs.pop('llm_model', None)
org_member_kwargs.pop('llm_base_url', None)
OrgMemberStore.add_user_to_org(
org_id=invitation.org_id,
user_id=user_id,
role_id=invitation.role_id,
llm_api_key=settings.llm_api_key,
status='active',
)
# Step 6: Mark invitation as accepted
updated_invitation = await OrgInvitationStore.update_invitation_status(
invitation.id,
OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
logger.info(
'Organization invitation accepted',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'role_id': invitation.role_id,
},
)
return updated_invitation

View File

@@ -16,10 +16,12 @@ from server.routes.org_models import (
OrgMemberUpdate,
RoleNotFoundError,
)
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member_store import OrgMemberStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import call_sync_from_async
@@ -65,10 +67,18 @@ class OrgMemberService:
org_id: UUID,
current_user_id: UUID,
page_id: str | None = None,
limit: int = 100,
limit: int = 10,
email_filter: str | None = None,
) -> tuple[bool, str | None, OrgMemberPage | None]:
"""Get organization members with authorization check.
Args:
org_id: Organization UUID.
current_user_id: Requesting user's UUID.
page_id: Offset encoded as string (e.g., "0", "10", "20").
limit: Items per page (default 10).
email_filter: Optional case-insensitive partial email match.
Returns:
Tuple of (success, error_code, data). If success is True, error_code is None.
"""
@@ -88,8 +98,11 @@ class OrgMemberService:
return False, 'invalid_page_id', None
# Call store to get paginated members
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=offset, limit=limit
members, _ = await OrgMemberStore.get_org_members_paginated(
org_id=org_id,
offset=offset,
limit=limit,
email_filter=email_filter,
)
# Transform data to response format
@@ -104,18 +117,53 @@ class OrgMemberService:
user_id=str(member.user_id),
email=user.email if user else None,
role_id=member.role_id,
role_name=role.name if role else '',
role=role.name if role else '',
role_rank=role.rank if role else 0,
status=member.status,
)
)
# Calculate next_page_id
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
# Calculate current page (1-indexed)
current_page = (offset // limit) + 1
return True, None, OrgMemberPage(items=items, next_page_id=next_page_id)
return (
True,
None,
OrgMemberPage(
items=items,
current_page=current_page,
per_page=limit,
),
)
@staticmethod
async def get_org_members_count(
org_id: UUID,
current_user_id: UUID,
email_filter: str | None = None,
) -> int:
"""Get count of organization members with authorization check.
Args:
org_id: Organization UUID.
current_user_id: Requesting user's UUID.
email_filter: Optional case-insensitive partial email match.
Returns:
int: Count of organization members matching the filter.
Raises:
OrgMemberNotFoundError: If requesting user is not a member of the organization.
"""
# Verify current user is a member of the organization
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
if not requester_membership:
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
return await OrgMemberStore.get_org_members_count(
org_id=org_id,
email_filter=email_filter,
)
@staticmethod
async def remove_org_member(
@@ -168,9 +216,42 @@ class OrgMemberService:
if not success:
return False, 'removal_failed'
# Update user's current_org_id if it points to the org they were removed from
user = UserStore.get_user_by_id(str(target_user_id))
if user and user.current_org_id == org_id:
# Set current_org_id to personal workspace (org.id == user.id)
UserStore.update_current_org(str(target_user_id), target_user_id)
return True, None
return await call_sync_from_async(_remove_member)
success, error = await call_sync_from_async(_remove_member)
# If database removal succeeded, also remove from LiteLLM team
if success:
try:
await LiteLlmManager.remove_user_from_team(
str(target_user_id), str(org_id)
)
logger.info(
'Successfully removed user from LiteLLM team',
extra={
'user_id': str(target_user_id),
'org_id': str(org_id),
},
)
except Exception as e:
# Log but don't fail the operation - database removal already succeeded
# LiteLLM state will be eventually consistent
logger.warning(
'Failed to remove user from LiteLLM team',
extra={
'user_id': str(target_user_id),
'org_id': str(org_id),
'error': str(e),
},
)
return success, error
@staticmethod
async def update_org_member(
@@ -240,7 +321,7 @@ class OrgMemberService:
user_id=str(target_membership.user_id),
email=user.email if user else None,
role_id=target_membership.role_id,
role_name=target_role.name,
role=target_role.name,
role_rank=target_role.rank,
status=target_membership.status,
)
@@ -280,7 +361,7 @@ class OrgMemberService:
user_id=str(updated_member.user_id),
email=user.email if user else None,
role_id=updated_member.role_id,
role_name=new_role.name,
role=new_role.name,
role_rank=new_role.rank,
status=updated_member.status,
)

View File

@@ -22,11 +22,70 @@ from openhands.app_server.app_conversation.app_conversation_models import (
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.errors import AuthError
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import ADMIN
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
async def _get_current_user(self) -> User | None:
"""Get the current user using the existing db_session.
Uses self.db_session to avoid opening a separate database session.
Returns:
User object or None if no user_id is available
"""
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
return None
user_id_uuid = UUID(user_id_str)
result = await self.db_session.execute(
select(User).where(User.id == user_id_uuid)
)
return result.scalars().first()
async def _apply_user_and_org_filter(self, query):
"""Apply user_id and org_id filters to ensure conversation isolation.
Filters conversations by:
- user_id: Only show conversations belonging to the current user
- org_id: Only show conversations belonging to the user's current organization
Args:
query: SQLAlchemy query to apply filters to
Returns:
Query with user and organization filters applied
Raises:
AuthError: If no user_id is available (secure default: deny access)
"""
# For internal operations such as getting a conversation by session_api_key
# we need a mode that does not have filtering. The dependency `as_admin()`
# is used to enable it
if self.user_context == ADMIN:
return query
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
# Secure default: no user means no access, not "show everything"
raise AuthError('User authentication required')
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
# Filter by organization ID to ensure conversations are isolated per organization
user = await self._get_current_user()
if user and user.current_org_id is not None:
query = query.where(
StoredConversationMetadataSaas.org_id == user.current_org_id
)
return query
async def _secure_select(self):
query = (
@@ -38,13 +97,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
return await self._apply_user_and_org_filter(query)
async def _secure_select_with_saas_metadata(self):
"""Select query that includes SAAS metadata for retrieving user_id."""
@@ -57,13 +110,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
return await self._apply_user_and_org_filter(query)
async def search_app_conversation_info(
self,
@@ -155,21 +202,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Count conversations matching the given filters with SAAS metadata."""
query = (
select(func.count(StoredConversationMetadata.conversation_id))
.select_from(
StoredConversationMetadata.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
# Apply user filtering
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
# Apply user and organization filtering
query = await self._apply_user_and_org_filter(query)
query = self._apply_filters_with_saas_metadata(
query=query,

View File

@@ -20,8 +20,10 @@ from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation
from storage.resend_synced_user import ResendSyncedUser
from storage.role import Role
from storage.slack_conversation import SlackConversation
from storage.slack_team import SlackTeam
@@ -65,8 +67,10 @@ __all__ = [
'MaintenanceTaskStatus',
'OpenhandsPR',
'Org',
'OrgInvitation',
'OrgMember',
'ProactiveConversation',
'ResendSyncedUser',
'Role',
'SlackConversation',
'SlackTeam',

View File

@@ -126,7 +126,7 @@ class ApiKeyStore:
return True
async def list_api_keys(self, user_id: str) -> list[dict]:
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
@@ -134,24 +134,14 @@ class ApiKeyStore:
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
with self.session_maker() as session:
keys = (
keys: list[ApiKey] = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
return [
{
'id': key.id,
'name': key.name,
'created_at': key.created_at,
'last_used_at': key.last_used_at,
'expires_at': key.expires_at,
}
for key in keys
if 'MCP_API_KEY' != key.name
]
return [key for key in keys if key.name != 'MCP_API_KEY']
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)

View File

@@ -4,7 +4,9 @@ import time
from dataclasses import dataclass
from typing import Awaitable, Callable, Dict
from sqlalchemy import select, update
from server.auth.auth_error import TokenRefreshError
from sqlalchemy import select, text, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import sessionmaker
from storage.auth_tokens import AuthTokens
from storage.database import a_session_maker
@@ -12,6 +14,14 @@ from storage.database import a_session_maker
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import ProviderType
# Time buffer (in seconds) before actual expiration to consider token expired
# This ensures tokens are refreshed before they actually expire. The
# github default is 8 hours, so 15 minutes leeway is ~3% of this.
ACCESS_TOKEN_EXPIRY_BUFFER = 900 # 15 minutes
# Database lock timeout to prevent indefinite blocking
LOCK_TIMEOUT_SECONDS = 5
@dataclass
class AuthTokenStore:
@@ -23,6 +33,31 @@ class AuthTokenStore:
def identity_provider_value(self) -> str:
return self.idp.value
def _is_token_expired(
self, access_token_expires_at: int, refresh_token_expires_at: int
) -> tuple[bool, bool]:
"""Check if access and refresh tokens are expired.
Args:
access_token_expires_at: Expiration time for access token (seconds since epoch)
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
Returns:
Tuple of (access_expired, refresh_expired)
"""
current_time = int(time.time())
access_expired = (
False
if access_token_expires_at == 0
else access_token_expires_at < current_time + ACCESS_TOKEN_EXPIRY_BUFFER
)
refresh_expired = (
False
if refresh_token_expires_at == 0
else refresh_token_expires_at < current_time
)
return access_expired, refresh_expired
async def store_tokens(
self,
access_token: str,
@@ -73,87 +108,149 @@ class AuthTokenStore:
]
| None = None,
) -> Dict[str, str | int] | None:
"""
Load authentication tokens from the database and refresh them if necessary.
"""Load authentication tokens from the database and refresh them if necessary.
This method retrieves the current authentication tokens for the user and checks if they have expired.
It uses the provided `check_expiration_and_refresh` function to determine if the tokens need
to be refreshed and to refresh the tokens if needed.
This method uses a double-checked locking pattern to minimize lock contention:
1. First, check if the token is valid WITHOUT acquiring a lock (fast path)
2. If refresh is needed, acquire a lock with a timeout
3. Double-check if refresh is still needed (another request may have refreshed)
4. Perform the refresh if still needed
The method ensures that only one refresh operation is performed per refresh token by using a
row-level lock on the token record.
The method is designed to handle race conditions where multiple requests might attempt to refresh
the same token simultaneously, ensuring that only one refresh call occurs per refresh token.
The row-level lock ensures that only one refresh operation is performed per
refresh token, which is important because most IDPs invalidate the old refresh
token after it's used once.
Args:
check_expiration_and_refresh (Callable, optional): A function that checks if the tokens have expired
and attempts to refresh them. It should return a dictionary containing the new access_token, refresh_token,
and their respective expiration timestamps. If no refresh is needed, it should return `None`.
check_expiration_and_refresh: A function that checks if the tokens have
expired and attempts to refresh them. It should return a dictionary
containing the new access_token, refresh_token, and their respective
expiration timestamps. If no refresh is needed, it should return None.
Returns:
Dict[str, str | int] | None:
A dictionary containing the access_token, refresh_token, access_token_expires_at,
and refresh_token_expires_at. If no token record is found, returns `None`.
A dictionary containing the access_token, refresh_token,
access_token_expires_at, and refresh_token_expires_at.
If no token record is found, returns None.
Raises:
TokenRefreshError: If the lock cannot be acquired within the timeout
period. This typically means another request is holding the lock
for an extended period. Callers should handle this by returning
a 401 response to prompt the user to re-authenticate.
"""
# FAST PATH: Check without lock first to avoid unnecessary lock contention
async with self.a_session_maker() as session:
async with session.begin(): # Ensures transaction management
# Lock the row while we check if we need to refresh the tokens.
# There is a race condition where 2 or more calls can load tokens simultaneously.
# If it turns out the loaded tokens are expired, then there will be multiple
# refresh token calls with the same refresh token. Most IDPs only allow one refresh
# per refresh token. This lock ensure that only one refresh call occurs per refresh token
result = await session.execute(
select(AuthTokens)
.filter(
AuthTokens.keycloak_user_id == self.keycloak_user_id,
AuthTokens.identity_provider == self.identity_provider_value,
)
.with_for_update()
result = await session.execute(
select(AuthTokens).filter(
AuthTokens.keycloak_user_id == self.keycloak_user_id,
AuthTokens.identity_provider == self.identity_provider_value,
)
token_record = result.scalars().one_or_none()
)
token_record = result.scalars().one_or_none()
if not token_record:
return None
if not token_record:
return None
token_refresh = (
await check_expiration_and_refresh(
# Check if token needs refresh
access_expired, _ = self._is_token_expired(
token_record.access_token_expires_at,
token_record.refresh_token_expires_at,
)
# If token is still valid, return it without acquiring a lock
if not access_expired or check_expiration_and_refresh is None:
return {
'access_token': token_record.access_token,
'refresh_token': token_record.refresh_token,
'access_token_expires_at': token_record.access_token_expires_at,
'refresh_token_expires_at': token_record.refresh_token_expires_at,
}
# SLOW PATH: Token needs refresh, acquire lock
try:
async with self.a_session_maker() as session:
async with session.begin():
# Set a lock timeout to prevent indefinite blocking
# This ensures we don't hold connections forever if something goes wrong
await session.execute(
text(f"SET LOCAL lock_timeout = '{LOCK_TIMEOUT_SECONDS}s'")
)
# Acquire row-level lock to prevent concurrent refresh attempts
result = await session.execute(
select(AuthTokens)
.filter(
AuthTokens.keycloak_user_id == self.keycloak_user_id,
AuthTokens.identity_provider
== self.identity_provider_value,
)
.with_for_update()
)
token_record = result.scalars().one_or_none()
if not token_record:
return None
# Double-check: another request may have refreshed while we waited for the lock
access_expired, _ = self._is_token_expired(
token_record.access_token_expires_at,
token_record.refresh_token_expires_at,
)
if not access_expired:
# Token was refreshed by another request while we waited
logger.debug(
'Token was refreshed by another request while waiting for lock'
)
return {
'access_token': token_record.access_token,
'refresh_token': token_record.refresh_token,
'access_token_expires_at': token_record.access_token_expires_at,
'refresh_token_expires_at': token_record.refresh_token_expires_at,
}
# We're the one doing the refresh
token_refresh = await check_expiration_and_refresh(
self.idp,
token_record.refresh_token,
token_record.access_token_expires_at,
token_record.refresh_token_expires_at,
)
if check_expiration_and_refresh
else None
)
if token_refresh:
await session.execute(
update(AuthTokens)
.where(AuthTokens.id == token_record.id)
.values(
access_token=token_refresh['access_token'],
refresh_token=token_refresh['refresh_token'],
access_token_expires_at=token_refresh[
'access_token_expires_at'
],
refresh_token_expires_at=token_refresh[
'refresh_token_expires_at'
],
if token_refresh:
await session.execute(
update(AuthTokens)
.where(AuthTokens.id == token_record.id)
.values(
access_token=token_refresh['access_token'],
refresh_token=token_refresh['refresh_token'],
access_token_expires_at=token_refresh[
'access_token_expires_at'
],
refresh_token_expires_at=token_refresh[
'refresh_token_expires_at'
],
)
)
)
await session.commit()
await session.commit()
return (
token_refresh
if token_refresh
else {
'access_token': token_record.access_token,
'refresh_token': token_record.refresh_token,
'access_token_expires_at': token_record.access_token_expires_at,
'refresh_token_expires_at': token_record.refresh_token_expires_at,
}
)
return (
token_refresh
if token_refresh
else {
'access_token': token_record.access_token,
'refresh_token': token_record.refresh_token,
'access_token_expires_at': token_record.access_token_expires_at,
'refresh_token_expires_at': token_record.refresh_token_expires_at,
}
)
except OperationalError as e:
# Lock timeout - another request is holding the lock for too long
logger.warning(
f'Token refresh lock timeout for user {self.keycloak_user_id}: {e}'
)
raise TokenRefreshError(
'Unable to refresh token due to lock timeout. Please try again.'
) from e
async def is_access_token_valid(self) -> bool:
"""Check if the access token is still valid.
@@ -194,8 +291,8 @@ class AuthTokenStore:
"""Get an instance of the AuthTokenStore.
Args:
config: The application configuration
keycloak_user_id: The Keycloak user ID
idp: The identity provider type
Returns:
An instance of AuthTokenStore

View File

@@ -18,17 +18,17 @@ def _get_db_session_injector():
return _config.db_session
def session_maker():
def session_maker(**kwargs):
db_session_injector = _get_db_session_injector()
session_maker = db_session_injector.get_session_maker()
return session_maker()
factory = db_session_injector.get_session_maker()
return factory(**kwargs)
@contextlib.asynccontextmanager
async def a_session_maker():
async def a_session_maker(**kwargs):
db_session_injector = _get_db_session_injector()
a_session_maker = await db_session_injector.get_async_session_maker()
async with a_session_maker() as session:
factory = await db_session_injector.get_async_session_maker()
async with factory(**kwargs) as session:
yield session

View File

@@ -10,7 +10,6 @@ import httpx
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import (
DEFAULT_INITIAL_BUDGET,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
@@ -44,6 +43,34 @@ def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
class LiteLlmManager:
"""Manage LiteLLM interactions."""
@staticmethod
def get_budget_from_team_info(
user_team_info: dict | None, user_id: str, org_id: str
) -> tuple[float, float]:
"""Extract max_budget and spend from user team info.
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
Args:
user_team_info: The response from get_user_team_info
user_id: The user's ID
org_id: The organization's ID
Returns:
Tuple of (max_budget, spend)
"""
if not user_team_info:
return 0, 0
spend = user_team_info.get('spend', 0)
if user_id == org_id:
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
else:
max_budget = user_team_info.get('max_budget_in_team') or 0
return max_budget, spend
@staticmethod
async def create_entries(
org_id: str,
@@ -72,8 +99,33 @@ class LiteLlmManager:
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
# Check if team already exists and get its budget
# New users joining existing orgs should inherit the team's budget
team_budget = 0.0
try:
existing_team = await LiteLlmManager._get_team(client, org_id)
if existing_team:
team_info = existing_team.get('team_info', {})
team_budget = team_info.get('max_budget', 0.0) or 0.0
logger.info(
'LiteLlmManager:create_entries:existing_team_budget',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'team_budget': team_budget,
},
)
except httpx.HTTPStatusError as e:
# Team doesn't exist yet (404) - this is expected for first user
if e.response.status_code != 404:
raise
logger.info(
'LiteLlmManager:create_entries:no_existing_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
client, keycloak_user_id, org_id, team_budget
)
if create_user:
@@ -82,7 +134,7 @@ class LiteLlmManager:
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
client, keycloak_user_id, org_id, team_budget
)
key = await LiteLlmManager._generate_key(
@@ -894,21 +946,31 @@ class LiteLlmManager:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
team_response = await LiteLlmManager._get_team(client, team_id)
if not team_response:
return None
# Filter team_memberships based on team_id and keycloak_user_id
user_membership = next(
(
membership
for membership in team_info.get('team_memberships', [])
for membership in team_response.get('team_memberships', [])
if membership.get('user_id') == keycloak_user_id
and membership.get('team_id') == team_id
),
None,
)
if not user_membership:
return None
# For team orgs (user_id != team_id), include team-level budget info
# The team's max_budget and spend are shared across all members
if keycloak_user_id != team_id:
team_info = team_response.get('team_info', {})
user_membership['max_budget_in_team'] = team_info.get('max_budget')
user_membership['spend'] = team_info.get('spend', 0)
return user_membership
@staticmethod

View File

@@ -51,6 +51,9 @@ class Org(Base): # type: ignore
# Relationships
org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org')
invitations = relationship(
'OrgInvitation', back_populates='org', passive_deletes=True
)
billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org'

View File

@@ -0,0 +1,59 @@
"""
SQLAlchemy model for Organization Invitation.
"""
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.orm import relationship
from storage.base import Base
class OrgInvitation(Base): # type: ignore
"""Organization invitation model.
Represents an invitation for a user to join an organization.
Invitations are created by organization owners/admins and contain
a secure token that can be used to accept the invitation.
"""
__tablename__ = 'org_invitation'
id = Column(Integer, primary_key=True, autoincrement=True)
token = Column(String(64), nullable=False, unique=True, index=True)
org_id = Column(
UUID(as_uuid=True),
ForeignKey('org.id', ondelete='CASCADE'),
nullable=False,
index=True,
)
email = Column(String(255), nullable=False, index=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
status = Column(
String(20),
nullable=False,
server_default=text("'pending'"),
)
created_at = Column(
DateTime,
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
expires_at = Column(DateTime, nullable=False)
accepted_at = Column(DateTime, nullable=True)
accepted_by_user_id = Column(
UUID(as_uuid=True),
ForeignKey('user.id'),
nullable=True,
)
# Relationships
org = relationship('Org', back_populates='invitations')
role = relationship('Role')
inviter = relationship('User', foreign_keys=[inviter_id])
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
# Status constants
STATUS_PENDING = 'pending'
STATUS_ACCEPTED = 'accepted'
STATUS_REVOKED = 'revoked'
STATUS_EXPIRED = 'expired'

View File

@@ -0,0 +1,227 @@
"""
Store class for managing organization invitations.
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import and_, select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.org_invitation import OrgInvitation
from openhands.core.logger import openhands_logger as logger
# Invitation token configuration
INVITATION_TOKEN_PREFIX = 'inv-'
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
DEFAULT_EXPIRATION_DAYS = 7
class OrgInvitationStore:
"""Store for managing organization invitations."""
@staticmethod
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
"""Generate a secure invitation token.
Uses cryptographically secure random generation for tokens.
Pattern from api_key_store.py.
Args:
length: Length of the random part of the token
Returns:
str: Token with prefix (e.g., 'inv-aBcDeF123...')
"""
alphabet = string.ascii_letters + string.digits
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_id: int,
inviter_id: UUID,
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
) -> OrgInvitation:
"""Create a new organization invitation.
Args:
org_id: Organization UUID
email: Invitee's email address
role_id: Role ID to assign on acceptance
inviter_id: User ID of the person creating the invitation
expiration_days: Days until the invitation expires
Returns:
OrgInvitation: The created invitation record
"""
async with a_session_maker() as session:
token = OrgInvitationStore.generate_token()
# Use timezone-naive datetime for database compatibility
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
invitation = OrgInvitation(
token=token,
org_id=org_id,
email=email.lower().strip(),
role_id=role_id,
inviter_id=inviter_id,
status=OrgInvitation.STATUS_PENDING,
expires_at=expires_at,
)
session.add(invitation)
await session.commit()
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.role))
.filter(OrgInvitation.id == invitation.id)
)
invitation = result.scalars().first()
logger.info(
'Created organization invitation',
extra={
'invitation_id': invitation.id,
'org_id': str(org_id),
'email': email,
'inviter_id': str(inviter_id),
'expires_at': expires_at.isoformat(),
},
)
return invitation
@staticmethod
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
"""Get an invitation by its token.
Args:
token: The invitation token
Returns:
OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
.filter(OrgInvitation.token == token)
)
return result.scalars().first()
@staticmethod
async def get_pending_invitation(
org_id: UUID, email: str
) -> Optional[OrgInvitation]:
"""Get a pending invitation for an email in an organization.
Args:
org_id: Organization UUID
email: Email address to check
Returns:
OrgInvitation or None if no pending invitation exists
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(
and_(
OrgInvitation.org_id == org_id,
OrgInvitation.email == email.lower().strip(),
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
)
)
)
return result.scalars().first()
@staticmethod
async def update_invitation_status(
invitation_id: int,
status: str,
accepted_by_user_id: Optional[UUID] = None,
) -> Optional[OrgInvitation]:
"""Update an invitation's status.
Args:
invitation_id: The invitation ID
status: New status (pending, accepted, revoked, expired)
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
Returns:
Updated OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
)
invitation = result.scalars().first()
if not invitation:
return None
old_status = invitation.status
invitation.status = status
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
# Use timezone-naive datetime for database compatibility
invitation.accepted_at = datetime.utcnow()
invitation.accepted_by_user_id = accepted_by_user_id
await session.commit()
await session.refresh(invitation)
logger.info(
'Updated invitation status',
extra={
'invitation_id': invitation_id,
'old_status': old_status,
'new_status': status,
'accepted_by_user_id': (
str(accepted_by_user_id) if accepted_by_user_id else None
),
},
)
return invitation
@staticmethod
def is_token_expired(invitation: OrgInvitation) -> bool:
"""Check if an invitation token has expired.
Args:
invitation: The invitation to check
Returns:
bool: True if expired, False otherwise
"""
# Use timezone-naive datetime for comparison (database stores without timezone)
now = datetime.utcnow()
return invitation.expires_at < now
@staticmethod
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
"""Check if invitation is expired and update status if needed.
Args:
invitation: The invitation to check
Returns:
bool: True if invitation was marked as expired, False otherwise
"""
if (
invitation.status == OrgInvitation.STATUS_PENDING
and OrgInvitationStore.is_token_expired(invitation)
):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
return True
return False

View File

@@ -5,10 +5,11 @@ Store class for managing organization-member relationships.
from typing import Optional
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import func, select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.org_member import OrgMember
from storage.user import User
from storage.user_settings import UserSettings
from openhands.storage.data_models.settings import Settings
@@ -60,6 +61,51 @@ class OrgMemberStore:
)
return result.scalars().first()
@staticmethod
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
"""Get the org member for a user's current organization.
Args:
user_id: The user's UUID.
Returns:
The OrgMember for the user's current organization, or None if not found.
"""
with session_maker() as session:
result = (
session.query(OrgMember)
.join(User, User.id == OrgMember.user_id)
.filter(
User.id == user_id,
OrgMember.org_id == User.current_org_id,
)
.first()
)
return result
@staticmethod
async def get_org_member_for_current_org_async(
user_id: UUID,
) -> Optional[OrgMember]:
"""Get the org member for a user's current organization (async version).
Args:
user_id: The user's UUID.
Returns:
The OrgMember for the user's current organization, or None if not found.
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgMember)
.join(User, User.id == OrgMember.user_id)
.filter(
User.id == user_id,
OrgMember.org_id == User.current_org_id,
)
)
return result.scalars().first()
@staticmethod
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
"""Get all organizations for a user."""
@@ -137,14 +183,48 @@ class OrgMemberStore:
}
return kwargs
@staticmethod
async def get_org_members_count(
org_id: UUID,
email_filter: str | None = None,
) -> int:
"""Get total count of organization members, optionally filtered by email.
Args:
org_id: Organization UUID.
email_filter: Optional case-insensitive partial email match.
Returns:
Total count of matching members.
"""
async with a_session_maker() as session:
query = select(func.count(OrgMember.user_id)).filter(
OrgMember.org_id == org_id
)
if email_filter:
query = query.join(User, User.id == OrgMember.user_id).filter(
User.email.ilike(f'%{email_filter}%')
)
result = await session.execute(query)
return result.scalar() or 0
@staticmethod
async def get_org_members_paginated(
org_id: UUID,
offset: int = 0,
limit: int = 100,
email_filter: str | None = None,
) -> tuple[list[OrgMember], bool]:
"""Get paginated list of organization members with user and role info.
Args:
org_id: Organization UUID.
offset: Number of records to skip.
limit: Maximum number of records to return.
email_filter: Optional case-insensitive partial email match.
Returns:
Tuple of (members_list, has_more) where has_more indicates if there are more results.
"""
@@ -154,13 +234,18 @@ class OrgMemberStore:
query = (
select(OrgMember)
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
.join(User, User.id == OrgMember.user_id)
.filter(OrgMember.org_id == org_id)
.order_by(OrgMember.user_id)
.offset(offset)
.limit(limit + 1)
)
# Apply email filter if provided
if email_filter:
query = query.filter(User.email.ilike(f'%{email_filter}%'))
query = query.order_by(OrgMember.user_id).offset(offset).limit(limit + 1)
result = await session.execute(query)
members = list(result.scalars().all())
members = list(result.unique().scalars().all())
# Check if there are more results
has_more = len(members) > limit

View File

@@ -656,10 +656,9 @@ class OrgService:
)
return None
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(org_id)
)
spend = user_team_info.get('spend', 0)
credits = max(max_budget - spend, 0)
logger.debug(

View File

@@ -0,0 +1,35 @@
"""SQLAlchemy model for tracking users synced to Resend audiences."""
from datetime import UTC, datetime
from uuid import uuid4
from sqlalchemy import Column, DateTime, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from storage.base import Base
class ResendSyncedUser(Base): # type: ignore
"""Tracks users that have been synced to a Resend audience.
This table ensures that once a user is synced to a Resend audience,
they won't be re-added even if they are later deleted from the
Resend UI. This respects manual deletions/unsubscribes.
"""
__tablename__ = 'resend_synced_users'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
email = Column(String, nullable=False, index=True)
audience_id = Column(String, nullable=False, index=True)
synced_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
keycloak_user_id = Column(String, nullable=True)
__table_args__ = (
UniqueConstraint(
'email', 'audience_id', name='uq_resend_synced_email_audience'
),
)

View File

@@ -0,0 +1,125 @@
"""Store class for managing Resend synced users."""
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Optional, Set
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import sessionmaker
from storage.resend_synced_user import ResendSyncedUser
@dataclass
class ResendSyncedUserStore:
"""Store for tracking users synced to Resend audiences."""
session_maker: sessionmaker
def is_user_synced(self, email: str, audience_id: str) -> bool:
"""Check if a user has been synced to a specific audience.
Args:
email: The email address to check.
audience_id: The Resend audience ID.
Returns:
True if the user has been synced, False otherwise.
"""
with self.session_maker() as session:
stmt = select(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt).first()
return result is not None
def get_synced_emails_for_audience(self, audience_id: str) -> Set[str]:
"""Get all synced email addresses for a specific audience.
Args:
audience_id: The Resend audience ID.
Returns:
A set of lowercase email addresses that have been synced.
"""
with self.session_maker() as session:
stmt = select(ResendSyncedUser.email).where(
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt).scalars().all()
return set(result)
def mark_user_synced(
self,
email: str,
audience_id: str,
keycloak_user_id: Optional[str] = None,
) -> ResendSyncedUser:
"""Mark a user as synced to a specific audience.
Uses upsert to handle race conditions - if the user is already
marked as synced, this is a no-op.
Args:
email: The email address of the user.
audience_id: The Resend audience ID.
keycloak_user_id: Optional Keycloak user ID.
Returns:
The ResendSyncedUser record.
Raises:
RuntimeError: If the record could not be created or retrieved.
"""
with self.session_maker() as session:
stmt = (
insert(ResendSyncedUser)
.values(
email=email.lower(),
audience_id=audience_id,
keycloak_user_id=keycloak_user_id,
synced_at=datetime.now(UTC),
)
.on_conflict_do_nothing(constraint='uq_resend_synced_email_audience')
.returning(ResendSyncedUser)
)
result = session.execute(stmt)
session.commit()
row = result.first()
if row:
return row[0]
# on_conflict_do_nothing triggered, fetch the existing record
existing = session.execute(
select(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
).first()
if existing:
return existing[0]
raise RuntimeError(
f'Failed to create or retrieve synced user record for {email}'
)
def remove_synced_user(self, email: str, audience_id: str) -> bool:
"""Remove a user's synced status for a specific audience.
Args:
email: The email address of the user.
audience_id: The Resend audience ID.
Returns:
True if a record was deleted, False if no record existed.
"""
with self.session_maker() as session:
stmt = delete(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt)
session.commit()
return result.rowcount > 0

View File

@@ -29,6 +29,20 @@ class RoleStore:
with session_maker() as session:
return session.query(Role).filter(Role.id == role_id).first()
@staticmethod
async def get_role_by_id_async(
role_id: int,
session: Optional[AsyncSession] = None,
) -> Optional[Role]:
"""Get role by ID (async version)."""
if session is not None:
result = await session.execute(select(Role).where(Role.id == role_id))
return result.scalars().first()
async with a_session_maker() as session:
result = await session.execute(select(Role).where(Role.id == role_id))
return result.scalars().first()
@staticmethod
def get_role_by_name(name: str) -> Optional[Role]:
"""Get role by name."""

View File

@@ -83,6 +83,8 @@ class UserStore:
role_id=role_id,
**user_kwargs,
)
user.email = user_info.get('email')
user.email_verified = user_info.get('email_verified')
session.add(user)
role = RoleStore.get_role_by_name('owner')
@@ -768,6 +770,30 @@ class UserStore:
finally:
await UserStore._release_user_creation_lock(user_id)
@staticmethod
async def get_user_by_email_async(email: str) -> Optional[User]:
"""Get user by email address (async version).
This method looks up a user by their email address. Note that email
addresses may not be unique across all users in rare cases.
Args:
email: The email address to search for
Returns:
User: The user with the matching email, or None if not found
"""
if not email:
return None
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.email == email.lower().strip())
)
return result.scalars().first()
@staticmethod
def list_users() -> list[User]:
"""List all users."""

View File

@@ -1,11 +1,19 @@
import asyncio
import asyncio # noqa: I001
from storage.proactive_conversation_store import ProactiveConversationStore
# This must be before the import of storage
# to set up logging and prevent alembic from
# running its mouth.
from openhands.core.logger import openhands_logger
from storage.proactive_conversation_store import (
ProactiveConversationStore,
)
OLDER_THAN = 30 # 30 minutes
async def main():
openhands_logger.info('clean_proactive_convo_table')
convo_store = ProactiveConversationStore()
await convo_store.clean_old_convos(older_than_minutes=OLDER_THAN)

View File

@@ -35,6 +35,7 @@ import resend
from keycloak.exceptions import KeycloakError
from resend.exceptions import ResendError
from server.auth.token_manager import get_keycloak_admin
from storage.resend_synced_user_store import ResendSyncedUserStore
from tenacity import (
retry,
retry_if_exception_type,
@@ -69,9 +70,6 @@ RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
# Set up Resend API
resend.api_key = RESEND_API_KEY
print('resend module', resend)
print('has contacts', hasattr(resend, 'Contacts'))
class ResendSyncError(Exception):
"""Base exception for Resend sync errors."""
@@ -98,7 +96,7 @@ class ResendAPIError(ResendSyncError):
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
def is_valid_email(email: str) -> bool:
def is_valid_email(email: Optional[str]) -> bool:
"""Validate an email address format.
This uses a regex pattern that matches most valid email addresses
@@ -106,10 +104,10 @@ def is_valid_email(email: str) -> bool:
does not accept (e.g., exclamation marks).
Args:
email: The email address to validate.
email: The email address to validate, or None.
Returns:
True if the email is valid, False otherwise.
True if the email is valid, False otherwise (including for None).
"""
if not email:
return False
@@ -199,8 +197,6 @@ def get_resend_contacts(audience_id: str) -> Dict[str, Dict[str, Any]]:
Raises:
ResendAPIError: If the API call fails.
"""
print('getting resend contacts')
print('has resend contacts', hasattr(resend, 'Contacts'))
try:
contacts = resend.Contacts.list(audience_id).get('data', [])
# Create a dictionary mapping email addresses to contact data for
@@ -255,6 +251,15 @@ def add_contact_to_resend(
raise
@retry(
stop=stop_after_attempt(MAX_RETRIES),
wait=wait_exponential(
multiplier=INITIAL_BACKOFF_SECONDS,
max=MAX_BACKOFF_SECONDS,
exp_base=BACKOFF_FACTOR,
),
retry=retry_if_exception_type(ResendError),
)
def send_welcome_email(
email: str,
first_name: Optional[str] = None,
@@ -271,7 +276,7 @@ def send_welcome_email(
The API response.
Raises:
ResendError: If the API call fails.
ResendError: If the API call fails after retries.
"""
try:
# Prepare the recipient name
@@ -317,8 +322,84 @@ def send_welcome_email(
raise
def _get_resend_synced_user_store() -> ResendSyncedUserStore:
"""Get the ResendSyncedUserStore instance.
This is separated into a function to allow for easier testing/mocking.
"""
from openhands.app_server.config import get_global_config
config = get_global_config()
db_session_injector = config.db_session
return ResendSyncedUserStore(session_maker=db_session_injector.get_session_maker())
def _backfill_existing_resend_contacts(
synced_user_store: ResendSyncedUserStore,
audience_id: str,
) -> int:
"""Backfill the synced_users table with contacts already in Resend.
This ensures that users who were added to Resend before the tracking
table existed are properly recorded, preventing duplicate welcome emails.
Args:
synced_user_store: The store for tracking synced users.
audience_id: The Resend audience ID.
Returns:
The number of contacts backfilled.
"""
logger.info('Starting backfill of existing Resend contacts...')
try:
resend_contacts = get_resend_contacts(audience_id)
logger.info(f'Found {len(resend_contacts)} contacts in Resend audience')
already_synced_emails = synced_user_store.get_synced_emails_for_audience(
audience_id
)
logger.info(
f'Found {len(already_synced_emails)} already synced emails in database'
)
backfilled_count = 0
for email in resend_contacts:
if email.lower() not in already_synced_emails:
synced_user_store.mark_user_synced(
email=email,
audience_id=audience_id,
keycloak_user_id=None, # We don't have this info during backfill
)
backfilled_count += 1
logger.debug(f'Backfilled existing Resend contact: {email}')
logger.info(
f'Backfill completed: {backfilled_count} contacts added to tracking'
)
return backfilled_count
except Exception:
logger.exception('Error during backfill of existing Resend contacts')
# Don't fail the entire sync if backfill fails - just log and continue
return 0
def sync_users_to_resend():
"""Sync users from Keycloak to Resend."""
"""Sync users from Keycloak to Resend.
This function syncs users from Keycloak to a Resend audience. It tracks
which users have been synced in the database to ensure that:
1. Users are only added once (even across multiple sync runs)
2. Users who are manually deleted from Resend are not re-added
The tracking is done via the resend_synced_users table, which records
each email/audience_id combination that has been synced.
On first run (or when new contacts exist in Resend), it will backfill
the tracking table with existing Resend contacts to avoid sending
duplicate welcome emails.
"""
# Check required environment variables
required_vars = {
'RESEND_API_KEY': RESEND_API_KEY,
@@ -344,28 +425,36 @@ def sync_users_to_resend():
)
try:
# Get the store for tracking synced users
synced_user_store = _get_resend_synced_user_store()
# Backfill existing Resend contacts into our tracking table
# This ensures users already in Resend don't get duplicate welcome emails
backfilled_count = _backfill_existing_resend_contacts(
synced_user_store, RESEND_AUDIENCE_ID
)
# Get the total number of users
total_users = get_total_keycloak_users()
logger.info(
f'Found {total_users} users in Keycloak realm {KEYCLOAK_REALM_NAME}'
)
# Get contacts from Resend
resend_contacts = get_resend_contacts(RESEND_AUDIENCE_ID)
logger.info(
f'Found {len(resend_contacts)} contacts in Resend audience '
f'{RESEND_AUDIENCE_ID}'
)
# Stats
stats = {
'total_users': total_users,
'existing_contacts': len(resend_contacts),
'backfilled_contacts': backfilled_count,
'already_synced': 0,
'added_contacts': 0,
'skipped_invalid_emails': 0,
'errors': 0,
}
synced_emails = synced_user_store.get_synced_emails_for_audience(
RESEND_AUDIENCE_ID
)
logger.info(f'Found {len(synced_emails)} already synced emails in database')
# Process users in batches
offset = 0
while offset < total_users:
@@ -378,8 +467,12 @@ def sync_users_to_resend():
continue
email = email.lower()
if email in resend_contacts:
logger.debug(f'User {email} already exists in Resend, skipping')
if email in synced_emails:
logger.debug(
f'User {email} was already synced to this audience, skipping'
)
stats['already_synced'] += 1
continue
# Validate email format before attempting to add to Resend
@@ -388,35 +481,51 @@ def sync_users_to_resend():
stats['skipped_invalid_emails'] += 1
continue
try:
first_name = user.get('first_name')
last_name = user.get('last_name')
first_name = user.get('first_name')
last_name = user.get('last_name')
keycloak_user_id = user.get('id')
# Add the contact to the Resend audience
# Mark as synced first (optimistic) to ensure consistency.
# If Resend API fails, we remove the record.
try:
synced_user_store.mark_user_synced(
email=email,
audience_id=RESEND_AUDIENCE_ID,
keycloak_user_id=keycloak_user_id,
)
except Exception:
logger.exception(f'Failed to mark user {email} as synced')
stats['errors'] += 1
continue
try:
add_contact_to_resend(
RESEND_AUDIENCE_ID, email, first_name, last_name
)
logger.info(f'Added user {email} to Resend')
stats['added_contacts'] += 1
# Sleep to respect rate limit after first API call
time.sleep(1 / RATE_LIMIT)
# Send a welcome email to the newly added contact
try:
send_welcome_email(email, first_name, last_name)
logger.info(f'Sent welcome email to {email}')
except Exception:
logger.exception(
f'Failed to send welcome email to {email}, but contact was added to audience'
)
# Continue with the sync process even if sending the welcome email fails
# Sleep to respect rate limit after second API call
time.sleep(1 / RATE_LIMIT)
except Exception:
logger.exception(f'Error adding user {email} to Resend')
synced_user_store.remove_synced_user(email, RESEND_AUDIENCE_ID)
stats['errors'] += 1
continue
synced_emails.add(email)
stats['added_contacts'] += 1
# Sleep to respect rate limit after first API call
time.sleep(1 / RATE_LIMIT)
# Send a welcome email to the newly added contact
try:
send_welcome_email(email, first_name, last_name)
logger.info(f'Sent welcome email to {email}')
except Exception:
logger.exception(
f'Failed to send welcome email to {email}, but contact was added to audience'
)
# Sleep to respect rate limit after second API call
time.sleep(1 / RATE_LIMIT)
offset += BATCH_SIZE

View File

View File

@@ -0,0 +1,358 @@
"""
Fixtures for V1 GitHub Resolver integration tests.
These tests run actual conversations with the ProcessSandboxService,
using TestLLM to replay pre-recorded trajectories.
"""
import hashlib
import hmac
import os
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# Set environment before importing app modules
os.environ.setdefault('RUNTIME', 'process')
os.environ.setdefault('ENABLE_V1_GITHUB_RESOLVER', 'true')
os.environ.setdefault('GITHUB_APP_CLIENT_ID', 'test-app-id')
os.environ.setdefault('GITHUB_APP_CLIENT_SECRET', 'test-app-secret')
os.environ.setdefault('GITHUB_APP_PRIVATE_KEY', 'test-private-key')
os.environ.setdefault('GITHUB_APP_WEBHOOK_SECRET', 'test-webhook-secret')
os.environ.setdefault('GITHUB_WEBHOOKS_ENABLED', '1')
os.environ.setdefault('HOST', 'localhost')
os.environ.setdefault('KEYCLOAK_URL', 'http://localhost:8080')
os.environ.setdefault('KEYCLOAK_REALM', 'test-realm')
os.environ.setdefault('KEYCLOAK_CLIENT_ID', 'test-client')
os.environ.setdefault('KEYCLOAK_CLIENT_SECRET', 'test-secret')
# Set the templates directory to the absolute path
_repo_root = Path(__file__).parent.parent.parent.parent.parent
_templates_dir = _repo_root / 'openhands' / 'integrations' / 'templates' / 'resolver'
os.environ.setdefault('OPENHANDS_RESOLVER_TEMPLATES_DIR', str(_templates_dir) + '/')
# Import storage models for database setup
# Note: Import ALL models to ensure tables are created
# NOTE: Imports must come after environment setup, hence noqa: E402
from server.constants import ORG_SETTINGS_VERSION # noqa: E402
from storage.auth_tokens import AuthTokens # noqa: E402
from storage.base import Base # noqa: E402
from storage.billing_session import BillingSession # noqa: E402, F401
from storage.conversation_work import ConversationWork # noqa: E402, F401
from storage.device_code import DeviceCode # noqa: E402, F401
from storage.feedback import Feedback # noqa: E402, F401
from storage.github_app_installation import GithubAppInstallation # noqa: E402
from storage.org import Org # noqa: E402
from storage.org_invitation import OrgInvitation # noqa: E402, F401
from storage.org_member import OrgMember # noqa: E402
from storage.role import Role # noqa: E402
from storage.stored_conversation_metadata import ( # noqa: E402
StoredConversationMetadata, # noqa: F401
)
from storage.stored_conversation_metadata_saas import ( # noqa: E402
StoredConversationMetadataSaas, # noqa: F401
)
from storage.stored_offline_token import StoredOfflineToken # noqa: E402
from storage.stripe_customer import StripeCustomer # noqa: E402, F401
from storage.user import User # noqa: E402
# Test constants
TEST_USER_UUID = UUID('11111111-1111-1111-1111-111111111111')
TEST_ORG_UUID = UUID('22222222-2222-2222-2222-222222222222')
TEST_KEYCLOAK_USER_ID = 'test-keycloak-user-id'
TEST_GITHUB_USER_ID = 12345
TEST_GITHUB_USERNAME = 'test-github-user'
TEST_WEBHOOK_SECRET = 'test-webhook-secret'
@pytest.fixture(scope='session')
def test_env():
"""Environment variables for testing."""
return {
'RUNTIME': 'process',
'ENABLE_V1_GITHUB_RESOLVER': 'true',
'GITHUB_APP_CLIENT_ID': 'test-app-id',
'GITHUB_APP_WEBHOOK_SECRET': TEST_WEBHOOK_SECRET,
'GITHUB_WEBHOOKS_ENABLED': '1',
}
@pytest.fixture
def engine():
"""Create an in-memory SQLite database engine for enterprise tables."""
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session_maker(engine):
"""Create a session maker bound to the test engine."""
return sessionmaker(bind=engine)
TEST_INSTALLATION_ID = 123456
@pytest.fixture
def seeded_db(session_maker):
"""Seed the database with test user data."""
now = datetime.now(tz=None) # Use naive datetime for SQLite compatibility
with session_maker() as session:
# Create role
session.add(Role(id=1, name='admin', rank=1))
# Create org with V1 enabled
session.add(
Org(
id=TEST_ORG_UUID,
name='test-org',
org_version=ORG_SETTINGS_VERSION,
enable_default_condenser=True,
enable_proactive_conversation_starters=False,
v1_enabled=True,
)
)
# Create user
session.add(
User(
id=TEST_USER_UUID,
current_org_id=TEST_ORG_UUID,
user_consents_to_analytics=True,
)
)
# Create org member with LLM API key
session.add(
OrgMember(
org_id=TEST_ORG_UUID,
user_id=TEST_USER_UUID,
role_id=1,
llm_api_key='test-llm-api-key',
status='active',
)
)
# Create offline token for Keycloak user
session.add(
StoredOfflineToken(
user_id=TEST_KEYCLOAK_USER_ID,
offline_token='test-offline-token',
created_at=now,
updated_at=now,
)
)
# Create auth tokens linking Keycloak user to GitHub
future_time = int((now + timedelta(hours=1)).timestamp())
session.add(
AuthTokens(
keycloak_user_id=TEST_KEYCLOAK_USER_ID,
identity_provider='github',
access_token='test-github-access-token',
refresh_token='test-github-refresh-token',
access_token_expires_at=future_time,
refresh_token_expires_at=future_time + 86400,
)
)
# Create GitHub app installation
session.add(
GithubAppInstallation(
installation_id=str(TEST_INSTALLATION_ID),
encrypted_token='test-encrypted-token',
created_at=now,
updated_at=now,
)
)
session.commit()
return session_maker
@pytest.fixture
def patched_session_maker(seeded_db):
"""Patch all imports of session_maker to use our test database.
This is necessary because the enterprise code imports session_maker
at module level from storage.database.
"""
patches = [
patch('storage.database.session_maker', seeded_db),
patch('integrations.github.github_view.session_maker', seeded_db),
patch('integrations.github.github_solvability.session_maker', seeded_db),
patch('server.auth.token_manager.session_maker', seeded_db),
patch('server.auth.saas_user_auth.session_maker', seeded_db),
patch('server.auth.domain_blocker.session_maker', seeded_db),
]
for p in patches:
p.start()
yield seeded_db
for p in patches:
p.stop()
@pytest.fixture
def mock_keycloak():
"""Mock Keycloak admin API to return our test user."""
async def mock_get_users(query: dict) -> list[dict]:
"""Mock user lookup by GitHub ID."""
q = query.get('q', '')
if f'github_id:{TEST_GITHUB_USER_ID}' in q:
return [{'id': TEST_KEYCLOAK_USER_ID, 'username': TEST_GITHUB_USERNAME}]
return []
mock_admin = MagicMock()
mock_admin.a_get_users = AsyncMock(side_effect=mock_get_users)
with patch('server.auth.token_manager.get_keycloak_admin', return_value=mock_admin):
yield mock_admin
@pytest.fixture
def mock_github_api():
"""Mock PyGithub API to capture posted comments (legacy - prefer mock_github_service)."""
captured_comments = []
captured_reactions = []
mock_issue = MagicMock()
mock_issue.create_comment = MagicMock(
side_effect=lambda body: captured_comments.append(body)
)
mock_issue.create_reaction = MagicMock(
side_effect=lambda reaction: captured_reactions.append(reaction)
)
mock_repo = MagicMock()
mock_repo.get_issue = MagicMock(return_value=mock_issue)
mock_github = MagicMock()
mock_github.get_repo = MagicMock(return_value=mock_repo)
with patch('github.Github', return_value=mock_github):
yield {
'github': mock_github,
'repo': mock_repo,
'issue': mock_issue,
'captured_comments': captured_comments,
'captured_reactions': captured_reactions,
}
@pytest.fixture
def mock_github_service():
"""
Real HTTP mock GitHub service.
This fixture starts an in-process HTTP server that implements the GitHub API.
PyGithub clients are patched to use this server instead of api.github.com.
Usage:
def test_something(mock_github_service):
# Configure the service
mock_github_service.configure_repo('owner/repo')
mock_github_service.configure_issue('owner/repo', 1)
# ... run your test code ...
# Verify
mock_github_service.assert_comment_sent("I'm on it")
mock_github_service.assert_reaction_added("eyes")
"""
from github import Github
from .mocks import MockGitHubService
# Create and start the mock service
service = MockGitHubService()
service.start()
# Patch Github to use our mock server
original_init = Github.__init__
def patched_init(self, *args, **kwargs):
kwargs['base_url'] = service.base_url
original_init(self, *args, **kwargs)
with patch.object(Github, '__init__', patched_init):
yield service
service.stop()
def create_webhook_signature(payload: bytes, secret: str) -> str:
"""Create a GitHub webhook signature."""
signature = hmac.new(
secret.encode('utf-8'), msg=payload, digestmod=hashlib.sha256
).hexdigest()
return f'sha256={signature}'
def create_issue_comment_payload(
issue_number: int = 1,
comment_body: str = '@openhands please fix this',
repo_name: str = 'test-owner/test-repo',
sender_id: int = TEST_GITHUB_USER_ID,
sender_login: str = TEST_GITHUB_USERNAME,
installation_id: int = 123456,
) -> dict[str, Any]:
"""Create a GitHub issue comment webhook payload."""
owner, repo = repo_name.split('/')
return {
'action': 'created',
'issue': {
'number': issue_number,
'title': 'Test Issue',
'body': 'This is a test issue',
'html_url': f'https://github.com/{repo_name}/issues/{issue_number}',
'user': {'login': sender_login, 'id': sender_id},
},
'comment': {
'id': 12345,
'body': comment_body,
'user': {'login': sender_login, 'id': sender_id},
'html_url': f'https://github.com/{repo_name}/issues/{issue_number}#issuecomment-12345',
},
'repository': {
'id': 12345678,
'name': repo,
'full_name': repo_name,
'private': False,
'html_url': f'https://github.com/{repo_name}',
'owner': {
'login': owner,
'id': 99999,
},
},
'sender': {'login': sender_login, 'id': sender_id},
'installation': {'id': installation_id},
}
@pytest.fixture
def issue_comment_payload():
"""Create a standard issue comment payload."""
return create_issue_comment_payload()
@pytest.fixture
def trajectory_path():
"""Path to trajectory files."""
return Path(__file__).parent / 'fixtures' / 'trajectories'
# Note: TestLLM injection will be handled separately as it requires
# the SDK to be installed and configured properly

View File

@@ -0,0 +1,16 @@
{
"name": "simple_finish",
"description": "Agent immediately finishes with a simple message",
"responses": [
{
"text": "",
"tool_calls": [
{
"id": "call_finish_1",
"name": "finish",
"arguments": "{\"message\": \"I have analyzed the issue and completed the task.\"}"
}
]
}
]
}

View File

@@ -0,0 +1,6 @@
"""Mocks for V1 GitHub Resolver integration tests."""
from .github_service import MockGitHubService
from .test_llm import TestLLM
__all__ = ['MockGitHubService', 'TestLLM']

View File

@@ -0,0 +1,485 @@
"""
Mock GitHub API Service for integration tests.
This module provides a real HTTP server that implements the GitHub API endpoints
used by the enterprise code. It tracks all API calls and state, allowing tests
to verify behavior without complex mocking.
Usage:
service = MockGitHubService()
service.start()
# Run test code that uses PyGithub...
# Verify
service.assert_comment_sent("I'm on it")
service.assert_reaction_added("eyes")
service.stop()
"""
import json
import socket
import threading
import time
from dataclasses import dataclass, field
from typing import Any
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
@dataclass
class MockGitHubService:
"""In-process mock GitHub API server that tracks state."""
host: str = "127.0.0.1"
port: int = 0 # 0 = auto-assign available port
# Tracked state
comments: list = field(default_factory=list)
reactions: list = field(default_factory=list)
api_calls: list = field(default_factory=list)
# Configurable data
repos: dict = field(default_factory=dict)
issues: dict = field(default_factory=dict)
pull_requests: dict = field(default_factory=dict)
issue_comments: dict = field(default_factory=dict)
# Internal
_app: FastAPI | None = None
_server: uvicorn.Server | None = None
_thread: threading.Thread | None = None
_ready: threading.Event = field(default_factory=threading.Event)
def __post_init__(self):
if self.port == 0:
self.port = self._find_free_port()
def _find_free_port(self) -> int:
"""Find an available port."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
@property
def base_url(self) -> str:
"""URL to use as GitHub API base."""
return f"http://{self.host}:{self.port}"
def configure_repo(
self,
full_name: str,
repo_id: int = 1,
private: bool = False,
) -> "MockGitHubService":
"""Configure a repository."""
owner, name = full_name.split("/")
self.repos[full_name] = {
"id": repo_id,
"name": name,
"full_name": full_name,
"private": private,
"owner": {"login": owner, "id": 1},
}
return self
def configure_issue(
self,
full_name: str,
number: int,
title: str = "Test Issue",
body: str = "Test body",
user_login: str = "testuser",
) -> "MockGitHubService":
"""Configure an issue."""
key = f"{full_name}/{number}"
self.issues[key] = {
"id": number,
"number": number,
"title": title,
"body": body,
"user": {"login": user_login, "id": 1},
}
return self
def configure_pull_request(
self,
full_name: str,
number: int,
title: str = "Test PR",
body: str = "Test body",
user_login: str = "testuser",
) -> "MockGitHubService":
"""Configure a pull request."""
key = f"{full_name}/{number}"
self.pull_requests[key] = {
"id": number,
"number": number,
"title": title,
"body": body,
"user": {"login": user_login, "id": 1},
}
return self
def configure_comment(
self,
full_name: str,
comment_id: int,
body: str = "Test comment",
user_login: str = "testuser",
) -> "MockGitHubService":
"""Configure an issue comment."""
key = f"{full_name}/{comment_id}"
self.issue_comments[key] = {
"id": comment_id,
"body": body,
"user": {"login": user_login, "id": 1},
}
return self
def start(self) -> "MockGitHubService":
"""Start the mock server in a background thread."""
self._create_app()
config = uvicorn.Config(
self._app,
host=self.host,
port=self.port,
log_level="error",
access_log=False,
)
self._server = uvicorn.Server(config)
self._thread = threading.Thread(target=self._run, daemon=True)
self._thread.start()
# Wait for server to be ready by polling
import requests
start_time = time.time()
while time.time() - start_time < 10:
try:
resp = requests.get(f"{self.base_url}/_test/state", timeout=0.5)
if resp.status_code == 200:
self._ready.set()
break
except requests.exceptions.RequestException:
time.sleep(0.1)
if not self._ready.is_set():
raise RuntimeError("Mock GitHub server failed to start")
return self
def stop(self):
"""Stop the mock server."""
if self._server:
self._server.should_exit = True
if self._thread:
self._thread.join(timeout=5)
def reset(self):
"""Reset tracked state (but keep configuration)."""
self.comments.clear()
self.reactions.clear()
self.api_calls.clear()
def _run(self):
"""Run the server (called in background thread)."""
if self._server:
self._server.run()
def _create_app(self):
"""Create the FastAPI app with GitHub API endpoints."""
app = FastAPI()
service = self # Capture reference for closures
# Middleware to log all API calls
@app.middleware("http")
async def log_requests(request: Request, call_next):
body = None
if request.method in ("POST", "PUT", "PATCH"):
body = await request.body()
try:
body = json.loads(body)
except (json.JSONDecodeError, UnicodeDecodeError):
body = body.decode() if isinstance(body, bytes) else body
service.api_calls.append(
{
"method": request.method,
"path": request.url.path,
"body": body,
}
)
response = await call_next(request)
return response
# Repository endpoints
@app.get("/repos/{owner}/{repo}")
async def get_repo(owner: str, repo: str, request: Request):
full_name = f"{owner}/{repo}"
if full_name in service.repos:
data = service.repos[full_name].copy()
data["url"] = str(request.url)
return JSONResponse(data)
# Return default repo
return JSONResponse(
{
"id": 1,
"name": repo,
"full_name": full_name,
"url": str(request.url),
"owner": {"login": owner, "id": 1},
}
)
# Issue endpoints
@app.get("/repos/{owner}/{repo}/issues/{issue_number}")
async def get_issue(
owner: str, repo: str, issue_number: int, request: Request
):
key = f"{owner}/{repo}/{issue_number}"
if key in service.issues:
data = service.issues[key].copy()
data["url"] = str(request.url)
return JSONResponse(data)
return JSONResponse(
{
"id": issue_number,
"number": issue_number,
"title": "Test Issue",
"body": "Test body",
"url": str(request.url),
"user": {"login": "testuser", "id": 1},
}
)
# Pull request endpoints
@app.get("/repos/{owner}/{repo}/pulls/{pr_number}")
async def get_pull(owner: str, repo: str, pr_number: int, request: Request):
key = f"{owner}/{repo}/{pr_number}"
if key in service.pull_requests:
data = service.pull_requests[key].copy()
data["url"] = str(request.url)
return JSONResponse(data)
return JSONResponse(
{
"id": pr_number,
"number": pr_number,
"title": "Test PR",
"body": "Test body",
"url": str(request.url),
"user": {"login": "testuser", "id": 1},
}
)
# Comment endpoints
@app.get("/repos/{owner}/{repo}/issues/comments/{comment_id}")
async def get_issue_comment(
owner: str, repo: str, comment_id: int, request: Request
):
key = f"{owner}/{repo}/{comment_id}"
if key in service.issue_comments:
data = service.issue_comments[key].copy()
data["url"] = str(request.url)
return JSONResponse(data)
return JSONResponse(
{
"id": comment_id,
"body": "Test comment",
"url": str(request.url),
"user": {"login": "testuser", "id": 1},
}
)
@app.get("/repos/{owner}/{repo}/issues/{issue_number}/comments")
async def list_issue_comments(
owner: str, repo: str, issue_number: int, request: Request
):
# Return empty list by default
return JSONResponse([])
@app.post("/repos/{owner}/{repo}/issues/{issue_number}/comments")
async def create_issue_comment(
owner: str, repo: str, issue_number: int, request: Request
):
body = await request.json()
comment_data = {
"repo": f"{owner}/{repo}",
"issue_number": issue_number,
"body": body.get("body", ""),
}
service.comments.append(comment_data)
return JSONResponse(
{
"id": len(service.comments),
"body": body.get("body", ""),
"url": str(request.url),
"user": {"login": "bot", "id": 1},
},
status_code=201,
)
# Reaction endpoints
@app.post("/repos/{owner}/{repo}/issues/comments/{comment_id}/reactions")
async def create_comment_reaction(
owner: str, repo: str, comment_id: int, request: Request
):
body = await request.json()
reaction_data = {
"repo": f"{owner}/{repo}",
"comment_id": comment_id,
"content": body.get("content", ""),
}
service.reactions.append(reaction_data)
return JSONResponse(
{
"id": len(service.reactions),
"content": body.get("content", ""),
"url": str(request.url),
},
status_code=201,
)
@app.post("/repos/{owner}/{repo}/issues/{issue_number}/reactions")
async def create_issue_reaction(
owner: str, repo: str, issue_number: int, request: Request
):
body = await request.json()
reaction_data = {
"repo": f"{owner}/{repo}",
"issue_number": issue_number,
"content": body.get("content", ""),
}
service.reactions.append(reaction_data)
return JSONResponse(
{
"id": len(service.reactions),
"content": body.get("content", ""),
"url": str(request.url),
},
status_code=201,
)
# PR review comment reply
@app.post(
"/repos/{owner}/{repo}/pulls/{pr_number}/comments/{comment_id}/replies"
)
async def create_review_comment_reply(
owner: str, repo: str, pr_number: int, comment_id: int, request: Request
):
body = await request.json()
comment_data = {
"repo": f"{owner}/{repo}",
"pr_number": pr_number,
"reply_to_comment_id": comment_id,
"body": body.get("body", ""),
}
service.comments.append(comment_data)
return JSONResponse(
{
"id": len(service.comments),
"body": body.get("body", ""),
"url": str(request.url),
"user": {"login": "bot", "id": 1},
},
status_code=201,
)
# PR comment (not review comment)
@app.post("/repos/{owner}/{repo}/issues/{pr_number}/comments")
async def create_pr_comment(
owner: str, repo: str, pr_number: int, request: Request
):
# PRs can also receive issue comments
return await create_issue_comment(owner, repo, pr_number, request)
# Test endpoint to query state
@app.get("/_test/state")
async def get_state():
return JSONResponse(
{
"comments": service.comments,
"reactions": service.reactions,
"api_calls": service.api_calls,
}
)
@app.post("/_test/reset")
async def reset_state():
service.reset()
return JSONResponse({"status": "ok"})
self._app = app
# Assertion helpers
def get_comments(self) -> list[dict[str, Any]]:
"""Get all comments that were created."""
return self.comments.copy()
def get_reactions(self) -> list[dict[str, Any]]:
"""Get all reactions that were created."""
return self.reactions.copy()
def get_api_calls(self) -> list[dict[str, Any]]:
"""Get all API calls that were made."""
return self.api_calls.copy()
def assert_comment_sent(self, body_contains: str) -> dict[str, Any]:
"""Assert that a comment containing the given text was sent."""
for comment in self.comments:
if body_contains in comment.get("body", ""):
return comment
raise AssertionError(
f"No comment containing '{body_contains}' was sent.\n"
f"Comments sent: {[c.get('body', '')[:50] for c in self.comments]}"
)
def assert_reaction_added(self, content: str) -> dict[str, Any]:
"""Assert that a reaction with the given content was added."""
for reaction in self.reactions:
if reaction.get("content") == content:
return reaction
raise AssertionError(
f"No '{content}' reaction was added.\n"
f"Reactions: {[r.get('content') for r in self.reactions]}"
)
def assert_no_comments(self):
"""Assert that no comments were sent."""
if self.comments:
raise AssertionError(
f"Expected no comments, but {len(self.comments)} were sent:\n"
f"{[c.get('body', '')[:50] for c in self.comments]}"
)
def wait_for_comment(self, body_contains: str, timeout: float = 5.0) -> dict:
"""Wait for a comment containing the given text."""
start = time.time()
while time.time() - start < timeout:
for comment in self.comments:
if body_contains in comment.get("body", ""):
return comment
time.sleep(0.1)
raise TimeoutError(
f"Timed out waiting for comment containing '{body_contains}'"
)
def wait_for_reaction(self, content: str, timeout: float = 5.0) -> dict:
"""Wait for a reaction with the given content."""
start = time.time()
while time.time() - start < timeout:
for reaction in self.reactions:
if reaction.get("content") == content:
return reaction
time.sleep(0.1)
raise TimeoutError(f"Timed out waiting for '{content}' reaction")

View File

@@ -0,0 +1,217 @@
"""TestLLM - A mock LLM for testing V1 GitHub Resolver.
This is a simplified version of the TestLLM from openhands.sdk.testing
that returns scripted responses without making real LLM API calls.
"""
from collections import deque
from typing import Any, ClassVar, Sequence
from litellm.types.utils import Choices, ModelResponse
from litellm.types.utils import Message as LiteLLMMessage
from pydantic import ConfigDict, Field, PrivateAttr
from openhands.sdk.llm.llm import LLM
from openhands.sdk.llm.llm_response import LLMResponse
from openhands.sdk.llm.message import Message, TextContent
from openhands.sdk.llm.streaming import TokenCallbackType
from openhands.sdk.llm.utils.metrics import MetricsSnapshot, TokenUsage
from openhands.sdk.tool.tool import ToolDefinition
__all__ = ['TestLLM', 'TestLLMExhaustedError']
class TestLLMExhaustedError(Exception):
"""Raised when TestLLM has no more scripted responses."""
pass
class TestLLM(LLM):
"""A mock LLM for testing that returns scripted responses.
TestLLM is a real LLM subclass that can be used anywhere an LLM is accepted.
It returns pre-scripted responses without making any API calls.
"""
# Prevent pytest from collecting this class as a test
__test__ = False
model: str = Field(default='test-model')
_scripted_responses: deque[Message | Exception] = PrivateAttr(default_factory=deque)
_call_count: int = PrivateAttr(default=0)
model_config: ClassVar[ConfigDict] = ConfigDict(
extra='ignore', arbitrary_types_allowed=True
)
def __init__(self, **data: Any) -> None:
# Extract scripted_responses before calling super().__init__
scripted_responses = data.pop('scripted_responses', [])
super().__init__(**data)
self._scripted_responses = deque(list(scripted_responses))
self._call_count = 0
@classmethod
def from_messages(
cls,
messages: list[Message | Exception],
*,
model: str = 'test-model',
usage_id: str = 'test-llm',
**kwargs: Any,
) -> 'TestLLM':
"""Create a TestLLM with scripted responses.
Args:
messages: List of Message or Exception objects to return in order.
model: Model name (default: "test-model")
usage_id: Usage ID for metrics (default: "test-llm")
**kwargs: Additional LLM configuration options
Returns:
A TestLLM instance configured with the scripted responses.
"""
return cls(
model=model,
usage_id=usage_id,
scripted_responses=messages,
**kwargs,
)
def completion(
self,
messages: list[Message],
tools: Sequence[ToolDefinition] | None = None,
_return_metrics: bool = False,
add_security_risk_prediction: bool = False,
on_token: TokenCallbackType | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Return the next scripted response.
Args:
messages: Input messages (ignored)
tools: Available tools (ignored)
_return_metrics: Whether to return metrics (ignored)
add_security_risk_prediction: Add security risk field (ignored)
on_token: Streaming callback (ignored)
**kwargs: Additional arguments (ignored)
Returns:
LLMResponse containing the next scripted message.
Raises:
TestLLMExhaustedError: When no more scripted responses are available.
"""
if not self._scripted_responses:
raise TestLLMExhaustedError(
f'TestLLM: no more scripted responses '
f'(exhausted after {self._call_count} calls)'
)
item = self._scripted_responses.popleft()
self._call_count += 1
# Raise scripted exceptions
if isinstance(item, Exception):
raise item
message = item
# Create a minimal ModelResponse for raw_response
raw_response = self._create_model_response(message)
return LLMResponse(
message=message,
metrics=self._zero_metrics(),
raw_response=raw_response,
)
def responses(
self,
messages: list[Message],
tools: Sequence[ToolDefinition] | None = None,
include: list[str] | None = None,
store: bool | None = None,
_return_metrics: bool = False,
add_security_risk_prediction: bool = False,
on_token: TokenCallbackType | None = None,
**kwargs: Any,
) -> LLMResponse:
"""Return the next scripted response (delegates to completion)."""
return self.completion(
messages=messages,
tools=tools,
_return_metrics=_return_metrics,
add_security_risk_prediction=add_security_risk_prediction,
on_token=on_token,
**kwargs,
)
def uses_responses_api(self) -> bool:
"""TestLLM always uses the completion path."""
return False
def _zero_metrics(self) -> MetricsSnapshot:
"""Return a zero-cost metrics snapshot."""
return MetricsSnapshot(
model_name=self.model,
accumulated_cost=0.0,
max_budget_per_task=None,
accumulated_token_usage=TokenUsage(
model=self.model,
prompt_tokens=0,
completion_tokens=0,
),
)
def _create_model_response(self, message: Message) -> ModelResponse:
"""Create a minimal ModelResponse from a Message."""
# Build the LiteLLM message dict
litellm_message_dict: dict[str, Any] = {
'role': message.role,
'content': self._content_to_string(message),
}
# Add tool_calls if present
if message.tool_calls:
litellm_message_dict['tool_calls'] = [
{
'id': tc.id,
'type': 'function',
'function': {
'name': tc.name,
'arguments': tc.arguments,
},
}
for tc in message.tool_calls
]
litellm_message = LiteLLMMessage(**litellm_message_dict)
return ModelResponse(
id=f'test-response-{self._call_count}',
choices=[Choices(message=litellm_message, index=0, finish_reason='stop')],
created=0,
model=self.model,
object='chat.completion',
)
def _content_to_string(self, message: Message) -> str:
"""Convert message content to a string."""
parts = []
for item in message.content:
if isinstance(item, TextContent):
parts.append(item.text)
return '\n'.join(parts)
@property
def remaining_responses(self) -> int:
"""Return the number of remaining scripted responses."""
return len(self._scripted_responses)
@property
def call_count(self) -> int:
"""Return the number of calls made to this TestLLM."""
return self._call_count

View File

@@ -0,0 +1,200 @@
"""
Integration test for V1 GitHub Resolver webhook flow.
This test verifies:
1. Webhook triggers agent server creation
2. "I'm on it" message is sent to GitHub
3. Eyes reaction is added to acknowledge the request
Uses MockGitHubService for real HTTP calls to a mock GitHub API.
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from .conftest import (
TEST_GITHUB_USER_ID,
TEST_GITHUB_USERNAME,
create_issue_comment_payload,
)
class TestV1GitHubResolverE2E:
"""E2E test for V1 GitHub Resolver with MockGitHubService."""
@pytest.mark.asyncio
async def test_webhook_flow_with_mock_github_service(
self, patched_session_maker, mock_keycloak, mock_github_service
):
"""
E2E test: Webhook → Agent Server → Real HTTP calls to MockGitHubService.
This test:
1. Receives a GitHub webhook payload
2. Routes to V1 path (v1_enabled=True)
3. Starts agent server via start_app_conversation
4. Makes REAL HTTP calls to MockGitHubService
5. Verifies "I'm on it" and eyes reaction via service state
"""
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTask,
AppConversationStartTaskStatus,
)
# Configure the mock GitHub service
mock_github_service.configure_repo('test-owner/test-repo')
mock_github_service.configure_issue(
'test-owner/test-repo',
number=1,
title='Test Issue',
body='This is a test issue',
)
mock_github_service.configure_comment(
'test-owner/test-repo',
comment_id=12345,
body='@openhands please fix this bug',
)
# Create webhook payload
payload = create_issue_comment_payload(
comment_body='@openhands please fix this bug',
sender_id=TEST_GITHUB_USER_ID,
sender_login=TEST_GITHUB_USERNAME,
)
# Track agent server start
agent_started = asyncio.Event()
captured_request = None
# Mock start_app_conversation to simulate agent server
async def mock_start_app_conversation(request):
from uuid import uuid4
nonlocal captured_request
captured_request = request
agent_started.set()
task_id = uuid4()
conv_id = uuid4()
yield AppConversationStartTask(
id=task_id,
created_by_user_id='test-user',
status=AppConversationStartTaskStatus.WORKING,
request=request,
)
await asyncio.sleep(0.1)
yield AppConversationStartTask(
id=task_id,
created_by_user_id='test-user',
status=AppConversationStartTaskStatus.READY,
app_conversation_id=conv_id,
request=request,
)
# Mock GithubServiceImpl (for fetching issue details)
mock_github_service_impl = MagicMock()
mock_github_service_impl.get_issue_or_pr_comments = AsyncMock(return_value=[])
mock_github_service_impl.get_issue_or_pr_title_and_body = AsyncMock(
return_value=('Test Issue', 'This is a test issue body')
)
mock_github_service_impl.get_review_thread_comments = AsyncMock(return_value=[])
# Mock app conversation service
mock_app_service = MagicMock()
mock_app_service.start_app_conversation = mock_start_app_conversation
with patch(
'integrations.github.github_view.get_user_v1_enabled_setting',
return_value=True,
), patch(
'integrations.github.github_view.get_app_conversation_service'
) as mock_get_service, patch(
'github.GithubIntegration'
) as mock_integration, patch(
'integrations.github.github_solvability.summarize_issue_solvability',
new_callable=AsyncMock,
return_value=None,
), patch(
'server.auth.token_manager.TokenManager.get_idp_token_from_idp_user_id',
new_callable=AsyncMock,
return_value='mock-token',
), patch(
'integrations.v1_utils.get_saas_user_auth',
new_callable=AsyncMock,
) as mock_saas_auth, patch(
'integrations.github.github_view.GithubServiceImpl',
return_value=mock_github_service_impl,
):
# Setup mock service context
mock_context = MagicMock()
mock_context.__aenter__ = AsyncMock(return_value=mock_app_service)
mock_context.__aexit__ = AsyncMock(return_value=None)
mock_get_service.return_value = mock_context
# Setup user auth
mock_user_auth = MagicMock()
mock_user_auth.get_provider_tokens = AsyncMock(
return_value={'github': 'mock-token'}
)
mock_saas_auth.return_value = mock_user_auth
# Setup GitHub integration
mock_token = MagicMock()
mock_token.token = 'test-installation-token'
mock_integration.return_value.get_access_token.return_value = mock_token
# Run the test
from integrations.github.github_manager import GithubManager
from integrations.models import Message, SourceType
from server.auth.token_manager import TokenManager
token_manager = TokenManager()
token_manager.load_org_token = MagicMock(return_value='mock-token')
data_collector = MagicMock()
data_collector.process_payload = MagicMock()
data_collector.fetch_issue_details = AsyncMock(
return_value={'description': 'Test', 'previous_comments': []}
)
data_collector.save_data = AsyncMock()
manager = GithubManager(token_manager, data_collector)
manager.github_integration = mock_integration.return_value
# Send webhook
message = Message(
source=SourceType.GITHUB,
message={
'payload': payload,
'installation': payload['installation']['id'],
},
)
await manager.receive_message(message)
# Wait for agent to start
await asyncio.wait_for(agent_started.wait(), timeout=10.0)
# Give time for GitHub API calls to complete
await asyncio.sleep(0.5)
# Verify via MockGitHubService state (no more async events!)
assert agent_started.is_set(), 'Agent server should start'
assert captured_request is not None
assert captured_request.selected_repository == 'test-owner/test-repo'
# Verify GitHub API calls were made
mock_github_service.assert_comment_sent("I'm on it")
mock_github_service.assert_reaction_added('eyes')
# Print verification
comments = mock_github_service.get_comments()
reactions = mock_github_service.get_reactions()
print('✅ Agent server started')
print(f'"I\'m on it" message sent: {comments[0]["body"][:60]}...')
print(f'✅ Eyes reaction added: {reactions[0]["content"]}')

View File

@@ -15,6 +15,7 @@ from storage.device_code import DeviceCode # noqa: F401
from storage.feedback import Feedback
from storage.github_app_installation import GithubAppInstallation
from storage.org import Org
from storage.org_invitation import OrgInvitation # noqa: F401
from storage.org_member import OrgMember
from storage.role import Role
from storage.stored_conversation_metadata import StoredConversationMetadata

View File

@@ -126,3 +126,24 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
# Should be a different instance than the original (copied after handler runs)
assert result is not agent
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
def test_run_agent_variant_tests_v1_preserves_planning_agent_system_prompt():
"""Planning agents should retain their specialized system prompt and not be overwritten by the experiment."""
# Arrange
planning_agent = make_agent().model_copy(
update={'system_prompt_filename': 'system_prompt_planning.j2'}
)
conv_id = uuid4()
# Act
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
user_id='user-planning',
conversation_id=conv_id,
agent=planning_agent,
)
# Assert
assert result.system_prompt_filename == 'system_prompt_planning.j2'

View File

@@ -6,6 +6,8 @@ import httpx
import pytest
from fastapi import HTTPException
from server.routes.api_keys import (
ByorPermittedResponse,
LlmApiKeyResponse,
check_byor_permitted,
delete_byor_key_from_litellm,
get_llm_api_key_for_byor,
@@ -203,7 +205,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
assert result == LlmApiKeyResponse(key=new_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_generate_key.assert_called_once_with(user_id)
@@ -228,7 +230,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': existing_key}
assert result == LlmApiKeyResponse(key=existing_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(existing_key, user_id)
@@ -265,7 +267,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
assert result == LlmApiKeyResponse(key=new_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(invalid_key, user_id)
@@ -305,7 +307,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
assert result == LlmApiKeyResponse(key=new_key)
mock_check_enabled.assert_called_once_with(user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
mock_generate_key.assert_called_once_with(user_id)
@@ -478,7 +480,7 @@ class TestCheckByorPermitted:
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == {'permitted': True}
assert result == ByorPermittedResponse(permitted=True)
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio
@@ -493,7 +495,7 @@ class TestCheckByorPermitted:
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == {'permitted': False}
assert result == ByorPermittedResponse(permitted=False)
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio

View File

@@ -1220,3 +1220,60 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
assert result == mock_workspace
# Tests for OAuth URL encoding
class TestJiraDcOAuthUrlEncoding:
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
@pytest.mark.asyncio
@patch('server.routes.integration.jira_dc.get_user_auth')
@patch('server.routes.integration.jira_dc.redis_client')
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
async def test_create_jira_dc_workspace_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_jira_dc_workspace properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
workspace_data = JiraDcWorkspaceCreate(
workspace_name='test-workspace',
webhook_secret='secret',
svc_acc_email='svc@test.com',
svc_acc_api_key='key',
is_active=True,
)
response = await create_jira_dc_workspace(mock_request, workspace_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url
@pytest.mark.asyncio
@patch('server.routes.integration.jira_dc.get_user_auth')
@patch('server.routes.integration.jira_dc.redis_client')
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
async def test_create_workspace_link_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
link_data = JiraDcLinkCreate(workspace_name='test-workspace')
response = await create_workspace_link(mock_request, link_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url

View File

@@ -1323,3 +1323,58 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
assert result == mock_workspace
# Tests for OAuth URL encoding
class TestJiraOAuthUrlEncoding:
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
@pytest.mark.asyncio
@patch('server.routes.integration.jira.get_user_auth')
@patch('server.routes.integration.jira.redis_client')
async def test_create_jira_workspace_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_jira_workspace properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
workspace_data = JiraWorkspaceCreate(
workspace_name='test-workspace',
webhook_secret='secret',
svc_acc_email='svc@test.com',
svc_acc_api_key='key',
is_active=True,
)
response = await create_jira_workspace(mock_request, workspace_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url
@pytest.mark.asyncio
@patch('server.routes.integration.jira.get_user_auth')
@patch('server.routes.integration.jira.redis_client')
async def test_create_workspace_link_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
link_data = JiraLinkCreate(workspace_name='test-workspace')
response = await create_workspace_link(mock_request, link_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url

File diff suppressed because it is too large Load Diff

View File

@@ -175,11 +175,12 @@ class TestOrgMemberServiceGetOrgMembers:
assert data is not None
assert isinstance(data, OrgMemberPage)
assert len(data.items) == 1
assert data.next_page_id is None
assert data.current_page == 1
assert data.per_page == 100
assert data.items[0].user_id == str(current_user_id)
assert data.items[0].email == 'test@example.com'
assert data.items[0].role_id == 1
assert data.items[0].role_name == 'owner'
assert data.items[0].role == 'owner'
assert data.items[0].role_rank == 10
assert data.items[0].status == 'active'
@@ -282,9 +283,9 @@ class TestOrgMemberServiceGetOrgMembers:
# Assert
assert success is True
assert data is not None
assert data.next_page_id is None
assert data.current_page == 1
mock_get_paginated.assert_called_once_with(
org_id=org_id, offset=0, limit=100
org_id=org_id, offset=0, limit=100, email_filter=None
)
@pytest.mark.asyncio
@@ -316,9 +317,9 @@ class TestOrgMemberServiceGetOrgMembers:
# Assert
assert success is True
assert data is not None
assert data.next_page_id == '150' # offset (100) + limit (50)
assert data.current_page == 3 # offset (100) / limit (50) + 1
mock_get_paginated.assert_called_once_with(
org_id=org_id, offset=100, limit=50
org_id=org_id, offset=100, limit=50, email_filter=None
)
@pytest.mark.asyncio
@@ -350,7 +351,7 @@ class TestOrgMemberServiceGetOrgMembers:
# Assert
assert success is True
assert data is not None
assert data.next_page_id is None
assert data.current_page == 3
@pytest.mark.asyncio
async def test_empty_organization_no_members(
@@ -382,7 +383,6 @@ class TestOrgMemberServiceGetOrgMembers:
assert success is True
assert data is not None
assert len(data.items) == 0
assert data.next_page_id is None
@pytest.mark.asyncio
async def test_missing_user_relationship_handles_gracefully(
@@ -462,7 +462,7 @@ class TestOrgMemberServiceGetOrgMembers:
assert success is True
assert data is not None
assert len(data.items) == 1
assert data.items[0].role_name == ''
assert data.items[0].role == ''
assert data.items[0].role_rank == 0
@pytest.mark.asyncio
@@ -512,6 +512,156 @@ class TestOrgMemberServiceGetOrgMembers:
assert data is not None
assert len(data.items) == 2
@pytest.mark.asyncio
async def test_email_filter_passed_to_store(
self, org_id, current_user_id, mock_org_member, requester_membership_owner
):
"""Test that email filter is passed to store methods."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.OrgMemberStore.get_org_members_paginated',
new_callable=AsyncMock,
) as mock_get_paginated,
):
mock_get_member.return_value = requester_membership_owner
mock_get_paginated.return_value = ([mock_org_member], False)
# Act
await OrgMemberService.get_org_members(
org_id=org_id,
current_user_id=current_user_id,
page_id=None,
limit=10,
email_filter='alice',
)
# Assert
mock_get_paginated.assert_called_once_with(
org_id=org_id, offset=0, limit=10, email_filter='alice'
)
@pytest.mark.asyncio
async def test_pagination_metadata_correct_for_page_2(
self, org_id, current_user_id, mock_org_member, requester_membership_owner
):
"""Test pagination metadata is correct for page 2."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.OrgMemberStore.get_org_members_paginated',
new_callable=AsyncMock,
) as mock_get_paginated,
):
mock_get_member.return_value = requester_membership_owner
mock_get_paginated.return_value = ([mock_org_member], True)
# Act - Request page 2 (offset 10) with limit 10
success, error_code, data = await OrgMemberService.get_org_members(
org_id=org_id,
current_user_id=current_user_id,
page_id='10',
limit=10,
)
# Assert
assert success is True
assert data is not None
assert data.current_page == 2
assert data.per_page == 10
class TestOrgMemberServiceGetOrgMembersCount:
"""Test cases for OrgMemberService.get_org_members_count."""
@pytest.fixture
def requester_membership(self, org_id, current_user_id):
"""Create a mock requester membership."""
membership = MagicMock(spec=OrgMember)
membership.org_id = org_id
membership.user_id = current_user_id
membership.role_id = 1
return membership
@pytest.mark.asyncio
async def test_count_succeeds_returns_count(
self, org_id, current_user_id, requester_membership
):
"""Test that successful count returns the member count."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.OrgMemberStore.get_org_members_count',
new_callable=AsyncMock,
) as mock_get_count,
):
mock_get_member.return_value = requester_membership
mock_get_count.return_value = 42
# Act
count = await OrgMemberService.get_org_members_count(
org_id=org_id,
current_user_id=current_user_id,
)
# Assert
assert count == 42
mock_get_count.assert_called_once_with(org_id=org_id, email_filter=None)
@pytest.mark.asyncio
async def test_count_with_email_filter(
self, org_id, current_user_id, requester_membership
):
"""Test that email filter is passed to store method."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.OrgMemberStore.get_org_members_count',
new_callable=AsyncMock,
) as mock_get_count,
):
mock_get_member.return_value = requester_membership
mock_get_count.return_value = 5
# Act
count = await OrgMemberService.get_org_members_count(
org_id=org_id,
current_user_id=current_user_id,
email_filter='alice',
)
# Assert
assert count == 5
mock_get_count.assert_called_once_with(org_id=org_id, email_filter='alice')
@pytest.mark.asyncio
async def test_not_a_member_raises_error(self, org_id, current_user_id):
"""Test that non-member raises OrgMemberNotFoundError."""
# Arrange
with patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member:
mock_get_member.return_value = None
# Act & Assert
with pytest.raises(OrgMemberNotFoundError):
await OrgMemberService.get_org_members_count(
org_id=org_id,
current_user_id=current_user_id,
)
@pytest.fixture
def target_membership_owner(org_id, target_user_id, owner_role):
@@ -549,6 +699,9 @@ class TestOrgMemberServiceRemoveOrgMember:
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
):
mock_get_member.side_effect = [
requester_membership_owner,
@@ -556,6 +709,7 @@ class TestOrgMemberServiceRemoveOrgMember:
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = None
# Act
success, error = await OrgMemberService.remove_org_member(
@@ -590,6 +744,9 @@ class TestOrgMemberServiceRemoveOrgMember:
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
):
mock_get_member.side_effect = [
requester_membership_owner,
@@ -597,6 +754,7 @@ class TestOrgMemberServiceRemoveOrgMember:
]
mock_get_role.side_effect = [owner_role, admin_role]
mock_remove.return_value = True
mock_get_user.return_value = None
# Act
success, error = await OrgMemberService.remove_org_member(
@@ -630,6 +788,9 @@ class TestOrgMemberServiceRemoveOrgMember:
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
):
mock_get_member.side_effect = [
requester_membership_admin,
@@ -637,6 +798,7 @@ class TestOrgMemberServiceRemoveOrgMember:
]
mock_get_role.side_effect = [admin_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = None
# Act
success, error = await OrgMemberService.remove_org_member(
@@ -927,6 +1089,9 @@ class TestOrgMemberServiceRemoveOrgMember:
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
):
mock_get_member.side_effect = [
requester_membership_owner,
@@ -940,6 +1105,7 @@ class TestOrgMemberServiceRemoveOrgMember:
another_owner,
]
mock_remove.return_value = True
mock_get_user.return_value = None
# Act
success, error = await OrgMemberService.remove_org_member(
@@ -990,6 +1156,302 @@ class TestOrgMemberServiceRemoveOrgMember:
assert success is False
assert error == 'removal_failed'
@pytest.mark.asyncio
async def test_remove_member_updates_current_org_id_when_matching(
self,
org_id,
current_user_id,
target_user_id,
requester_membership_owner,
target_membership_user,
owner_role,
member_role,
):
"""Test that current_org_id is updated to personal workspace when it matches removed org."""
# Arrange
mock_user = MagicMock(spec=User)
mock_user.current_org_id = (
org_id # User's current org matches the org being removed
)
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.RoleStore.get_role_by_id'
) as mock_get_role,
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
patch(
'server.services.org_member_service.UserStore.update_current_org'
) as mock_update_org,
):
mock_get_member.side_effect = [
requester_membership_owner,
target_membership_user,
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = mock_user
# Act
success, error = await OrgMemberService.remove_org_member(
org_id, target_user_id, current_user_id
)
# Assert
assert success is True
assert error is None
mock_update_org.assert_called_once_with(str(target_user_id), target_user_id)
@pytest.mark.asyncio
async def test_remove_member_does_not_update_current_org_id_when_not_matching(
self,
org_id,
current_user_id,
target_user_id,
requester_membership_owner,
target_membership_user,
owner_role,
member_role,
):
"""Test that current_org_id is NOT updated when it differs from removed org."""
# Arrange
different_org_id = uuid.uuid4()
mock_user = MagicMock(spec=User)
mock_user.current_org_id = different_org_id # User's current org is different
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.RoleStore.get_role_by_id'
) as mock_get_role,
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
patch(
'server.services.org_member_service.UserStore.update_current_org'
) as mock_update_org,
):
mock_get_member.side_effect = [
requester_membership_owner,
target_membership_user,
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = mock_user
# Act
success, error = await OrgMemberService.remove_org_member(
org_id, target_user_id, current_user_id
)
# Assert
assert success is True
assert error is None
mock_update_org.assert_not_called()
@pytest.mark.asyncio
async def test_remove_member_succeeds_when_user_not_found_after_removal(
self,
org_id,
current_user_id,
target_user_id,
requester_membership_owner,
target_membership_user,
owner_role,
member_role,
):
"""Test that removal succeeds even if user lookup returns None after removal."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.RoleStore.get_role_by_id'
) as mock_get_role,
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
patch(
'server.services.org_member_service.UserStore.update_current_org'
) as mock_update_org,
):
mock_get_member.side_effect = [
requester_membership_owner,
target_membership_user,
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = None # User not found
# Act
success, error = await OrgMemberService.remove_org_member(
org_id, target_user_id, current_user_id
)
# Assert
assert success is True
assert error is None
mock_update_org.assert_not_called()
@pytest.mark.asyncio
async def test_successful_removal_calls_litellm_remove_user_from_team(
self,
org_id,
current_user_id,
target_user_id,
requester_membership_owner,
target_membership_user,
owner_role,
member_role,
):
"""Test that LiteLLM remove_user_from_team is called after successful database removal."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.RoleStore.get_role_by_id'
) as mock_get_role,
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
patch(
'server.services.org_member_service.LiteLlmManager.remove_user_from_team',
new_callable=AsyncMock,
) as mock_litellm_remove,
):
mock_get_member.side_effect = [
requester_membership_owner,
target_membership_user,
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = None
# Act
success, error = await OrgMemberService.remove_org_member(
org_id, target_user_id, current_user_id
)
# Assert
assert success is True
mock_litellm_remove.assert_called_once_with(
str(target_user_id), str(org_id)
)
@pytest.mark.asyncio
async def test_litellm_failure_does_not_fail_removal(
self,
org_id,
current_user_id,
target_user_id,
requester_membership_owner,
target_membership_user,
owner_role,
member_role,
):
"""Test that LiteLLM failure doesn't fail the overall removal operation."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.RoleStore.get_role_by_id'
) as mock_get_role,
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.UserStore.get_user_by_id'
) as mock_get_user,
patch(
'server.services.org_member_service.LiteLlmManager.remove_user_from_team',
new_callable=AsyncMock,
) as mock_litellm_remove,
):
mock_get_member.side_effect = [
requester_membership_owner,
target_membership_user,
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = True
mock_get_user.return_value = None
mock_litellm_remove.side_effect = Exception('LiteLLM API error')
# Act
success, error = await OrgMemberService.remove_org_member(
org_id, target_user_id, current_user_id
)
# Assert
assert success is True
assert error is None
@pytest.mark.asyncio
async def test_database_failure_skips_litellm_call(
self,
org_id,
current_user_id,
target_user_id,
requester_membership_owner,
target_membership_user,
owner_role,
member_role,
):
"""Test that LiteLLM is not called when database removal fails."""
# Arrange
with (
patch(
'server.services.org_member_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_member_service.RoleStore.get_role_by_id'
) as mock_get_role,
patch(
'server.services.org_member_service.OrgMemberStore.remove_user_from_org'
) as mock_remove,
patch(
'server.services.org_member_service.LiteLlmManager.remove_user_from_team',
new_callable=AsyncMock,
) as mock_litellm_remove,
):
mock_get_member.side_effect = [
requester_membership_owner,
target_membership_user,
]
mock_get_role.side_effect = [owner_role, member_role]
mock_remove.return_value = False
# Act
success, error = await OrgMemberService.remove_org_member(
org_id, target_user_id, current_user_id
)
# Assert
assert success is False
mock_litellm_remove.assert_not_called()
class TestOrgMemberServiceCanRemoveMember:
"""Test cases for OrgMemberService._can_remove_member."""
@@ -1099,7 +1561,7 @@ class TestOrgMemberServiceUpdateOrgMember:
# Assert
assert isinstance(data, OrgMemberResponse)
assert data.role_name == 'admin'
assert data.role == 'admin'
assert data.role_rank == 20
mock_update.assert_called_once_with(org_id, target_user_id, admin_role.id)
@@ -1431,7 +1893,7 @@ class TestOrgMemberServiceUpdateOrgMember:
# Assert
assert data is not None
assert data.role_name == 'member'
assert data.role == 'member'
assert data.role_rank == 1000

View File

@@ -0,0 +1,661 @@
"""Unit tests for AuthTokenStore."""
import time
from contextlib import asynccontextmanager
from typing import Dict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from server.auth.auth_error import TokenRefreshError
from sqlalchemy.exc import OperationalError
from storage.auth_token_store import (
ACCESS_TOKEN_EXPIRY_BUFFER,
LOCK_TIMEOUT_SECONDS,
AuthTokenStore,
)
from openhands.integrations.service_types import ProviderType
def create_mock_session():
"""Create a mock async session with properly configured context managers."""
session = AsyncMock()
# Create async context manager for begin()
@asynccontextmanager
async def begin_context():
yield
session.begin = begin_context
return session
def create_mock_session_maker(mock_session):
"""Create a mock async session maker."""
@asynccontextmanager
async def session_context():
yield mock_session
# Return a callable that returns the context manager
return lambda: session_context()
@pytest.fixture
def mock_session():
"""Create mock async session."""
return create_mock_session()
@pytest.fixture
def mock_session_maker(mock_session):
"""Create mock async session maker."""
return create_mock_session_maker(mock_session)
@pytest.fixture
def auth_token_store(mock_session_maker):
"""Create AuthTokenStore instance with mocked session maker."""
return AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
class TestIsTokenExpired:
"""Tests for _is_token_expired method."""
def test_both_tokens_valid(self, auth_token_store):
"""Test when both tokens are valid (not expired)."""
current_time = int(time.time())
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
refresh_expires = current_time + 1000
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is False
assert refresh_expired is False
def test_access_token_expired(self, auth_token_store):
"""Test when access token is expired but within buffer."""
current_time = int(time.time())
# Access token expires within buffer period
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
refresh_expires = current_time + 10000
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is True
assert refresh_expired is False
def test_refresh_token_expired(self, auth_token_store):
"""Test when refresh token is expired."""
current_time = int(time.time())
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
refresh_expires = current_time - 100 # Already expired
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is False
assert refresh_expired is True
def test_both_tokens_expired(self, auth_token_store):
"""Test when both tokens are expired."""
current_time = int(time.time())
access_expires = current_time - 100
refresh_expires = current_time - 100
access_expired, refresh_expired = auth_token_store._is_token_expired(
access_expires, refresh_expires
)
assert access_expired is True
assert refresh_expired is True
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
"""Test that 0 expiration time is treated as never expires."""
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
assert access_expired is False
assert refresh_expired is False
class TestLoadTokensFastPath:
"""Tests for load_tokens fast path (no lock needed)."""
@pytest.mark.asyncio
async def test_fast_path_token_not_found(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test fast path returns None when no token record exists."""
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
result = await auth_token_store.load_tokens()
assert result is None
@pytest.mark.asyncio
async def test_fast_path_valid_token_no_refresh_needed(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test fast path returns tokens when they are still valid."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'valid-access-token'
mock_token.refresh_token = 'valid-refresh-token'
mock_token.access_token_expires_at = (
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
)
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
result = await auth_token_store.load_tokens()
assert result is not None
assert result['access_token'] == 'valid-access-token'
assert result['refresh_token'] == 'valid-refresh-token'
@pytest.mark.asyncio
async def test_fast_path_no_refresh_callback_provided(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test fast path returns existing tokens when no refresh callback is provided."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'expired-access-token'
mock_token.refresh_token = 'valid-refresh-token'
# Expired access token
mock_token.access_token_expires_at = current_time - 100
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
assert result is not None
assert result['access_token'] == 'expired-access-token'
class TestLoadTokensSlowPath:
"""Tests for load_tokens slow path (lock required for refresh)."""
@pytest.mark.asyncio
async def test_slow_path_successful_refresh(self):
"""Test slow path successfully refreshes expired tokens."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - returns expired token
# Second call (slow path) - returns same token for update
expired_token = MagicMock()
expired_token.id = 1
expired_token.access_token = 'expired-access-token'
expired_token.refresh_token = 'valid-refresh-token'
expired_token.access_token_expires_at = current_time - 100 # Expired
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int]:
return {
'access_token': 'new-access-token',
'refresh_token': 'new-refresh-token',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert result is not None
assert result['access_token'] == 'new-access-token'
assert result['refresh_token'] == 'new-refresh-token'
@pytest.mark.asyncio
async def test_slow_path_double_check_avoids_refresh(self):
"""Test double-check locking: token was refreshed by another request."""
current_time = int(time.time())
mock_session = create_mock_session()
# Simulate scenario:
# 1. Fast path sees expired token
# 2. While waiting for lock, another request refreshes
# 3. Slow path sees fresh token, skips refresh
call_count = [0]
def create_token():
call_count[0] += 1
token = MagicMock()
token.id = 1
token.access_token = 'fresh-access-token'
token.refresh_token = 'fresh-refresh-token'
if call_count[0] == 1:
# First call (fast path) - expired
token.access_token_expires_at = current_time - 100
else:
# Second call (slow path) - already refreshed
token.access_token_expires_at = (
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
)
token.refresh_token_expires_at = current_time + 86400
return token
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.side_effect = (
lambda: create_token()
)
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
refresh_called = [False]
async def mock_refresh(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int]:
refresh_called[0] = True
return {
'access_token': 'should-not-be-used',
'refresh_token': 'should-not-be-used',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
# The refresh callback should not be called because double-check
# found the token was already refreshed
assert result is not None
assert result['access_token'] == 'fresh-access-token'
@pytest.mark.asyncio
async def test_slow_path_token_not_found_after_lock(self):
"""Test slow path returns None if token record disappears after lock."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - token exists but expired
# Second call (slow path with lock) - token no longer exists
call_count = [0]
def get_token():
call_count[0] += 1
if call_count[0] == 1:
token = MagicMock()
token.access_token_expires_at = current_time - 100 # Expired
token.refresh_token_expires_at = current_time + 10000
return token
return None
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.side_effect = get_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert result is None
class TestLoadTokensLockTimeout:
"""Tests for lock timeout handling."""
@pytest.mark.asyncio
async def test_lock_timeout_raises_token_refresh_error(self):
"""Test that lock timeout raises TokenRefreshError."""
current_time = int(time.time())
mock_session = create_mock_session()
# First call (fast path) - returns expired token
expired_token = MagicMock()
expired_token.access_token_expires_at = current_time - 100
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
# First execute for fast path succeeds
# Second execute (for slow path) raises OperationalError
call_count = [0]
async def execute_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] <= 1:
return mock_result
# Simulate lock timeout
raise OperationalError(
'canceling statement due to lock timeout', None, None
)
mock_session.execute = execute_side_effect
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
with pytest.raises(TokenRefreshError) as exc_info:
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
assert 'lock timeout' in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_lock_timeout_preserves_original_exception(self):
"""Test that TokenRefreshError preserves the original OperationalError."""
current_time = int(time.time())
mock_session = create_mock_session()
expired_token = MagicMock()
expired_token.access_token_expires_at = current_time - 100
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
original_error = OperationalError(
'canceling statement due to lock timeout', None, None
)
call_count = [0]
async def execute_side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] <= 1:
return mock_result
raise original_error
mock_session.execute = execute_side_effect
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh(*args) -> Dict[str, str | int]:
return {
'access_token': 'new-token',
'refresh_token': 'new-refresh',
'access_token_expires_at': current_time + 3600,
'refresh_token_expires_at': current_time + 86400,
}
with pytest.raises(TokenRefreshError) as exc_info:
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
# Verify the original exception is chained
assert exc_info.value.__cause__ is original_error
class TestLoadTokensRefreshCallbackBehavior:
"""Tests for refresh callback return values."""
@pytest.mark.asyncio
async def test_refresh_callback_returns_none(self):
"""Test behavior when refresh callback returns None (no refresh performed)."""
current_time = int(time.time())
mock_session = create_mock_session()
expired_token = MagicMock()
expired_token.id = 1
expired_token.access_token = 'old-access-token'
expired_token.refresh_token = 'old-refresh-token'
expired_token.access_token_expires_at = current_time - 100 # Expired
expired_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = expired_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
async def mock_refresh_returns_none(
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
) -> Dict[str, str | int] | None:
return None
result = await auth_store.load_tokens(
check_expiration_and_refresh=mock_refresh_returns_none
)
# Should return the old tokens when refresh returns None
assert result is not None
assert result['access_token'] == 'old-access-token'
assert result['refresh_token'] == 'old-refresh-token'
class TestStoreTokens:
"""Tests for store_tokens method."""
@pytest.mark.asyncio
async def test_store_tokens_creates_new_record(self):
"""Test storing tokens when no existing record."""
mock_session = create_mock_session()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
await auth_store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
mock_session.add.assert_called_once()
@pytest.mark.asyncio
async def test_store_tokens_updates_existing_record(self):
"""Test storing tokens updates existing record."""
mock_session = create_mock_session()
existing_token = MagicMock()
existing_token.access_token = 'old-access'
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = existing_token
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session_maker = create_mock_session_maker(mock_session)
auth_store = AuthTokenStore(
keycloak_user_id='test-user-123',
idp=ProviderType.GITHUB,
a_session_maker=mock_session_maker,
)
await auth_store.store_tokens(
access_token='new-access-token',
refresh_token='new-refresh-token',
access_token_expires_at=1234567890,
refresh_token_expires_at=1234657890,
)
assert existing_token.access_token == 'new-access-token'
assert existing_token.refresh_token == 'new-refresh-token'
class TestIsAccessTokenValid:
"""Tests for is_access_token_valid method."""
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_false_when_no_tokens(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test returns False when no tokens found."""
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
result = await auth_token_store.is_access_token_valid()
assert result is False
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_true_for_valid_token(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test returns True when token is valid."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'valid-access'
mock_token.refresh_token = 'valid-refresh'
mock_token.access_token_expires_at = current_time + 1000
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
result = await auth_token_store.is_access_token_valid()
assert result is True
@pytest.mark.asyncio
async def test_is_access_token_valid_returns_false_for_expired_token(
self, auth_token_store, mock_session_maker, mock_session
):
"""Test returns False when token is expired."""
current_time = int(time.time())
mock_token = MagicMock()
mock_token.access_token = 'expired-access'
mock_token.refresh_token = 'valid-refresh'
mock_token.access_token_expires_at = current_time - 100 # Expired
mock_token.refresh_token_expires_at = current_time + 10000
mock_result = MagicMock()
mock_result.scalars.return_value.one_or_none.return_value = mock_token
mock_session.execute = AsyncMock(return_value=mock_result)
result = await auth_token_store.is_access_token_valid()
assert result is False
class TestGetInstance:
"""Tests for get_instance class method."""
@pytest.mark.asyncio
async def test_get_instance_creates_auth_token_store(self):
"""Test get_instance creates an AuthTokenStore with correct params."""
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
store = await AuthTokenStore.get_instance(
keycloak_user_id='user-123', idp=ProviderType.GITHUB
)
assert store.keycloak_user_id == 'user-123'
assert store.idp == ProviderType.GITHUB
assert store.a_session_maker is mock_a_session_maker
class TestIdentityProviderValue:
"""Tests for identity_provider_value property."""
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
"""Test that identity_provider_value returns the enum value."""
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
def test_identity_provider_value_for_different_providers(self):
"""Test identity_provider_value for different providers."""
for provider in [
ProviderType.GITHUB,
ProviderType.GITLAB,
ProviderType.BITBUCKET,
]:
store = AuthTokenStore(
keycloak_user_id='test-user',
idp=provider,
a_session_maker=MagicMock(),
)
assert store.identity_provider_value == provider.value
class TestConstants:
"""Tests for module constants."""
def test_access_token_expiry_buffer_value(self):
"""Test ACCESS_TOKEN_EXPIRY_BUFFER is set to 15 minutes."""
assert ACCESS_TOKEN_EXPIRY_BUFFER == 900
def test_lock_timeout_seconds_value(self):
"""Test LOCK_TIMEOUT_SECONDS is set to 5 seconds."""
assert LOCK_TIMEOUT_SECONDS == 5

View File

@@ -0,0 +1,99 @@
"""Tests for the enterprise storage.database module.
These tests verify that the session_maker function properly forwards
keyword arguments to the underlying session maker for backward compatibility.
"""
from unittest.mock import MagicMock, patch
class TestSessionMaker:
"""Test cases for the session_maker function."""
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_without_args(self, mock_get_injector):
"""Test that session_maker works without any arguments."""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
# Call session_maker without arguments
result = session_maker()
# Verify the inner session maker was called without arguments
mock_inner_session_maker.assert_called_once_with()
assert result == mock_session
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_with_expire_on_commit_false(self, mock_get_injector):
"""Test that session_maker accepts expire_on_commit keyword argument.
This is a critical backward compatibility test - the session_maker
must accept keyword arguments like expire_on_commit=False which is
used in slack.py and potentially other integration modules.
"""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
# Call session_maker with expire_on_commit=False
# This is the exact call pattern used in slack.py line 242
result = session_maker(expire_on_commit=False)
# Verify the inner session maker was called with the keyword argument
mock_inner_session_maker.assert_called_once_with(expire_on_commit=False)
assert result == mock_session
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_with_multiple_kwargs(self, mock_get_injector):
"""Test that session_maker passes through multiple keyword arguments."""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
# Call with multiple kwargs
result = session_maker(
expire_on_commit=False, autoflush=False, autocommit=False
)
# Verify all kwargs were passed through
mock_inner_session_maker.assert_called_once_with(
expire_on_commit=False, autoflush=False, autocommit=False
)
assert result == mock_session
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_returns_correct_session(self, mock_get_injector):
"""Test that session_maker returns the session from the inner session maker."""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
result = session_maker()
# Verify the returned session is from the inner session maker
assert result is mock_session

View File

@@ -0,0 +1,158 @@
"""Unit tests for ResendSyncedUserStore."""
from unittest.mock import MagicMock
import pytest
# Import directly from the module files to avoid loading all of storage/__init__.py
# which has many dependencies
from storage.resend_synced_user import ResendSyncedUser
from storage.resend_synced_user_store import ResendSyncedUserStore
@pytest.fixture
def mock_session():
"""Mock database session."""
session = MagicMock()
return session
@pytest.fixture
def mock_session_maker(mock_session):
"""Mock session maker."""
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = mock_session
session_maker.return_value.__exit__.return_value = None
return session_maker
@pytest.fixture
def store(mock_session_maker):
"""Create ResendSyncedUserStore instance."""
return ResendSyncedUserStore(session_maker=mock_session_maker)
class TestResendSyncedUserStore:
"""Test cases for ResendSyncedUserStore."""
def test_is_user_synced_returns_true_when_exists(self, store, mock_session):
"""Test is_user_synced returns True when user exists in database."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_row = MagicMock()
mock_session.execute.return_value.first.return_value = mock_row
result = store.is_user_synced(email, audience_id)
assert result is True
mock_session.execute.assert_called_once()
def test_is_user_synced_returns_false_when_not_exists(self, store, mock_session):
"""Test is_user_synced returns False when user doesn't exist."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_session.execute.return_value.first.return_value = None
result = store.is_user_synced(email, audience_id)
assert result is False
def test_is_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
"""Test that is_user_synced normalizes email to lowercase."""
email = 'TEST@EXAMPLE.COM'
audience_id = 'test-audience-123'
mock_session.execute.return_value.first.return_value = None
store.is_user_synced(email, audience_id)
# Verify the query was called (we can't easily check the exact SQL)
mock_session.execute.assert_called_once()
def test_mark_user_synced_creates_new_record(self, store, mock_session):
"""Test that mark_user_synced creates a new record."""
email = 'test@example.com'
audience_id = 'test-audience-123'
keycloak_user_id = 'kc-user-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
result = store.mark_user_synced(email, audience_id, keycloak_user_id)
assert result == mock_synced_user
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_mark_user_synced_handles_existing_record(self, store, mock_session):
"""Test that mark_user_synced handles conflict (existing record)."""
email = 'test@example.com'
audience_id = 'test-audience-123'
# First execute (insert) returns None (conflict occurred)
# Second execute (select existing) returns the record
mock_existing_user = MagicMock(spec=ResendSyncedUser)
mock_result_insert = MagicMock()
mock_result_insert.first.return_value = None
mock_result_select = MagicMock()
mock_result_select.first.return_value = (mock_existing_user,)
mock_session.execute.side_effect = [mock_result_insert, mock_result_select]
result = store.mark_user_synced(email, audience_id)
assert result == mock_existing_user
assert mock_session.execute.call_count == 2
mock_session.commit.assert_called_once()
def test_mark_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
"""Test that mark_user_synced normalizes email to lowercase."""
email = 'TEST@EXAMPLE.COM'
audience_id = 'test-audience-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
store.mark_user_synced(email, audience_id)
# Verify execute was called (the email normalization happens in the SQL)
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_mark_user_synced_without_keycloak_user_id(self, store, mock_session):
"""Test that mark_user_synced works without keycloak_user_id."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
result = store.mark_user_synced(email, audience_id)
assert result == mock_synced_user
mock_session.execute.assert_called_once()
class TestResendSyncedUser:
"""Test cases for ResendSyncedUser model."""
def test_model_has_required_fields(self):
"""Test that the model has all required fields."""
assert hasattr(ResendSyncedUser, 'id')
assert hasattr(ResendSyncedUser, 'email')
assert hasattr(ResendSyncedUser, 'audience_id')
assert hasattr(ResendSyncedUser, 'synced_at')
assert hasattr(ResendSyncedUser, 'keycloak_user_id')
def test_model_table_name(self):
"""Test the model's table name."""
assert ResendSyncedUser.__tablename__ == 'resend_synced_users'

View File

@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.user import User
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.service_types import ProviderType
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# Test UUIDs
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
@pytest.fixture
async def async_engine():
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
yield db_session
@pytest.fixture
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session with pre-populated Org and User rows for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
# Insert Orgs first (required for User foreign key)
org1 = Org(
id=ORG1_ID,
name='test-org-1',
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
org2 = Org(
id=ORG2_ID,
name='test-org-2',
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
db_session.add(org1)
db_session.add(org2)
await db_session.flush()
# Insert Users
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
db_session.add(user1)
db_session.add(user2)
await db_session.commit()
yield db_session
@pytest.fixture
def service(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance for testing."""
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
assert user1_id != user2_id
@pytest.mark.asyncio
async def test_secure_select_includes_user_filtering(
async def test_secure_select_includes_user_and_org_filtering(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
async_session_with_users: AsyncSession,
):
"""Test that _secure_select method includes user filtering."""
# This test verifies that the _secure_select method exists and can be called
# The actual SQL generation is tested implicitly through integration
query = await saas_service_user1._secure_select()
assert query is not None
"""Test that _secure_select method includes both user_id and org_id filtering."""
service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
query = await service._secure_select()
# Convert query to string to verify filters are present
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
# Verify user_id filter is present
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
# Verify org_id filter is present (user1 is in org1)
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
@pytest.mark.asyncio
async def test_to_info_with_user_id_functionality(
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
assert result.sandbox_id == 'test-sandbox'
@pytest.mark.asyncio
async def test_user_isolation(
async def test_user_isolation_different_users(
self,
async_session: AsyncSession,
multiple_conversation_infos: list[AppConversationInfo],
async_session_with_users: AsyncSession,
):
"""Test that user isolation works correctly."""
from unittest.mock import MagicMock
from storage.user import User
# Mock the database session execute method to return mock users
# This mock intercepts User queries and returns a mock user object
# with user_id and org_id the same as the user_id_uuid from the query
original_execute = async_session.execute
async def mock_execute(query):
query_str = str(query)
# Check if this is a User query
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
# Extract the UUID from the query parameters
# The query will have bound parameters, we need to get the UUID value
if hasattr(query, 'compile'):
try:
compiled = query.compile(compile_kwargs={'literal_binds': True})
query_with_params = str(compiled)
# Extract UUID from the query string
import re
# Try both formats: with dashes and without dashes
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
uuid_match = re.search(
uuid_pattern_with_dashes, query_with_params
)
if not uuid_match:
uuid_match = re.search(
uuid_pattern_without_dashes, query_with_params
)
if uuid_match:
user_id_str = uuid_match.group(0)
# If the UUID doesn't have dashes, add them
if len(user_id_str) == 32 and '-' not in user_id_str:
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
user_id_uuid = UUID(user_id_str)
# Create a mock user with user_id and org_id the same as user_id_uuid
mock_user = MagicMock(spec=User)
mock_user.id = user_id_uuid
mock_user.current_org_id = user_id_uuid
# Create a mock result
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_user
return mock_result
except Exception:
# If there's any error in parsing, fall back to original execute
pass
# For all other queries, use the original execute method
return await original_execute(query)
# Apply the mock
async_session.execute = mock_execute
"""Test that different users cannot see each other's conversations."""
# Create services for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='a1111111-1111-1111-1111-111111111111'
),
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='b2222222-2222-2222-2222-222222222222'
),
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
)
# Create conversations for different users
user1_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_user1',
title='User 1 Conversation',
)
user2_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='b2222222-2222-2222-2222-222222222222',
created_by_user_id=str(USER2_ID),
sandbox_id='sandbox_user2',
title='User 2 Conversation',
)
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
# User 1 should only see their conversation
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 1
assert (
user1_page.items[0].created_by_user_id
== 'a1111111-1111-1111-1111-111111111111'
)
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
# User 2 should only see their conversation
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 1
assert (
user2_page.items[0].created_by_user_id
== 'b2222222-2222-2222-2222-222222222222'
)
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
# User 1 should not be able to get user 2's conversation
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
@@ -366,3 +347,319 @@ class TestSaasSQLAppConversationInfoService:
# User 2 should not be able to get user 1's conversation
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
assert user1_from_user2 is None
@pytest.mark.asyncio
async def test_same_user_org_switching_isolation(
self,
async_session_with_users: AsyncSession,
):
"""Test that the same user switching orgs cannot see conversations from other orgs.
This tests the actual bug scenario: a user creates a conversation in org1,
then switches to org2, and should NOT see org1's conversations.
"""
# Create service for user1 in org1
user1_service_org1 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create a conversation while user is in org1
conv_in_org1 = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_org1',
title='Conversation in Org 1',
)
await user1_service_org1.save_app_conversation_info(conv_in_org1)
# Verify user can see the conversation in org1
page_in_org1 = await user1_service_org1.search_app_conversation_info()
assert len(page_in_org1.items) == 1
assert page_in_org1.items[0].title == 'Conversation in Org 1'
# Simulate user switching to org2 by updating current_org_id using ORM
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG2_ID
await async_session_with_users.commit()
# Clear SQLAlchemy's identity map cache to simulate a new request
async_session_with_users.expire_all()
# Create new service instance (simulating a new request after org switch)
user1_service_org2 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# User should NOT see org1's conversations after switching to org2
page_in_org2 = await user1_service_org2.search_app_conversation_info()
assert (
len(page_in_org2.items) == 0
), 'User should not see conversations from org1 after switching to org2'
# User should not be able to get the specific conversation from org1
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
conv_in_org1.id
)
assert (
conv_from_org2 is None
), 'User should not be able to access org1 conversation from org2'
# Now create a conversation in org2
conv_in_org2 = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_org2',
title='Conversation in Org 2',
)
await user1_service_org2.save_app_conversation_info(conv_in_org2)
# User should only see org2's conversation
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
assert len(page_in_org2_after.items) == 1
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
# Switch back to org1 and verify isolation works both ways
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG1_ID
await async_session_with_users.commit()
async_session_with_users.expire_all()
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# User should only see org1's conversation now
page_back_in_org1 = (
await user1_service_back_to_org1.search_app_conversation_info()
)
assert len(page_back_in_org1.items) == 1
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
@pytest.mark.asyncio
async def test_count_respects_org_isolation(
self,
async_session_with_users: AsyncSession,
):
"""Test that count_app_conversation_info respects org isolation."""
# Create service for user1 in org1
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create conversations in org1
for i in range(3):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id=f'sandbox_org1_{i}',
title=f'Org1 Conversation {i}',
)
await user1_service.save_app_conversation_info(conv)
# Count should be 3
count_org1 = await user1_service.count_app_conversation_info()
assert count_org1 == 3
# Switch to org2 using ORM
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG2_ID
await async_session_with_users.commit()
async_session_with_users.expire_all()
user1_service_org2 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Count should be 0 in org2
count_org2 = await user1_service_org2.count_app_conversation_info()
assert count_org2 == 0
class TestSaasSQLAppConversationInfoServiceAdminContext:
"""Test suite for SaasSQLAppConversationInfoService with ADMIN context."""
@pytest.mark.asyncio
async def test_admin_context_returns_unfiltered_data(
self,
async_session_with_users: AsyncSession,
):
"""Test that ADMIN context returns unfiltered data (no user/org filtering)."""
# Create conversations for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create conversations for user1 in org1
for i in range(3):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id=f'sandbox_user1_{i}',
title=f'User1 Conversation {i}',
)
await user1_service.save_app_conversation_info(conv)
# Now create an ADMIN service
from openhands.app_server.user.specifiy_user_context import ADMIN
admin_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=ADMIN,
)
# ADMIN should see ALL conversations (unfiltered)
admin_page = await admin_service.search_app_conversation_info()
assert (
len(admin_page.items) == 3
), 'ADMIN context should see all conversations without filtering'
# ADMIN count should return total count (3)
admin_count = await admin_service.count_app_conversation_info()
assert (
admin_count == 3
), 'ADMIN context should count all conversations without filtering'
@pytest.mark.asyncio
async def test_admin_context_can_access_any_conversation(
self,
async_session_with_users: AsyncSession,
):
"""Test that ADMIN context can access any conversation regardless of owner."""
from openhands.app_server.user.specifiy_user_context import ADMIN
# Create a conversation as user1
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_user1',
title='User1 Private Conversation',
)
await user1_service.save_app_conversation_info(conv)
# Create a service as user2 in org2 - should not see user1's conversation
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
)
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 0, 'User2 should not see User1 conversation'
# But ADMIN should see ALL conversations including user1's
admin_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=ADMIN,
)
admin_page = await admin_service.search_app_conversation_info()
assert len(admin_page.items) == 1
assert admin_page.items[0].id == conv.id
# ADMIN should also be able to get specific conversation by ID
admin_get_conv = await admin_service.get_app_conversation_info(conv.id)
assert admin_get_conv is not None
assert admin_get_conv.id == conv.id
@pytest.mark.asyncio
async def test_secure_select_admin_bypasses_filtering(
self,
async_session_with_users: AsyncSession,
):
"""Test that _secure_select returns unfiltered query for ADMIN context."""
from openhands.app_server.user.specifiy_user_context import ADMIN
# Create an ADMIN service
admin_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=ADMIN,
)
# Get the secure select query
query = await admin_service._secure_select()
# Convert query to string to verify NO filters are present
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
# For ADMIN, there should be no user_id or org_id filtering
# The query should not contain filters for user_id or org_id
assert str(USER1_ID) not in query_str.replace(
'-', ''
), 'ADMIN context should not filter by user_id'
assert str(USER2_ID) not in query_str.replace(
'-', ''
), 'ADMIN context should not filter by user_id'
@pytest.mark.asyncio
async def test_regular_user_context_filters_correctly(
self,
async_session_with_users: AsyncSession,
):
"""Test that regular user context properly filters data (control test)."""
from openhands.app_server.user.specifiy_user_context import ADMIN
# Create conversations for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create 3 conversations for user1
for i in range(3):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id=f'sandbox_user1_{i}',
title=f'User1 Conversation {i}',
)
await user1_service.save_app_conversation_info(conv)
# Create 2 conversations for user2
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
)
for i in range(2):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER2_ID),
sandbox_id=f'sandbox_user2_{i}',
title=f'User2 Conversation {i}',
)
await user2_service.save_app_conversation_info(conv)
# User1 should only see their 3 conversations
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 3
# User2 should only see their 2 conversations
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 2
# But ADMIN should see all 5 conversations
admin_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=ADMIN,
)
admin_page = await admin_service.search_app_conversation_info()
assert len(admin_page.items) == 5

View File

@@ -1,6 +1,25 @@
"""Tests for resend_keycloak email validation."""
"""Tests for Resend Keycloak sync functionality."""
from sync.resend_keycloak import is_valid_email
import os
from unittest.mock import MagicMock, patch
import pytest
from resend.exceptions import ResendError
from tenacity import RetryError
# Set required environment variables before importing the module
# that reads them at import time
os.environ['RESEND_API_KEY'] = 'test_api_key'
os.environ['RESEND_AUDIENCE_ID'] = 'test_audience_id'
os.environ['KEYCLOAK_SERVER_URL'] = 'http://localhost:8080'
os.environ['KEYCLOAK_REALM_NAME'] = 'test_realm'
os.environ['KEYCLOAK_ADMIN_PASSWORD'] = 'test_password'
from enterprise.sync.resend_keycloak import ( # noqa: E402
add_contact_to_resend,
is_valid_email,
send_welcome_email,
)
class TestIsValidEmail:
@@ -115,3 +134,134 @@ class TestIsValidEmail:
"""Test that validation works for uppercase emails."""
assert is_valid_email('USER@EXAMPLE.COM') is True
assert is_valid_email('User@Example.Com') is True
class TestSendWelcomeEmail:
"""Tests for send_welcome_email function."""
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_success(self, mock_send: MagicMock) -> None:
"""Test successful welcome email sending."""
mock_send.return_value = {'id': 'email_123'}
result = send_welcome_email(
email='test@example.com',
first_name='John',
last_name='Doe',
)
assert result == {'id': 'email_123'}
mock_send.assert_called_once()
call_args = mock_send.call_args[0][0]
assert call_args['to'] == ['test@example.com']
assert call_args['subject'] == 'Welcome to OpenHands Cloud'
assert 'Hi John Doe,' in call_args['html']
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_retries_on_rate_limit(
self, mock_send: MagicMock
) -> None:
"""Test that send_welcome_email retries on rate limit errors."""
# First two calls raise rate limit error, third succeeds
mock_send.side_effect = [
ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
),
ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
),
{'id': 'email_123'},
]
result = send_welcome_email(
email='test@example.com',
first_name='John',
last_name='Doe',
)
assert result == {'id': 'email_123'}
assert mock_send.call_count == 3
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_fails_after_max_retries(
self, mock_send: MagicMock
) -> None:
"""Test that send_welcome_email fails after max retries."""
# All calls raise rate limit error
mock_send.side_effect = ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
)
# Tenacity wraps the final exception in RetryError
with pytest.raises(RetryError):
send_welcome_email(
email='test@example.com',
first_name='John',
last_name='Doe',
)
# Default MAX_RETRIES is 3
assert mock_send.call_count == 3
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_no_name(self, mock_send: MagicMock) -> None:
"""Test welcome email with no name provided."""
mock_send.return_value = {'id': 'email_123'}
result = send_welcome_email(email='test@example.com')
assert result == {'id': 'email_123'}
call_args = mock_send.call_args[0][0]
assert 'Hi there,' in call_args['html']
class TestAddContactToResend:
"""Tests for add_contact_to_resend function."""
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
def test_add_contact_to_resend_success(self, mock_create: MagicMock) -> None:
"""Test successful contact addition."""
mock_create.return_value = {'id': 'contact_123'}
result = add_contact_to_resend(
audience_id='test_audience',
email='test@example.com',
first_name='John',
last_name='Doe',
)
assert result == {'id': 'contact_123'}
mock_create.assert_called_once()
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
def test_add_contact_to_resend_retries_on_rate_limit(
self, mock_create: MagicMock
) -> None:
"""Test that add_contact_to_resend retries on rate limit errors."""
# First call raises rate limit error, second succeeds
mock_create.side_effect = [
ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
),
{'id': 'contact_123'},
]
result = add_contact_to_resend(
audience_id='test_audience',
email='test@example.com',
)
assert result == {'id': 'contact_123'}
assert mock_create.call_count == 2

View File

@@ -265,17 +265,17 @@ async def test_list_api_keys(
# Verify
mock_get_user.assert_called_once_with(user_id)
assert len(result) == 2
assert result[0]['id'] == 1
assert result[0]['name'] == 'Key 1'
assert result[0]['created_at'] == now
assert result[0]['last_used_at'] == now
assert result[0]['expires_at'] == now + timedelta(days=30)
assert result[0].id == 1
assert result[0].name == 'Key 1'
assert result[0].created_at == now
assert result[0].last_used_at == now
assert result[0].expires_at == now + timedelta(days=30)
assert result[1]['id'] == 2
assert result[1]['name'] == 'Key 2'
assert result[1]['created_at'] == now
assert result[1]['last_used_at'] is None
assert result[1]['expires_at'] is None
assert result[1].id == 2
assert result[1].name == 'Key 2'
assert result[1].created_at == now
assert result[1].last_used_at is None
assert result[1].expires_at is None
@pytest.mark.asyncio

View File

@@ -0,0 +1,181 @@
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
import pytest
class TestAuthCallbackInvitationEmailMismatch:
"""Test cases for EmailMismatchError handling during auth callback."""
@pytest.fixture
def mock_redirect_url(self):
"""Base redirect URL."""
return 'https://app.example.com/'
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return '87654321-4321-8765-4321-876543218765'
def test_email_mismatch_appends_to_url_without_query_params(
self, mock_redirect_url, mock_user_id
):
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
from server.routes.org_invitation_models import EmailMismatchError
# Simulate the logic from auth.py
redirect_url = mock_redirect_url
try:
raise EmailMismatchError('Your email does not match the invitation')
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
from server.routes.org_invitation_models import EmailMismatchError
redirect_url = 'https://app.example.com/?other_param=value'
try:
raise EmailMismatchError()
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert (
redirect_url
== 'https://app.example.com/?other_param=value&email_mismatch=true'
)
def test_email_mismatch_error_has_default_message(self):
"""Test that EmailMismatchError has the default message."""
from server.routes.org_invitation_models import EmailMismatchError
error = EmailMismatchError()
assert str(error) == 'Your email does not match the invitation'
def test_email_mismatch_error_accepts_custom_message(self):
"""Test that EmailMismatchError accepts a custom message."""
from server.routes.org_invitation_models import EmailMismatchError
custom_message = 'Custom error message'
error = EmailMismatchError(custom_message)
assert str(error) == custom_message
def test_email_mismatch_error_is_invitation_error(self):
"""Test that EmailMismatchError inherits from InvitationError."""
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationError,
)
error = EmailMismatchError()
assert isinstance(error, InvitationError)
class TestInvitationTokenInOAuthState:
"""Test cases for invitation token handling in OAuth state."""
def test_invitation_token_included_in_oauth_state(self):
"""Test that invitation token is included in OAuth state data."""
import base64
import json
# Simulate building OAuth state with invitation token
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
assert decoded_data['redirect_url'] == 'https://app.example.com/'
def test_invitation_token_extracted_from_oauth_state(self):
"""Test that invitation token can be extracted from OAuth state."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
# Simulate decoding in callback
decoded_state = json.loads(base64.b64decode(encoded_state))
invitation_token = decoded_state.get('invitation_token')
assert invitation_token == 'inv-test-token-12345'
def test_oauth_state_without_invitation_token(self):
"""Test that OAuth state works without invitation token."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert 'invitation_token' not in decoded_data
assert decoded_data['redirect_url'] == 'https://app.example.com/'
class TestAuthCallbackInvitationErrors:
"""Test cases for various invitation error scenarios in auth callback."""
def test_invitation_expired_appends_flag(self):
"""Test that invitation_expired=true is appended for expired invitations."""
from server.routes.org_invitation_models import InvitationExpiredError
redirect_url = 'https://app.example.com/'
try:
raise InvitationExpiredError()
except InvitationExpiredError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
def test_invitation_invalid_appends_flag(self):
"""Test that invitation_invalid=true is appended for invalid invitations."""
from server.routes.org_invitation_models import InvitationInvalidError
redirect_url = 'https://app.example.com/'
try:
raise InvitationInvalidError()
except InvitationInvalidError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
def test_already_member_appends_flag(self):
"""Test that already_member=true is appended when user is already a member."""
from server.routes.org_invitation_models import UserAlreadyMemberError
redirect_url = 'https://app.example.com/'
try:
raise UserAlreadyMemberError()
except UserAlreadyMemberError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
assert redirect_url == 'https://app.example.com/?already_member=true'

View File

@@ -284,3 +284,85 @@ async def test_middleware_ignores_email_resend_path_no_tos_check(
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
# Should not raise TosNotAcceptedError for this path
@pytest.mark.asyncio
async def test_middleware_skips_webhook_endpoints(
middleware, mock_request, mock_response
):
"""Test middleware skips webhook endpoints (/api/v1/webhooks/*) and doesn't require auth."""
# Test various webhook paths
webhook_paths = [
'/api/v1/webhooks/events',
'/api/v1/webhooks/events/123',
'/api/v1/webhooks/stats',
'/api/v1/webhooks/parent-conversation',
]
for path in webhook_paths:
mock_request.cookies = {}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = path
mock_call_next = AsyncMock(return_value=mock_response)
# Act
result = await middleware(mock_request, mock_call_next)
# Assert - middleware should skip auth check and call next
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
@pytest.mark.asyncio
async def test_middleware_skips_webhook_secrets_endpoint(
middleware, mock_request, mock_response
):
"""Test middleware skips the old /api/v1/webhooks/secrets endpoint."""
# This was explicitly in ignore_paths but is now handled by the prefix check
mock_request.cookies = {}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = '/api/v1/webhooks/secrets'
mock_call_next = AsyncMock(return_value=mock_response)
# Act
result = await middleware(mock_request, mock_call_next)
# Assert - middleware should skip auth check and call next
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
@pytest.mark.asyncio
async def test_middleware_does_not_skip_similar_non_webhook_paths(
middleware, mock_response
):
"""Test middleware does NOT skip paths that start with /api/v1/webhook (without 's')."""
# These paths should still be processed by the middleware (not skipped)
# They start with /api so _should_attach returns True, and since there's no auth,
# middleware should return 401 response (it catches NoCredentialsError internally)
non_webhook_paths = [
'/api/v1/webhook/events',
'/api/v1/webhook/something',
]
for path in non_webhook_paths:
# Create a fresh mock request for each test
mock_request = MagicMock(spec=Request)
mock_request.cookies = {}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = path
mock_request.headers = MagicMock()
mock_request.headers.get = MagicMock(side_effect=lambda k: None)
# Since these paths start with /api, _should_attach returns True
# Since there's no auth, middleware catches NoCredentialsError and returns 401
mock_call_next = AsyncMock()
result = await middleware(mock_request, mock_call_next)
# Should return a 401 response, not raise an exception
assert result.status_code == status.HTTP_401_UNAUTHORIZED
# Should NOT call next for non-webhook paths when auth is missing
mock_call_next.assert_not_called()

View File

@@ -1,188 +1,331 @@
"""
Unit tests for role-based authorization (authorization.py).
Unit tests for permission-based authorization (authorization.py).
Tests the FastAPI dependencies that validate user roles within organizations.
Tests the FastAPI dependencies that validate user permissions within organizations.
"""
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import HTTPException
from server.auth.authorization import (
ROLE_HIERARCHY,
OrgRole,
ROLE_PERMISSIONS,
Permission,
RoleName,
get_role_permissions,
get_user_org_role,
has_required_role,
require_org_admin,
require_org_owner,
require_org_role,
require_org_user,
has_permission,
require_permission,
)
# =============================================================================
# Tests for OrgRole enum
# Tests for Permission enum
# =============================================================================
class TestOrgRole:
"""Tests for OrgRole enum."""
class TestPermission:
"""Tests for Permission enum."""
def test_org_role_values(self):
def test_permission_values(self):
"""
GIVEN: OrgRole enum
WHEN: Accessing role values
THEN: All expected roles exist with correct string values
GIVEN: Permission enum
WHEN: Accessing permission values
THEN: All expected permissions exist with correct string values
"""
assert OrgRole.OWNER.value == 'owner'
assert OrgRole.ADMIN.value == 'admin'
assert OrgRole.USER.value == 'user'
assert Permission.MANAGE_SECRETS.value == 'manage_secrets'
assert Permission.MANAGE_MCP.value == 'manage_mcp'
assert Permission.MANAGE_INTEGRATIONS.value == 'manage_integrations'
assert (
Permission.MANAGE_APPLICATION_SETTINGS.value
== 'manage_application_settings'
)
assert Permission.MANAGE_API_KEYS.value == 'manage_api_keys'
assert Permission.VIEW_LLM_SETTINGS.value == 'view_llm_settings'
assert Permission.EDIT_LLM_SETTINGS.value == 'edit_llm_settings'
assert Permission.VIEW_BILLING.value == 'view_billing'
assert Permission.ADD_CREDITS.value == 'add_credits'
assert (
Permission.INVITE_USER_TO_ORGANIZATION.value
== 'invite_user_to_organization'
)
assert Permission.CHANGE_USER_ROLE_MEMBER.value == 'change_user_role:member'
assert Permission.CHANGE_USER_ROLE_ADMIN.value == 'change_user_role:admin'
assert Permission.CHANGE_USER_ROLE_OWNER.value == 'change_user_role:owner'
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
def test_org_role_from_string(self):
def test_permission_from_string(self):
"""
GIVEN: Valid role string
WHEN: Creating OrgRole from string
GIVEN: Valid permission string
WHEN: Creating Permission from string
THEN: Correct enum value is returned
"""
assert OrgRole('owner') == OrgRole.OWNER
assert OrgRole('admin') == OrgRole.ADMIN
assert OrgRole('user') == OrgRole.USER
assert Permission('manage_secrets') == Permission.MANAGE_SECRETS
assert Permission('view_llm_settings') == Permission.VIEW_LLM_SETTINGS
assert Permission('delete_organization') == Permission.DELETE_ORGANIZATION
def test_org_role_invalid_string(self):
def test_permission_invalid_string(self):
"""
GIVEN: Invalid role string
WHEN: Creating OrgRole from string
GIVEN: Invalid permission string
WHEN: Creating Permission from string
THEN: ValueError is raised
"""
with pytest.raises(ValueError):
OrgRole('invalid_role')
Permission('invalid_permission')
# =============================================================================
# Tests for role hierarchy
# Tests for RoleName enum
# =============================================================================
class TestRoleHierarchy:
"""Tests for role hierarchy constants."""
class TestRoleName:
"""Tests for RoleName enum."""
def test_owner_highest_rank(self):
def test_role_name_values(self):
"""
GIVEN: Role hierarchy
WHEN: Comparing role ranks
THEN: Owner has highest rank
GIVEN: RoleName enum
WHEN: Accessing role name values
THEN: All expected roles exist with correct string values
"""
assert ROLE_HIERARCHY[OrgRole.OWNER] > ROLE_HIERARCHY[OrgRole.ADMIN]
assert ROLE_HIERARCHY[OrgRole.OWNER] > ROLE_HIERARCHY[OrgRole.USER]
assert RoleName.OWNER.value == 'owner'
assert RoleName.ADMIN.value == 'admin'
assert RoleName.MEMBER.value == 'member'
def test_admin_middle_rank(self):
def test_role_name_from_string(self):
"""
GIVEN: Role hierarchy
WHEN: Comparing role ranks
THEN: Admin is between owner and user
GIVEN: Valid role name string
WHEN: Creating RoleName from string
THEN: Correct enum value is returned
"""
assert ROLE_HIERARCHY[OrgRole.ADMIN] > ROLE_HIERARCHY[OrgRole.USER]
assert ROLE_HIERARCHY[OrgRole.ADMIN] < ROLE_HIERARCHY[OrgRole.OWNER]
assert RoleName('owner') == RoleName.OWNER
assert RoleName('admin') == RoleName.ADMIN
assert RoleName('member') == RoleName.MEMBER
def test_user_lowest_rank(self):
def test_role_name_invalid_string(self):
"""
GIVEN: Role hierarchy
WHEN: Comparing role ranks
THEN: User has lowest rank
GIVEN: Invalid role name string
WHEN: Creating RoleName from string
THEN: ValueError is raised
"""
assert ROLE_HIERARCHY[OrgRole.USER] < ROLE_HIERARCHY[OrgRole.ADMIN]
assert ROLE_HIERARCHY[OrgRole.USER] < ROLE_HIERARCHY[OrgRole.OWNER]
with pytest.raises(ValueError):
RoleName('invalid_role')
# =============================================================================
# Tests for has_required_role function
# Tests for ROLE_PERMISSIONS mapping
# =============================================================================
class TestHasRequiredRole:
"""Tests for has_required_role function."""
class TestRolePermissions:
"""Tests for role permission mappings."""
def test_owner_has_owner_role(self):
def test_owner_has_all_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking owner permissions
THEN: Owner has all permissions including owner-only permissions
"""
owner_perms = ROLE_PERMISSIONS[RoleName.OWNER]
assert Permission.MANAGE_SECRETS in owner_perms
assert Permission.MANAGE_MCP in owner_perms
assert Permission.VIEW_LLM_SETTINGS in owner_perms
assert Permission.EDIT_LLM_SETTINGS in owner_perms
assert Permission.VIEW_BILLING in owner_perms
assert Permission.ADD_CREDITS in owner_perms
assert Permission.INVITE_USER_TO_ORGANIZATION in owner_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in owner_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in owner_perms
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
assert Permission.DELETE_ORGANIZATION in owner_perms
def test_admin_has_admin_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking admin permissions
THEN: Admin has admin permissions but not owner-only permissions
"""
admin_perms = ROLE_PERMISSIONS[RoleName.ADMIN]
assert Permission.MANAGE_SECRETS in admin_perms
assert Permission.MANAGE_MCP in admin_perms
assert Permission.VIEW_LLM_SETTINGS in admin_perms
assert Permission.EDIT_LLM_SETTINGS in admin_perms
assert Permission.VIEW_BILLING in admin_perms
assert Permission.ADD_CREDITS in admin_perms
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
# Admin should NOT have owner-only permissions
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
assert Permission.DELETE_ORGANIZATION not in admin_perms
def test_member_has_limited_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking member permissions
THEN: Member has limited permissions
"""
member_perms = ROLE_PERMISSIONS[RoleName.MEMBER]
# Member has basic settings permissions
assert Permission.MANAGE_SECRETS in member_perms
assert Permission.MANAGE_MCP in member_perms
assert Permission.MANAGE_INTEGRATIONS in member_perms
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
assert Permission.MANAGE_API_KEYS in member_perms
assert Permission.VIEW_LLM_SETTINGS in member_perms
assert Permission.VIEW_ORG_SETTINGS in member_perms
# Member should NOT have admin/owner permissions
assert Permission.EDIT_LLM_SETTINGS not in member_perms
assert Permission.VIEW_BILLING not in member_perms
assert Permission.ADD_CREDITS not in member_perms
assert Permission.INVITE_USER_TO_ORGANIZATION not in member_perms
assert Permission.CHANGE_USER_ROLE_MEMBER not in member_perms
assert Permission.CHANGE_USER_ROLE_ADMIN not in member_perms
assert Permission.CHANGE_USER_ROLE_OWNER not in member_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in member_perms
assert Permission.DELETE_ORGANIZATION not in member_perms
# =============================================================================
# Tests for get_role_permissions function
# =============================================================================
class TestGetRolePermissions:
"""Tests for get_role_permissions function."""
def test_get_owner_permissions(self):
"""
GIVEN: Role name 'owner'
WHEN: get_role_permissions is called
THEN: Owner permissions are returned
"""
perms = get_role_permissions('owner')
assert Permission.DELETE_ORGANIZATION in perms
assert Permission.CHANGE_ORGANIZATION_NAME in perms
def test_get_admin_permissions(self):
"""
GIVEN: Role name 'admin'
WHEN: get_role_permissions is called
THEN: Admin permissions are returned
"""
perms = get_role_permissions('admin')
assert Permission.EDIT_LLM_SETTINGS in perms
assert Permission.DELETE_ORGANIZATION not in perms
def test_get_member_permissions(self):
"""
GIVEN: Role name 'member'
WHEN: get_role_permissions is called
THEN: Member permissions are returned
"""
perms = get_role_permissions('member')
assert Permission.VIEW_LLM_SETTINGS in perms
assert Permission.EDIT_LLM_SETTINGS not in perms
def test_get_invalid_role_permissions(self):
"""
GIVEN: Invalid role name
WHEN: get_role_permissions is called
THEN: Empty frozenset is returned
"""
perms = get_role_permissions('invalid_role')
assert perms == frozenset()
# =============================================================================
# Tests for has_permission function
# =============================================================================
class TestHasPermission:
"""Tests for has_permission function."""
def test_owner_has_delete_organization_permission(self):
"""
GIVEN: User with owner role
WHEN: Checking for owner requirement
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns True
"""
assert has_required_role('owner', OrgRole.OWNER) is True
mock_role = MagicMock()
mock_role.name = 'owner'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is True
def test_owner_has_admin_role(self):
def test_owner_has_view_llm_settings_permission(self):
"""
GIVEN: User with owner role
WHEN: Checking for admin requirement
THEN: Returns True (owner > admin)
"""
assert has_required_role('owner', OrgRole.ADMIN) is True
def test_owner_has_user_role(self):
"""
GIVEN: User with owner role
WHEN: Checking for user requirement
THEN: Returns True (owner > user)
"""
assert has_required_role('owner', OrgRole.USER) is True
def test_admin_has_admin_role(self):
"""
GIVEN: User with admin role
WHEN: Checking for admin requirement
WHEN: Checking for VIEW_LLM_SETTINGS permission
THEN: Returns True
"""
assert has_required_role('admin', OrgRole.ADMIN) is True
mock_role = MagicMock()
mock_role.name = 'owner'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
def test_admin_has_user_role(self):
def test_admin_has_edit_llm_settings_permission(self):
"""
GIVEN: User with admin role
WHEN: Checking for user requirement
THEN: Returns True (admin > user)
"""
assert has_required_role('admin', OrgRole.USER) is True
def test_admin_lacks_owner_role(self):
"""
GIVEN: User with admin role
WHEN: Checking for owner requirement
THEN: Returns False (admin < owner)
"""
assert has_required_role('admin', OrgRole.OWNER) is False
def test_user_has_user_role(self):
"""
GIVEN: User with user role
WHEN: Checking for user requirement
WHEN: Checking for EDIT_LLM_SETTINGS permission
THEN: Returns True
"""
assert has_required_role('user', OrgRole.USER) is True
mock_role = MagicMock()
mock_role.name = 'admin'
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is True
def test_user_lacks_admin_role(self):
def test_admin_lacks_delete_organization_permission(self):
"""
GIVEN: User with user role
WHEN: Checking for admin requirement
THEN: Returns False (user < admin)
"""
assert has_required_role('user', OrgRole.ADMIN) is False
def test_user_lacks_owner_role(self):
"""
GIVEN: User with user role
WHEN: Checking for owner requirement
THEN: Returns False (user < owner)
"""
assert has_required_role('user', OrgRole.OWNER) is False
def test_invalid_role_returns_false(self):
"""
GIVEN: Invalid role string
WHEN: Checking for any requirement
GIVEN: User with admin role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns False
"""
assert has_required_role('invalid_role', OrgRole.USER) is False
assert has_required_role('invalid_role', OrgRole.ADMIN) is False
assert has_required_role('invalid_role', OrgRole.OWNER) is False
mock_role = MagicMock()
mock_role.name = 'admin'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
def test_member_has_view_llm_settings_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for VIEW_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
def test_member_lacks_edit_llm_settings_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for EDIT_LLM_SETTINGS permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is False
def test_member_lacks_delete_organization_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
def test_invalid_role_has_no_permissions(self):
"""
GIVEN: User with invalid role
WHEN: Checking for any permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'invalid_role'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is False
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
# =============================================================================
@@ -197,7 +340,7 @@ class TestGetUserOrgRole:
"""
GIVEN: User is a member of organization with role
WHEN: get_user_org_role is called
THEN: Role name is returned
THEN: Role object is returned
"""
user_id = str(uuid4())
org_id = uuid4()
@@ -219,7 +362,7 @@ class TestGetUserOrgRole:
),
):
result = get_user_org_role(user_id, org_id)
assert result == 'admin'
assert result == mock_role
def test_returns_none_when_not_member(self):
"""
@@ -237,70 +380,95 @@ class TestGetUserOrgRole:
result = get_user_org_role(user_id, org_id)
assert result is None
def test_returns_none_when_role_not_found(self):
def test_returns_role_when_org_id_is_none(self):
"""
GIVEN: User is member but role not found
WHEN: get_user_org_role is called
THEN: None is returned
GIVEN: User with a current organization
WHEN: get_user_org_role is called with org_id=None
THEN: Role object is returned using get_org_member_for_current_org
"""
user_id = str(uuid4())
org_id = uuid4()
mock_org_member = MagicMock()
mock_org_member.role_id = 999 # Non-existent role
mock_org_member.role_id = 1
mock_role = MagicMock()
mock_role.name = 'admin'
with (
patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
return_value=mock_org_member,
),
) as mock_get_current,
patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
) as mock_get_org_member,
patch(
'server.auth.authorization.RoleStore.get_role_by_id',
return_value=None,
return_value=mock_role,
),
):
result = get_user_org_role(user_id, org_id)
result = get_user_org_role(user_id, None)
assert result == mock_role
mock_get_current.assert_called_once()
mock_get_org_member.assert_not_called()
def test_returns_none_when_org_id_is_none_and_no_current_org(self):
"""
GIVEN: User with no current organization membership
WHEN: get_user_org_role is called with org_id=None
THEN: None is returned
"""
user_id = str(uuid4())
with patch(
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
return_value=None,
):
result = get_user_org_role(user_id, None)
assert result is None
# =============================================================================
# Tests for require_org_role dependency
# Tests for require_permission dependency
# =============================================================================
class TestRequireOrgRole:
"""Tests for require_org_role dependency factory."""
class TestRequirePermission:
"""Tests for require_permission dependency factory."""
@pytest.mark.asyncio
async def test_returns_user_id_when_authorized(self):
"""
GIVEN: User with sufficient role
WHEN: Role checker is called
GIVEN: User with required permission
WHEN: Permission checker is called
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role',
return_value='admin',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
role_checker = require_org_role(OrgRole.USER)
result = await role_checker(org_id=org_id, user_id=user_id)
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_raises_401_when_not_authenticated(self):
"""
GIVEN: No user ID (not authenticated)
WHEN: Role checker is called
WHEN: Permission checker is called
THEN: 401 Unauthorized is raised
"""
org_id = uuid4()
role_checker = require_org_role(OrgRole.USER)
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await role_checker(org_id=org_id, user_id=None)
await permission_checker(org_id=org_id, user_id=None)
assert exc_info.value.status_code == 401
assert 'not authenticated' in exc_info.value.detail.lower()
@@ -309,183 +477,280 @@ class TestRequireOrgRole:
async def test_raises_403_when_not_member(self):
"""
GIVEN: User is not a member of organization
WHEN: Role checker is called
WHEN: Permission checker is called
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.get_user_org_role',
return_value=None,
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
role_checker = require_org_role(OrgRole.USER)
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await role_checker(org_id=org_id, user_id=user_id)
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_raises_403_when_insufficient_role(self):
async def test_raises_403_when_insufficient_permission(self):
"""
GIVEN: User with insufficient role
WHEN: Role checker is called
GIVEN: User without required permission
WHEN: Permission checker is called
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role',
return_value='user',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
role_checker = require_org_role(OrgRole.ADMIN)
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info:
await role_checker(org_id=org_id, user_id=user_id)
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'admin' in exc_info.value.detail.lower()
assert 'delete_organization' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_owner_satisfies_admin_requirement(self):
async def test_owner_can_delete_organization(self):
"""
GIVEN: User with owner role
WHEN: Admin role is required
THEN: User ID is returned (owner > admin)
WHEN: DELETE_ORGANIZATION permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'owner'
with patch(
'server.auth.authorization.get_user_org_role',
return_value='owner',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
role_checker = require_org_role(OrgRole.ADMIN)
result = await role_checker(org_id=org_id, user_id=user_id)
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_logs_warning_on_insufficient_role(self):
async def test_admin_cannot_delete_organization(self):
"""
GIVEN: User with insufficient role
WHEN: Role checker is called
GIVEN: User with admin role
WHEN: DELETE_ORGANIZATION permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_logs_warning_on_insufficient_permission(self):
"""
GIVEN: User without required permission
WHEN: Permission checker is called
THEN: Warning is logged with details
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with (
patch(
'server.auth.authorization.get_user_org_role',
return_value='user',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
),
patch('server.auth.authorization.logger') as mock_logger,
):
role_checker = require_org_role(OrgRole.OWNER)
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException):
await role_checker(org_id=org_id, user_id=user_id)
await permission_checker(org_id=org_id, user_id=user_id)
mock_logger.warning.assert_called()
call_args = mock_logger.warning.call_args
assert call_args[1]['extra']['user_id'] == user_id
assert call_args[1]['extra']['user_role'] == 'user'
assert call_args[1]['extra']['required_role'] == 'owner'
# =============================================================================
# Tests for convenience dependencies
# =============================================================================
class TestConvenienceDependencies:
"""Tests for pre-configured convenience dependencies."""
assert call_args[1]['extra']['user_role'] == 'member'
assert call_args[1]['extra']['required_permission'] == 'delete_organization'
@pytest.mark.asyncio
async def test_require_org_user_allows_user(self):
async def test_returns_user_id_when_org_id_is_none(self):
"""
GIVEN: User with user role
WHEN: require_org_user is used
GIVEN: User with required permission in their current org
WHEN: Permission checker is called with org_id=None
THEN: User ID is returned
"""
user_id = str(uuid4())
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
) as mock_get_role:
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=None, user_id=user_id)
assert result == user_id
mock_get_role.assert_called_once_with(user_id, None)
@pytest.mark.asyncio
async def test_raises_403_when_org_id_is_none_and_not_member(self):
"""
GIVEN: User not a member of their current organization
WHEN: Permission checker is called with org_id=None
THEN: HTTPException with 403 status is raised
"""
user_id = str(uuid4())
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=None, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail
# =============================================================================
# Tests for permission-based access control scenarios
# =============================================================================
class TestPermissionScenarios:
"""Tests for real-world permission scenarios."""
@pytest.mark.asyncio
async def test_member_can_manage_secrets(self):
"""
GIVEN: User with member role
WHEN: MANAGE_SECRETS permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role',
return_value='user',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
result = await require_org_user(org_id=org_id, user_id=user_id)
permission_checker = require_permission(Permission.MANAGE_SECRETS)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_require_org_admin_allows_admin(self):
async def test_member_cannot_invite_users(self):
"""
GIVEN: User with admin role
WHEN: require_org_admin is used
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.get_user_org_role',
return_value='admin',
):
result = await require_org_admin(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_require_org_admin_rejects_user(self):
"""
GIVEN: User with user role
WHEN: require_org_admin is used
GIVEN: User with member role
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role',
return_value='user',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION
)
with pytest.raises(HTTPException) as exc_info:
await require_org_admin(org_id=org_id, user_id=user_id)
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_require_org_owner_allows_owner(self):
async def test_admin_can_invite_users(self):
"""
GIVEN: User with admin role
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION
)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_admin_cannot_change_owner_role(self):
"""
GIVEN: User with admin role
WHEN: CHANGE_USER_ROLE_OWNER permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_owner_can_change_owner_role(self):
"""
GIVEN: User with owner role
WHEN: require_org_owner is used
WHEN: CHANGE_USER_ROLE_OWNER permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'owner'
with patch(
'server.auth.authorization.get_user_org_role',
return_value='owner',
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
result = await require_org_owner(org_id=org_id, user_id=user_id)
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_require_org_owner_rejects_admin(self):
"""
GIVEN: User with admin role
WHEN: require_org_owner is used
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.get_user_org_role',
return_value='admin',
):
with pytest.raises(HTTPException) as exc_info:
await require_org_owner(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403

View File

@@ -101,7 +101,7 @@ async def test_get_credits_success():
json={
'user_info': {
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
}
},
request=MagicMock(),
@@ -121,7 +121,7 @@ async def test_get_credits_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
},
),
):
@@ -313,7 +313,7 @@ async def test_success_callback_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
'max_budget_in_team': 100.00,
},
),
patch(
@@ -349,7 +349,7 @@ async def test_success_callback_success():
# Verify LiteLLM API calls
mock_update_budget.assert_called_once_with(
'mock_org_id',
125.0, # 100 + (25.00 from Stripe)
125.0, # 100 + 25.00
)
# Verify BYOR export is enabled for the org (updated in same session)
@@ -402,6 +402,68 @@ async def test_success_callback_lite_llm_error():
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_success_callback_lite_llm_update_budget_error_rollback():
"""Test that database changes are not committed when update_team_and_users_budget fails.
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
the database transaction rolls back.
"""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 0,
'max_budget_in_team': 0,
},
),
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget',
side_effect=Exception('LiteLLM API Error'),
),
):
mock_db_session = MagicMock()
mock_query_chain_billing = MagicMock()
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_query_chain_org = MagicMock()
mock_query_chain_org.filter.return_value.first.return_value = mock_org
mock_db_session.query.side_effect = [
mock_query_chain_billing,
mock_query_chain_org,
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete',
amount_subtotal=1000, # $10
customer='mock_customer_id',
)
with pytest.raises(Exception, match='LiteLLM API Error'):
await success_callback('test_session_id', mock_request)
# Verify no database commit occurred - the transaction should roll back
assert mock_billing_session.status == 'in_progress'
mock_db_session.merge.assert_not_called()
mock_db_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_cancel_callback_session_not_found():
"""Test cancel callback when billing session is not found."""
@@ -517,6 +579,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='https://test.com/?free_credits=success',
success_url='https://test.com/?setup=success',
cancel_url='https://test.com/',
)

View File

@@ -48,7 +48,7 @@ async def test_create_customer_setup_session_uses_customer_id():
customer=customer_id,
mode='setup',
payment_method_types=['card'],
success_url=f'{request.base_url}?free_credits=success',
success_url=f'{request.base_url}?setup=success',
cancel_url=f'{request.base_url}',
)

View File

@@ -0,0 +1,192 @@
"""Tests for email service."""
import os
from unittest.mock import MagicMock, patch
from server.services.email_service import (
DEFAULT_WEB_HOST,
EmailService,
)
class TestEmailServiceInvitationUrl:
"""Test cases for invitation URL generation."""
def test_invitation_url_uses_correct_endpoint(self):
"""Test that invitation URL points to the correct API endpoint."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
# Get the call arguments
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify the URL in the email HTML contains the correct endpoint
assert (
'/api/organizations/members/invite/accept?token='
in email_params['html']
)
assert 'inv-test-token-12345' in email_params['html']
def test_invitation_url_uses_web_host_env_var(self):
"""Test that invitation URL uses WEB_HOST environment variable."""
custom_host = 'https://custom.example.com'
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(
os.environ,
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
def test_invitation_url_uses_default_host_when_env_not_set(self):
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
# Remove WEB_HOST if it exists
env_without_web_host.pop('WEB_HOST', None)
with (
patch.dict(os.environ, env_without_web_host, clear=True),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
# Clear WEB_HOST from the environment
os.environ.pop('WEB_HOST', None)
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
class TestEmailServiceGetResendClient:
"""Test cases for Resend client initialization."""
def test_get_resend_client_returns_false_when_resend_not_available(self):
"""Test that _get_resend_client returns False when resend is not installed."""
with patch('server.services.email_service.RESEND_AVAILABLE', False):
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
"""Test that _get_resend_client returns False when API key is missing."""
with (
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop('RESEND_API_KEY', None)
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_true_when_configured(self):
"""Test that _get_resend_client returns True when properly configured."""
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
result = EmailService._get_resend_client()
assert result is True
assert mock_resend.api_key == 'test-key'
class TestEmailServiceSendInvitationEmail:
"""Test cases for send_invitation_email method."""
def test_send_invitation_email_skips_when_client_not_ready(self):
"""Test that email sending is skipped when client is not ready."""
with patch.object(
EmailService, '_get_resend_client', return_value=False
) as mock_get_client:
# Should not raise, just return early
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token',
invitation_id=1,
)
mock_get_client.assert_called_once()
def test_send_invitation_email_includes_all_required_info(self):
"""Test that invitation email includes org name, inviter name, and role."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Acme Corp',
inviter_name='John Doe',
role_name='admin',
invitation_token='inv-test-token-12345',
invitation_id=42,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify email content
assert email_params['to'] == ['test@example.com']
assert 'Acme Corp' in email_params['subject']
assert 'John Doe' in email_params['html']
assert 'Acme Corp' in email_params['html']
assert 'admin' in email_params['html']

View File

@@ -142,44 +142,192 @@ class TestLiteLlmManager:
@pytest.mark.asyncio
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
"""Test create_entries in cloud deployment mode."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.post.return_value = mock_response
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
result = await LiteLlmManager.create_entries(
'test-org-id',
'test-user-id',
mock_settings,
create_user=False,
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert (
result.llm_api_key.get_secret_value() == 'test-api-key'
)
assert result.llm_base_url == 'http://test.com'
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
# Verify API calls were made
assert (
mock_client.post.call_count == 3
) # create_team, create_user, add_user_to_team, generate_key
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-api-key'
assert result.llm_base_url == 'http://test.com'
# Verify API calls were made (get_team + 3 posts)
assert mock_client.get.call_count == 1 # get_team
assert (
mock_client.post.call_count == 3
) # create_team, add_user_to_team, generate_key
@pytest.mark.asyncio
async def test_create_entries_inherits_existing_team_budget(
self, mock_settings, mock_response
):
"""Test that create_entries inherits budget from existing team."""
mock_team_response = MagicMock()
mock_team_response.is_success = True
mock_team_response.status_code = 200
mock_team_response.json.return_value = {
'team_info': {'max_budget': 30.0, 'spend': 5.0},
'team_memberships': [],
}
mock_team_response.raise_for_status = MagicMock()
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_team_response
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _get_team was called first
mock_client.get.assert_called_once()
get_call_url = mock_client.get.call_args[0][0]
assert 'team/info' in get_call_url
assert 'test-org-id' in get_call_url
# Verify _create_team was called with inherited budget (30.0)
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 30.0
# Verify _add_user_to_team was called with inherited budget (30.0)
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
@pytest.mark.asyncio
async def test_create_entries_new_org_uses_zero_budget(
self, mock_settings, mock_response
):
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _create_team was called with budget=0
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 0.0
# Verify _add_user_to_team was called with budget=0
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
@pytest.mark.asyncio
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
"""Test that create_entries propagates non-404 errors from _get_team."""
mock_500_response = MagicMock()
mock_500_response.status_code = 500
mock_500_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_500_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Internal Server Error',
request=MagicMock(),
response=mock_500_response,
)
)
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert exc_info.value.response.status_code == 500
@pytest.mark.asyncio
async def test_migrate_entries_missing_config(self, mock_user_settings):

View File

@@ -0,0 +1,464 @@
"""Tests for organization invitation service - email validation."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from server.routes.org_invitation_models import (
EmailMismatchError,
)
from server.services.org_invitation_service import OrgInvitationService
from storage.org_invitation import OrgInvitation
class TestAcceptInvitationEmailValidation:
"""Test cases for email validation during invitation acceptance."""
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation with pending status."""
invitation = MagicMock(spec=OrgInvitation)
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.status = OrgInvitation.STATUS_PENDING
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
invitation.role_id = 1
return invitation
@pytest.fixture
def mock_user(self):
"""Create a mock user with email."""
user = MagicMock()
user.id = UUID('87654321-4321-8765-4321-876543218765')
user.email = 'alice@example.com'
return user
@pytest.mark.asyncio
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
"""Test that invitation is accepted when user email matches invitation email."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
with patch.object(
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
) as mock_accept:
mock_accept.return_value = mock_invitation
# Act
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_accept.assert_called_once_with(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_email_mismatch_raises_error(
self, mock_invitation, mock_user
):
"""Test that EmailMismatchError is raised when emails don't match."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
mock_user.email = 'bob@example.com' # Different email
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Act & Assert
with pytest.raises(EmailMismatchError):
await OrgInvitationService.accept_invitation(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
self, mock_invitation
):
"""Test that Keycloak email is used when user has no email in database."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager instance
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(
return_value=mock_keycloak_user_info
)
mock_token_manager_class.return_value = mock_token_manager
mock_get_member.return_value = None # Not already a member
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because Keycloak email matches
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
str(user_id)
)
@pytest.mark.asyncio
async def test_accept_invitation_no_email_anywhere_raises_error(
self, mock_invitation
):
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager to return no email
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
mock_token_manager_class.return_value = mock_token_manager
# Act & Assert
with pytest.raises(EmailMismatchError) as exc_info:
await OrgInvitationService.accept_invitation(token, user_id)
assert 'does not have an email address' in str(exc_info.value)
@pytest.mark.asyncio
async def test_accept_invitation_email_comparison_is_case_insensitive(
self, mock_invitation
):
"""Test that email comparison is case insensitive."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
mock_get_member.return_value = None
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because emails match case-insensitively
await OrgInvitationService.accept_invitation(token, user_id)
# Assert - invitation was accepted (update_invitation_status was called)
mock_update_status.assert_called_once()
class TestCreateInvitationsBatch:
"""Test cases for batch invitation creation."""
@pytest.fixture
def org_id(self):
"""Organization UUID for testing."""
return UUID('12345678-1234-5678-1234-567812345678')
@pytest.fixture
def inviter_id(self):
"""Inviter UUID for testing."""
return UUID('87654321-4321-8765-4321-876543218765')
@pytest.fixture
def mock_org(self):
"""Create a mock organization."""
org = MagicMock()
org.id = UUID('12345678-1234-5678-1234-567812345678')
org.name = 'Test Org'
return org
@pytest.fixture
def mock_inviter_member(self):
"""Create a mock inviter member with owner role."""
member = MagicMock()
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
member.role_id = 1
return member
@pytest.fixture
def mock_owner_role(self):
"""Create a mock owner role."""
role = MagicMock()
role.id = 1
role.name = 'owner'
return role
@pytest.fixture
def mock_member_role(self):
"""Create a mock member role."""
role = MagicMock()
role.id = 3
role.name = 'member'
return role
@pytest.mark.asyncio
async def test_batch_creates_all_invitations_successfully(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch creation succeeds for all valid emails."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
mock_invitation_1 = MagicMock(spec=OrgInvitation)
mock_invitation_1.id = 1
mock_invitation_2 = MagicMock(spec=OrgInvitation)
mock_invitation_2.id = 2
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation_1, mock_invitation_2],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 2
assert len(failed) == 0
@pytest.mark.asyncio
async def test_batch_handles_partial_success(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch returns partial results when some emails fail."""
# Arrange
from server.routes.org_invitation_models import UserAlreadyMemberError
emails = ['alice@example.com', 'existing@example.com']
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation, UserAlreadyMemberError()],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 1
assert len(failed) == 1
assert failed[0][0] == 'existing@example.com'
@pytest.mark.asyncio
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
"""Test that permission error fails the entire batch upfront."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
with patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=None, # Organization not found
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
assert 'not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_batch_fails_on_invalid_role(
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
):
"""Test that invalid role fails the entire batch."""
# Arrange
emails = ['alice@example.com']
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=None, # Invalid role
),
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='invalid_role',
inviter_id=inviter_id,
)
assert 'Invalid role' in str(exc_info.value)

View File

@@ -0,0 +1,308 @@
"""Tests for organization invitation store."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import (
INVITATION_TOKEN_LENGTH,
INVITATION_TOKEN_PREFIX,
OrgInvitationStore,
)
class TestGenerateToken:
"""Test cases for token generation."""
def test_generate_token_has_correct_prefix(self):
"""Test that generated tokens have the correct prefix."""
token = OrgInvitationStore.generate_token()
assert token.startswith(INVITATION_TOKEN_PREFIX)
def test_generate_token_has_correct_length(self):
"""Test that generated tokens have the correct total length."""
token = OrgInvitationStore.generate_token()
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
assert len(token) == expected_length
def test_generate_token_uses_alphanumeric_characters(self):
"""Test that generated tokens use only alphanumeric characters."""
token = OrgInvitationStore.generate_token()
# Remove prefix and check the rest is alphanumeric
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
assert random_part.isalnum()
def test_generate_token_is_unique(self):
"""Test that generated tokens are unique (probabilistically)."""
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
assert len(set(tokens)) == 100
class TestIsTokenExpired:
"""Test cases for token expiration checking."""
def test_token_not_expired_when_future(self):
"""Test that tokens with future expiration are not expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is False
def test_token_expired_when_past(self):
"""Test that tokens with past expiration are expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
def test_token_expired_at_exact_boundary(self):
"""Test that tokens at exact expiration time are expired."""
# A token that expires "now" should be expired
now = datetime.utcnow()
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = now - timedelta(microseconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
class TestCreateInvitation:
"""Test cases for invitation creation."""
@pytest.mark.asyncio
async def test_create_invitation_normalizes_email(self):
"""Test that email is normalized (lowercase, stripped) on creation."""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.execute = AsyncMock()
# Mock the result of the re-fetch query
mock_result = MagicMock()
mock_invitation = MagicMock()
mock_invitation.id = 1
mock_invitation.email = 'test@example.com'
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute.return_value = mock_result
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.create_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
role_id=1,
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
)
# Verify that the OrgInvitation was created with normalized email
add_call = mock_session.add.call_args
created_invitation = add_call[0][0]
assert created_invitation.email == 'test@example.com'
class TestGetInvitationByToken:
"""Test cases for getting invitation by token."""
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_invitation(self):
"""Test that get_invitation_by_token returns the invitation when found."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.token = 'inv-test-token-12345'
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-test-token-12345'
)
assert result == mock_invitation
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_none_when_not_found(self):
"""Test that get_invitation_by_token returns None when not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-nonexistent-token'
)
assert result is None
class TestGetPendingInvitation:
"""Test cases for getting pending invitation."""
@pytest.mark.asyncio
async def test_get_pending_invitation_normalizes_email(self):
"""Test that email is normalized when querying for pending invitations."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.get_pending_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
)
# Verify the query was called (email normalization happens in the filter)
assert mock_session.execute.called
class TestUpdateInvitationStatus:
"""Test cases for updating invitation status."""
@pytest.mark.asyncio
async def test_update_status_sets_accepted_at_for_accepted(self):
"""Test that accepted_at is set when status is accepted."""
from uuid import UUID
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
user_id = UUID('87654321-4321-8765-4321-876543218765')
await OrgInvitationStore.update_invitation_status(
invitation_id=1,
status=OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
assert mock_invitation.accepted_at is not None
assert mock_invitation.accepted_by_user_id == user_id
@pytest.mark.asyncio
async def test_update_status_returns_none_when_not_found(self):
"""Test that update returns None when invitation not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.update_invitation_status(
invitation_id=999,
status=OrgInvitation.STATUS_ACCEPTED,
)
assert result is None
class TestMarkExpiredIfNeeded:
"""Test cases for marking expired invitations."""
@pytest.mark.asyncio
async def test_marks_expired_when_pending_and_past_expiry(self):
"""Test that pending expired invitations are marked as expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is True
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
@pytest.mark.asyncio
async def test_does_not_mark_when_not_expired(self):
"""Test that non-expired invitations are not marked."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()
@pytest.mark.asyncio
async def test_does_not_mark_when_not_pending(self):
"""Test that non-pending invitations are not marked even if expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()

View File

@@ -0,0 +1,388 @@
"""Tests for organization invitations API router."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.routes.org_invitations import accept_router, invitation_router
@pytest.fixture
def app():
"""Create a FastAPI app with the invitation routers."""
app = FastAPI()
app.include_router(invitation_router)
app.include_router(accept_router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
class TestRouterPrefixes:
"""Test that router prefixes are configured correctly."""
def test_invitation_router_has_correct_prefix(self):
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
def test_accept_router_has_correct_prefix(self):
"""Test that accept_router has /api/organizations/members/invite prefix."""
assert accept_router.prefix == '/api/organizations/members/invite'
class TestAcceptInvitationEndpoint:
"""Test cases for the accept invitation endpoint."""
@pytest.fixture
def mock_user_auth(self):
"""Create a mock user auth."""
user_auth = MagicMock()
user_auth.get_user_id = AsyncMock(
return_value='87654321-4321-8765-4321-876543218765'
)
return user_auth
@pytest.mark.asyncio
async def test_accept_unauthenticated_redirects_to_login(self, client):
"""Test that unauthenticated users are redirected to login with invitation token."""
with patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=None,
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
'location', ''
)
@pytest.mark.asyncio
async def test_accept_authenticated_success_redirects_home(
self, client, mock_user_auth
):
"""Test that successful acceptance redirects to home page."""
mock_invitation = MagicMock()
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
return_value=mock_invitation,
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
location = response.headers.get('location', '')
assert location.endswith('/')
assert 'invitation_expired' not in location
assert 'invitation_invalid' not in location
assert 'email_mismatch' not in location
@pytest.mark.asyncio
async def test_accept_expired_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that expired invitation redirects with invitation_expired=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationExpiredError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_expired=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_invalid_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that invalid invitation redirects with invitation_invalid=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationInvalidError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_invalid=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_already_member_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that already member error redirects with already_member=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=UserAlreadyMemberError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'already_member=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_email_mismatch_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that email mismatch error redirects with email_mismatch=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=EmailMismatchError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'email_mismatch=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_unexpected_error_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that unexpected errors redirect with invitation_error=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=Exception('Unexpected error'),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_error=true' in response.headers.get('location', '')
class TestCreateInvitationBatchEndpoint:
"""Test cases for the batch invitation creation endpoint."""
@pytest.fixture
def batch_app(self):
"""Create a FastAPI app with dependency overrides for batch tests."""
from openhands.server.user_auth import get_user_id
app = FastAPI()
app.include_router(invitation_router)
# Override the get_user_id dependency
app.dependency_overrides[get_user_id] = (
lambda: '87654321-4321-8765-4321-876543218765'
)
return app
@pytest.fixture
def batch_client(self, batch_app):
"""Create a test client with dependency overrides."""
return TestClient(batch_app)
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation."""
from datetime import datetime
invitation = MagicMock()
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.role = MagicMock(name='member')
invitation.role.name = 'member'
invitation.role_id = 3
invitation.status = 'pending'
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
return invitation
@pytest.mark.asyncio
async def test_batch_create_returns_successful_invitations(
self, batch_client, mock_invitation
):
"""Test that batch creation returns successful invitations."""
mock_invitation_2 = MagicMock()
mock_invitation_2.id = 2
mock_invitation_2.email = 'bob@example.com'
mock_invitation_2.role = MagicMock()
mock_invitation_2.role.name = 'member'
mock_invitation_2.role_id = 3
mock_invitation_2.status = 'pending'
mock_invitation_2.created_at = mock_invitation.created_at
mock_invitation_2.expires_at = mock_invitation.expires_at
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation, mock_invitation_2], []),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'bob@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 2
assert len(data['failed']) == 0
@pytest.mark.asyncio
async def test_batch_create_returns_partial_success(
self, batch_client, mock_invitation
):
"""Test that batch creation returns both successful and failed invitations."""
failed_emails = [('existing@example.com', 'User is already a member')]
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation], failed_emails),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'existing@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 1
assert len(data['failed']) == 1
assert data['failed'][0]['email'] == 'existing@example.com'
assert 'already a member' in data['failed'][0]['error']
@pytest.mark.asyncio
async def test_batch_create_permission_denied_returns_403(self, batch_client):
"""Test that permission denied returns 403 for entire batch."""
from server.routes.org_invitation_models import InsufficientPermissionError
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=InsufficientPermissionError(
'Only owners and admins can invite'
),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'member'},
)
assert response.status_code == 403
assert 'owners and admins' in response.json()['detail']
@pytest.mark.asyncio
async def test_batch_create_invalid_role_returns_400(self, batch_client):
"""Test that invalid role returns 400."""
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=ValueError('Invalid role: superuser'),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'superuser'},
)
assert response.status_code == 400
assert 'Invalid role' in response.json()['detail']

View File

@@ -158,6 +158,57 @@ def test_get_org_member(session_maker):
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
def test_get_org_member_for_current_org(session_maker):
# Test getting org_member for user's current organization
with session_maker() as session:
# Create test data - user belongs to two orgs but current_org is org1
org1 = Org(name='test-org-1')
org2 = Org(name='test-org-2')
session.add_all([org1, org2])
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org1.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member1 = OrgMember(
org_id=org1.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org2.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
user_id = user.id
org1_id = org1.id
# Test retrieval - should return org_member for current_org (org1)
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(user_id)
assert retrieved_org_member is not None
assert retrieved_org_member.org_id == org1_id
assert retrieved_org_member.user_id == user_id
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key-1'
def test_get_org_member_for_current_org_user_not_found(session_maker):
# Test getting org_member for non-existent user
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(
uuid.uuid4()
)
assert retrieved_org_member is None
def test_add_user_to_org(session_maker):
# Test adding a user to an org
with session_maker() as session:
@@ -604,3 +655,180 @@ async def test_get_org_members_paginated_eager_loading(async_session_maker):
assert member.role is not None
assert member.role.name == 'owner'
assert member.role.rank == 10
@pytest.mark.asyncio
async def test_get_org_members_count_no_filter(async_session_maker):
"""Test get_org_members_count returns correct count without email filter."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
for i in range(5)
]
session.add_all(users)
await session.flush()
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
count = await OrgMemberStore.get_org_members_count(org_id=org_id)
# Assert
assert count == 5
@pytest.mark.asyncio
async def test_get_org_members_count_with_email_filter(async_session_maker):
"""Test get_org_members_count filters by email correctly."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email='alice@example.com'),
User(id=uuid.uuid4(), current_org_id=org.id, email='bob@example.com'),
User(
id=uuid.uuid4(), current_org_id=org.id, email='alice.smith@example.com'
),
]
session.add_all(users)
await session.flush()
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
count = await OrgMemberStore.get_org_members_count(
org_id=org_id, email_filter='alice'
)
# Assert
assert count == 2
@pytest.mark.asyncio
async def test_get_org_members_paginated_with_email_filter(async_session_maker):
"""Test get_org_members_paginated filters by email correctly."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
users = [
User(id=uuid.uuid4(), current_org_id=org.id, email='alice@example.com'),
User(id=uuid.uuid4(), current_org_id=org.id, email='bob@example.com'),
User(id=uuid.uuid4(), current_org_id=org.id, email='charlie@example.com'),
]
session.add_all(users)
await session.flush()
org_members = [
OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key=f'test-key-{i}',
status='active',
)
for i, user in enumerate(users)
]
session.add_all(org_members)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=10, email_filter='bob'
)
# Assert
assert len(members) == 1
assert members[0].user.email == 'bob@example.com'
assert has_more is False
@pytest.mark.asyncio
async def test_get_org_members_paginated_email_filter_case_insensitive(
async_session_maker,
):
"""Test email filter is case-insensitive."""
# Arrange
async with async_session_maker() as session:
org = Org(name='test-org')
session.add(org)
await session.flush()
role = Role(name='admin', rank=1)
session.add(role)
await session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id, email='Alice@Example.COM')
session.add(user)
await session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
await session.commit()
org_id = org.id
# Act
with patch('storage.org_member_store.a_session_maker', async_session_maker):
members, has_more = await OrgMemberStore.get_org_members_paginated(
org_id=org_id, offset=0, limit=10, email_filter='alice@example'
)
# Assert
assert len(members) == 1
assert members[0].user.email == 'Alice@Example.COM'

View File

@@ -482,7 +482,7 @@ async def test_get_org_credits_success(mock_litellm_api):
spend = 25.0
mock_team_info = {
'litellm_budget_table': {'max_budget': max_budget},
'max_budget_in_team': max_budget,
'spend': spend,
}

View File

@@ -10,6 +10,7 @@ with patch('storage.database.engine', create=True), patch(
'storage.database.a_engine', create=True
):
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.org_store import OrgStore
from storage.role import Role
@@ -806,3 +807,88 @@ def test_orphaned_user_error_contains_user_ids():
assert error.user_ids == user_ids
assert '2 user(s)' in str(error)
assert 'no remaining organization' in str(error)
def test_org_deletion_with_invitations_uses_passive_deletes(
session_maker, mock_litellm_api
):
"""
GIVEN: Organization has associated invitations with non-nullable org_id foreign key
WHEN: Organization is deleted via SQLAlchemy session.delete()
THEN: Deletion succeeds without NOT NULL constraint violation
(passive_deletes=True defers to database CASCADE instead of setting org_id to NULL)
This test verifies the fix for the bug where SQLAlchemy would try to
SET org_id=NULL on org_invitation before deleting the org, causing:
"NOT NULL constraint failed: org_invitation.org_id"
With passive_deletes=True on the relationship, SQLAlchemy defers to the
database's CASCADE constraint instead of trying to nullify the foreign key.
Note: SQLite doesn't enforce CASCADE by default, so we only verify that
the deletion succeeds. In production (PostgreSQL), CASCADE handles cleanup.
"""
from datetime import datetime, timedelta
# Arrange
org_id = uuid.uuid4()
other_org_id = uuid.uuid4()
user_id = uuid.uuid4()
with session_maker() as session:
# Create role first (required for invitation)
role = Role(id=1, name='owner', rank=1)
session.add(role)
session.flush()
# Create organization to be deleted
org = Org(id=org_id, name='test-org-with-invitations')
session.add(org)
session.flush()
# Create a second org for the user's current_org_id
# (to avoid the user.current_org_id constraint issue during deletion)
other_org = Org(id=other_org_id, name='other-org')
session.add(other_org)
session.flush()
# Create user with current_org pointing to the OTHER org (not the one being deleted)
user = User(id=user_id, current_org_id=other_org_id)
session.add(user)
session.flush()
# Create invitation associated with the organization to be deleted
invitation = OrgInvitation(
token='test-invitation-token-12345',
org_id=org_id,
email='invitee@example.com',
role_id=1,
inviter_id=user_id,
status='pending',
created_at=datetime.now(),
expires_at=datetime.now() + timedelta(days=7),
)
session.add(invitation)
session.commit()
# Verify invitation was created
invitation_count = session.query(OrgInvitation).filter_by(org_id=org_id).count()
assert invitation_count == 1
# Act - Delete organization via SQLAlchemy (this is what triggered the bug)
# Without passive_deletes=True, SQLAlchemy would try to SET org_id=NULL
# which violates the NOT NULL constraint on org_invitation.org_id
with session_maker() as session:
org = session.query(Org).filter(Org.id == org_id).first()
assert org is not None
# This should NOT raise IntegrityError with passive_deletes=True
# Previously this would fail with:
# "NOT NULL constraint failed: org_invitation.org_id"
session.delete(org)
session.commit() # Success indicates passive_deletes=True is working
# Assert - Organization should be deleted
with session_maker() as session:
deleted_org = session.query(Org).filter(Org.id == org_id).first()
assert deleted_org is None

View File

@@ -398,6 +398,121 @@ async def test_create_user_contact_name_falls_back_to_username():
assert org.contact_name == 'jdoe'
# --- Tests for email fields in create_user() ---
# create_user() should populate user.email and user.email_verified from the
# Keycloak user_info, ensuring the user table has the correct email data.
class _StopAfterUserCreation(Exception):
"""Halt create_user() after User creation for email field inspection."""
pass
@pytest.mark.asyncio
async def test_create_user_sets_email_from_user_info():
"""create_user() should set user.email and user.email_verified from user_info."""
# Arrange
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'testuser',
'email': 'testuser@example.com',
'email_verified': True,
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(language='en')
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
patch.object(
mock_session,
'commit',
side_effect=_StopAfterUserCreation,
),
):
# Act
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
# Assert - User is the second object added to session (after Org)
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'testuser@example.com'
assert user.email_verified is True
@pytest.mark.asyncio
async def test_create_user_handles_missing_email_verified():
"""create_user() should handle missing email_verified in user_info gracefully."""
# Arrange
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'testuser',
'email': 'testuser@example.com',
# email_verified is not present
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(language='en')
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
patch.object(
mock_session,
'commit',
side_effect=_StopAfterUserCreation,
),
):
# Act
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
# Assert - User should have email but email_verified should be None
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'testuser@example.com'
assert user.email_verified is None
# --- Tests for backfill_contact_name on login ---
# Existing users created before the resolve_display_name fix may have
# username-style values in contact_name. The backfill updates these to

View File

@@ -1,9 +1,27 @@
import { render, screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { afterEach, describe, expect, it, test, vi } from "vitest";
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
import { AccountSettingsContextMenu } from "#/components/features/context-menu/account-settings-context-menu";
import { MemoryRouter } from "react-router";
import { renderWithProviders } from "../../../test-utils";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { createMockWebClientConfig } from "../../helpers/mock-config";
const mockTrackAddTeamMembersButtonClick = vi.fn();
vi.mock("#/hooks/use-tracking", () => ({
useTracking: () => ({
trackAddTeamMembersButtonClick: mockTrackAddTeamMembersButtonClick,
}),
}));
// Mock posthog feature flag
vi.mock("posthog-js/react", () => ({
useFeatureFlagEnabled: vi.fn(),
}));
// Import the mocked module to get access to the mock
import * as posthog from "posthog-js/react";
describe("AccountSettingsContextMenu", () => {
const user = userEvent.setup();
@@ -11,15 +29,45 @@ describe("AccountSettingsContextMenu", () => {
const onLogoutMock = vi.fn();
const onCloseMock = vi.fn();
let queryClient: QueryClient;
beforeEach(() => {
queryClient = new QueryClient({
defaultOptions: { queries: { retry: false } },
});
// Set default feature flag to false
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(false);
});
// Create a wrapper with MemoryRouter and renderWithProviders
const renderWithRouter = (ui: React.ReactElement) => {
return renderWithProviders(<MemoryRouter>{ui}</MemoryRouter>);
};
const renderWithSaasConfig = (ui: React.ReactElement) => {
queryClient.setQueryData(["web-client-config"], createMockWebClientConfig({ app_mode: "saas" }));
return render(
<QueryClientProvider client={queryClient}>
<MemoryRouter>{ui}</MemoryRouter>
</QueryClientProvider>
);
};
const renderWithOssConfig = (ui: React.ReactElement) => {
queryClient.setQueryData(["web-client-config"], createMockWebClientConfig({ app_mode: "oss" }));
return render(
<QueryClientProvider client={queryClient}>
<MemoryRouter>{ui}</MemoryRouter>
</QueryClientProvider>
);
};
afterEach(() => {
onClickAccountSettingsMock.mockClear();
onLogoutMock.mockClear();
onCloseMock.mockClear();
mockTrackAddTeamMembersButtonClick.mockClear();
vi.mocked(posthog.useFeatureFlagEnabled).mockClear();
});
it("should always render the right options", () => {
@@ -93,4 +141,59 @@ describe("AccountSettingsContextMenu", () => {
expect(onCloseMock).toHaveBeenCalledOnce();
});
it("should show Add Team Members button in SaaS mode when feature flag is enabled", () => {
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
renderWithSaasConfig(
<AccountSettingsContextMenu
onLogout={onLogoutMock}
onClose={onCloseMock}
/>,
);
expect(screen.getByTestId("add-team-members-button")).toBeInTheDocument();
expect(screen.getByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).toBeInTheDocument();
});
it("should not show Add Team Members button in SaaS mode when feature flag is disabled", () => {
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(false);
renderWithSaasConfig(
<AccountSettingsContextMenu
onLogout={onLogoutMock}
onClose={onCloseMock}
/>,
);
expect(screen.queryByTestId("add-team-members-button")).not.toBeInTheDocument();
expect(screen.queryByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).not.toBeInTheDocument();
});
it("should not show Add Team Members button in OSS mode even when feature flag is enabled", () => {
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
renderWithOssConfig(
<AccountSettingsContextMenu
onLogout={onLogoutMock}
onClose={onCloseMock}
/>,
);
expect(screen.queryByTestId("add-team-members-button")).not.toBeInTheDocument();
expect(screen.queryByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).not.toBeInTheDocument();
});
it("should call tracking function and onClose when Add Team Members button is clicked", async () => {
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
renderWithSaasConfig(
<AccountSettingsContextMenu
onLogout={onLogoutMock}
onClose={onCloseMock}
/>,
);
const addTeamMembersButton = screen.getByTestId("add-team-members-button");
await user.click(addTeamMembersButton);
expect(mockTrackAddTeamMembersButtonClick).toHaveBeenCalledOnce();
expect(onCloseMock).toHaveBeenCalledOnce();
});
});

View File

@@ -151,8 +151,9 @@ describe("LoginContent", () => {
await user.click(githubButton);
// Wait for async handleAuthRedirect to complete
// The URL includes state parameter added by handleAuthRedirect
await waitFor(() => {
expect(window.location.href).toBe(mockUrl);
expect(window.location.href).toContain(mockUrl);
});
});
@@ -201,4 +202,103 @@ describe("LoginContent", () => {
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
});
it("should display invitation pending message when hasInvitation is true", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation
/>
</MemoryRouter>,
);
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
it("should not display invitation pending message when hasInvitation is false", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should call buildOAuthStateData when clicking auth button", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
expect(mockBuildOAuthStateData).toHaveBeenCalled();
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
expect(callArg).toHaveProperty("redirect_url");
});
});
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
const redirectUrl = window.location.href;
// The URL should contain an encoded state parameter
expect(redirectUrl).toContain("state=");
// Decode and verify the state contains invitation_token
const url = new URL(redirectUrl);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
});
});

View File

@@ -358,6 +358,30 @@ describe("Conversation WebSocket Handler", () => {
});
});
it("should show friendly i18n message for budget ConversationErrorEvent", async () => {
const mockBudgetConversationError = createMockConversationErrorEvent({
detail:
"Budget has been exceeded! Current cost: 18.51, Max budget: 18.24",
});
mswServer.use(
wsLink.addEventListener("connection", ({ client, server }) => {
server.connect();
client.send(JSON.stringify(mockBudgetConversationError));
}),
);
renderWithWebSocketContext(<ErrorMessageStoreComponent />);
expect(screen.getByTestId("error-message")).toHaveTextContent("none");
await waitFor(() => {
expect(screen.getByTestId("error-message")).toHaveTextContent(
"STATUS$ERROR_LLM_OUT_OF_CREDITS",
);
});
});
it("should set error message store on WebSocket connection errors", async () => {
// Simulate a connect-then-fail sequence (the MSW server auto-connects by default).
// This should surface an error message because the app has previously connected.

View File

@@ -0,0 +1,104 @@
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { describe, expect, it, vi } from "vitest";
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
import { SuggestedTask } from "#/utils/types";
vi.mock("#/hooks/query/use-settings", async () => {
const actual = await vi.importActual<typeof import("#/hooks/query/use-settings")>(
"#/hooks/query/use-settings",
);
return {
...actual,
useSettings: vi.fn().mockReturnValue({
data: {
v1_enabled: true,
},
isLoading: false,
}),
};
});
vi.mock("#/hooks/use-tracking", () => ({
useTracking: () => ({
trackConversationCreated: vi.fn(),
}),
}));
describe("useCreateConversation", () => {
it("passes suggested tasks to the V1 create conversation API", async () => {
const createConversationSpy = vi
.spyOn(V1ConversationService, "createConversation")
.mockResolvedValue({
id: "task-id",
created_by_user_id: null,
status: "READY",
detail: null,
app_conversation_id: null,
sandbox_id: null,
agent_server_url: "http://agent-server.local",
request: {
sandbox_id: null,
initial_message: {
role: "user",
content: [{ type: "text", text: "Please address the comments" }],
},
processors: [],
llm_model: null,
selected_repository: null,
selected_branch: null,
git_provider: "github",
suggested_task: null,
title: null,
trigger: null,
pr_number: [],
parent_conversation_id: null,
agent_type: "default",
},
created_at: new Date().toISOString(),
updated_at: new Date().toISOString(),
});
const { result } = renderHook(() => useCreateConversation(), {
wrapper: ({ children }) => (
<QueryClientProvider client={new QueryClient()}>
{children}
</QueryClientProvider>
),
});
const suggestedTask: SuggestedTask = {
git_provider: "github",
issue_number: 42,
repo: "owner/repo",
title: "Resolve comments",
task_type: "UNRESOLVED_COMMENTS",
};
await result.current.mutateAsync({
query: "Please address the comments",
repository: {
name: "owner/repo",
gitProvider: "github",
branch: "main",
},
conversationInstructions: "Focus on review comments",
suggestedTask,
});
await waitFor(() => {
expect(createConversationSpy).toHaveBeenCalledWith(
"owner/repo",
"github",
"Please address the comments",
"main",
"Focus on review comments",
suggestedTask,
undefined,
undefined,
undefined,
);
});
});
});

View File

@@ -0,0 +1,113 @@
import { describe, expect, it, vi, beforeEach } from "vitest";
import { renderHook, waitFor } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import React from "react";
import { useGitUser } from "#/hooks/query/use-git-user";
import { useLogout } from "#/hooks/mutation/use-logout";
import UserService from "#/api/user-service/user-service.api";
import * as useShouldShowUserFeaturesModule from "#/hooks/use-should-show-user-features";
import * as useConfigModule from "#/hooks/query/use-config";
import { AxiosError } from "axios";
vi.mock("#/hooks/use-should-show-user-features");
vi.mock("#/hooks/query/use-config");
vi.mock("#/hooks/mutation/use-logout");
vi.mock("#/api/user-service/user-service.api");
vi.mock("posthog-js/react", () => ({
usePostHog: vi.fn(() => ({
identify: vi.fn(),
})),
}));
describe("useGitUser", () => {
let mockLogout: ReturnType<typeof useLogout>;
beforeEach(() => {
vi.clearAllMocks();
mockLogout = {
mutate: vi.fn(),
mutateAsync: vi.fn(),
data: undefined,
error: null,
isPending: false,
isSuccess: false,
isError: false,
isIdle: true,
reset: vi.fn(),
status: "idle",
} as unknown as ReturnType<typeof useLogout>;
vi.mocked(useShouldShowUserFeaturesModule.useShouldShowUserFeatures).mockReturnValue(true);
vi.mocked(useConfigModule.useConfig).mockReturnValue({
data: { app_mode: "saas" },
isLoading: false,
error: null,
} as any);
vi.mocked(useLogout).mockReturnValue(mockLogout);
});
const createWrapper = () => {
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
},
},
});
return ({ children }: { children: React.ReactNode }) => (
<QueryClientProvider client={queryClient}>
{children}
</QueryClientProvider>
);
};
it("should call logout when receiving a 401 error", async () => {
// Mock the user service to throw a 401 error
const mockError = new AxiosError("Unauthorized", "401", undefined, undefined, {
status: 401,
data: { message: "Unauthorized" },
} as any);
vi.mocked(UserService.getUser).mockRejectedValue(mockError);
const { result } = renderHook(() => useGitUser(), {
wrapper: createWrapper(),
});
// Wait for the query to fail (status becomes 'error')
await waitFor(() => {
expect(result.current.status).toBe("error");
});
// Wait for the useEffect to trigger logout
await waitFor(() => {
expect(mockLogout.mutate).toHaveBeenCalled();
});
});
it("should not call logout for non-401 errors", async () => {
// Mock the user service to throw a 500 error
const mockError = new AxiosError("Server Error", "500", undefined, undefined, {
status: 500,
data: { message: "Internal Server Error" },
} as any);
vi.mocked(UserService.getUser).mockRejectedValue(mockError);
const { result } = renderHook(() => useGitUser(), {
wrapper: createWrapper(),
});
// Wait for the query to fail (status becomes 'error')
await waitFor(() => {
expect(result.current.status).toBe("error");
});
// Wait a bit to ensure logout is not called
await waitFor(() => {
expect(mockLogout.mutate).not.toHaveBeenCalled();
});
});
});

View File

@@ -0,0 +1,170 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
// Mock setSearchParams function
const mockSetSearchParams = vi.fn();
// Default mock searchParams
let mockSearchParamsData: Record<string, string> = {};
// Mock react-router
vi.mock("react-router", () => ({
useSearchParams: () => [
{
get: (key: string) => mockSearchParamsData[key] || null,
has: (key: string) => key in mockSearchParamsData,
},
mockSetSearchParams,
],
}));
// Import after mocking
import { useInvitation } from "#/hooks/use-invitation";
describe("useInvitation", () => {
beforeEach(() => {
// Clear localStorage before each test
localStorage.clear();
// Reset mock data
mockSearchParamsData = {};
mockSetSearchParams.mockClear();
});
afterEach(() => {
vi.clearAllMocks();
});
describe("initialization", () => {
it("should initialize with null token when localStorage is empty", () => {
// Arrange - localStorage is empty (cleared in beforeEach)
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
it("should initialize with token from localStorage if present", () => {
// Arrange
const storedToken = "inv-stored-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBe(storedToken);
expect(result.current.hasInvitation).toBe(true);
});
});
describe("URL token capture", () => {
it("should capture invitation_token from URL and store in localStorage", () => {
// Arrange
const urlToken = "inv-url-token-67890";
mockSearchParamsData = { invitation_token: urlToken };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
expect(mockSetSearchParams).toHaveBeenCalled();
});
});
describe("completion cleanup", () => {
it("should clear localStorage when email_mismatch param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { email_mismatch: "true" };
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(mockSetSearchParams).toHaveBeenCalled();
});
it("should clear localStorage when invitation_success param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { invitation_success: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
it("should clear localStorage when invitation_expired param is present", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
mockSearchParamsData = { invitation_expired: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
});
describe("buildOAuthStateData", () => {
it("should include invitation_token in OAuth state when token is present", () => {
// Arrange
const token = "inv-oauth-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, token);
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBe(token);
expect(stateData.redirect_url).toBe("/dashboard");
});
it("should not include invitation_token when no token is present", () => {
// Arrange - no token in localStorage
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBeUndefined();
expect(stateData.redirect_url).toBe("/dashboard");
});
});
describe("clearInvitation", () => {
it("should remove token from localStorage when called", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
const { result } = renderHook(() => useInvitation());
// Act
act(() => {
result.current.clearInvitation();
});
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
});
});

View File

@@ -9,7 +9,10 @@ import {
} from "vitest";
import { screen, waitFor, render, cleanup } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { createMockAgentErrorEvent } from "#/mocks/mock-ws-helpers";
import {
createMockAgentErrorEvent,
createMockConversationErrorEvent,
} from "#/mocks/mock-ws-helpers";
import { ConversationWebSocketProvider } from "#/contexts/conversation-websocket-context";
import { conversationWebSocketTestSetup } from "./helpers/msw-websocket-setup";
import { ConnectionStatusComponent } from "./helpers/websocket-test-components";
@@ -229,5 +232,35 @@ describe("PostHog Analytics Tracking", () => {
}),
);
});
it("should track credit_limit_reached when ConversationErrorEvent contains budget error", async () => {
const mockBudgetConversationError = createMockConversationErrorEvent({
detail:
"Budget has been exceeded! Current cost: 18.51, Max budget: 18.24",
});
mswServer.use(
wsLink.addEventListener("connection", ({ client, server }) => {
server.connect();
client.send(JSON.stringify(mockBudgetConversationError));
}),
);
renderWithProviders(<ConnectionStatusComponent />);
await waitFor(() => {
expect(screen.getByTestId("connection-state")).toHaveTextContent(
"OPEN",
);
});
await waitFor(() => {
expect(mockTrackCreditLimitReached).toHaveBeenCalledWith(
expect.objectContaining({
conversationId: "test-conversation-123",
}),
);
});
});
});
});

View File

@@ -57,6 +57,22 @@ vi.mock("#/hooks/use-tracking", () => ({
}),
}));
const { useInvitationMock, buildOAuthStateDataMock } = vi.hoisted(() => ({
useInvitationMock: vi.fn(() => ({
invitationToken: null as string | null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
})),
buildOAuthStateDataMock: vi.fn(
(baseState: Record<string, string>) => baseState,
),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => useInvitationMock(),
}));
const RouterStub = createRoutesStub([
{
Component: LoginPage,
@@ -234,7 +250,8 @@ describe("LoginPage", () => {
});
await user.click(githubButton);
expect(window.location.href).toBe(mockUrl);
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(mockUrl);
});
it("should redirect to GitLab auth URL when GitLab button is clicked", async () => {
@@ -255,7 +272,8 @@ describe("LoginPage", () => {
});
await user.click(gitlabButton);
expect(window.location.href).toBe("https://gitlab.com/oauth/authorize");
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain("https://gitlab.com/oauth/authorize");
});
it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => {
@@ -282,7 +300,8 @@ describe("LoginPage", () => {
});
await user.click(bitbucketButton);
expect(window.location.href).toBe(
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(
"https://bitbucket.org/site/oauth2/authorize",
);
});
@@ -479,4 +498,137 @@ describe("LoginPage", () => {
});
});
});
describe("Invitation Flow", () => {
it("should display invitation pending message when hasInvitation is true", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
it("should not display invitation pending message when hasInvitation is false", async () => {
useInvitationMock.mockReturnValue({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByTestId("login-content")).toBeInTheDocument();
});
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should pass buildOAuthStateData to LoginContent for OAuth state encoding", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// buildOAuthStateData should have been called during the OAuth redirect
expect(mockBuildOAuthStateData).toHaveBeenCalled();
});
it("should include invitation token in OAuth state when invitation is present", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// Verify the redirect URL contains the state with invitation token
await waitFor(() => {
expect(window.location.href).toContain("state=");
});
// Decode and verify the state contains invitation_token
const url = new URL(window.location.href);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
it("should handle login with invitation_token URL parameter", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-url-token-67890",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login?invitation_token=inv-url-token-67890"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
});
});

View File

@@ -42,6 +42,15 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: vi.fn(),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => ({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
}),
}));
function LoginStub() {
const [searchParams] = useSearchParams();
const emailVerificationRequired =
@@ -353,4 +362,68 @@ describe("MainApp", () => {
);
});
});
describe("Invitation URL Parameters", () => {
beforeEach(() => {
vi.spyOn(AuthService, "authenticate").mockRejectedValue({
response: { status: 401 },
isAxiosError: true,
});
});
it("should redirect to login when email_mismatch=true is in query params", async () => {
renderMainApp(["/?email_mismatch=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_success=true is in query params", async () => {
renderMainApp(["/?invitation_success=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_expired=true is in query params", async () => {
renderMainApp(["/?invitation_expired=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_invalid=true is in query params", async () => {
renderMainApp(["/?invitation_invalid=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when already_member=true is in query params", async () => {
renderMainApp(["/?already_member=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
});
});

View File

@@ -1,12 +1,12 @@
{
"name": "openhands-frontend",
"version": "1.3.0",
"version": "1.4.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "openhands-frontend",
"version": "1.3.0",
"version": "1.4.0",
"dependencies": {
"@heroui/react": "2.8.7",
"@microlink/react-json-view": "^1.27.1",

View File

@@ -1,6 +1,6 @@
{
"name": "openhands-frontend",
"version": "1.3.0",
"version": "1.4.0",
"private": true,
"type": "module",
"engines": {

View File

@@ -2,6 +2,7 @@ import axios from "axios";
import { openHands } from "../open-hands-axios";
import { ConversationTrigger, GetVSCodeUrlResponse } from "../open-hands.types";
import { Provider } from "#/types/settings";
import { SuggestedTask } from "#/utils/types";
import { buildHttpBaseUrl } from "#/utils/websocket-url";
import { buildSessionHeaders } from "#/utils/utils";
import type {
@@ -61,6 +62,7 @@ class V1ConversationService {
initialUserMsg?: string,
selected_branch?: string,
conversationInstructions?: string,
suggestedTask?: SuggestedTask,
trigger?: ConversationTrigger,
parent_conversation_id?: string,
agent_type?: "default" | "plan",
@@ -69,14 +71,15 @@ class V1ConversationService {
selected_repository: selectedRepository,
git_provider,
selected_branch,
suggested_task: suggestedTask,
title: conversationInstructions,
trigger,
parent_conversation_id: parent_conversation_id || null,
agent_type,
};
// Add initial message if provided
if (initialUserMsg) {
// suggested_task implies the backend will construct the initial_message
if (!suggestedTask && initialUserMsg) {
body.initial_message = {
role: "user",
content: [

View File

@@ -1,6 +1,7 @@
import { ConversationTrigger } from "../open-hands.types";
import { Provider } from "#/types/settings";
import { V1SandboxStatus } from "../sandbox-service/sandbox-service.types";
import { Provider } from "#/types/settings";
import { SuggestedTask } from "#/utils/types";
// V1 Metrics Types
export interface V1TokenUsage {
@@ -47,6 +48,7 @@ export interface V1AppConversationStartRequest {
selected_repository?: string | null;
selected_branch?: string | null;
git_provider?: Provider | null;
suggested_task?: SuggestedTask | null;
title?: string | null;
trigger?: ConversationTrigger | null;
pr_number?: number[];

View File

@@ -21,6 +21,10 @@ export interface LoginContentProps {
emailVerified?: boolean;
hasDuplicatedEmail?: boolean;
recaptchaBlocked?: boolean;
hasInvitation?: boolean;
buildOAuthStateData?: (
baseStateData: Record<string, string>,
) => Record<string, string>;
}
export function LoginContent({
@@ -31,6 +35,8 @@ export function LoginContent({
emailVerified = false,
hasDuplicatedEmail = false,
recaptchaBlocked = false,
hasInvitation = false,
buildOAuthStateData,
}: LoginContentProps) {
const { t } = useTranslation();
const { trackLoginButtonClick } = useTracking();
@@ -59,31 +65,36 @@ export function LoginContent({
) => {
trackLoginButtonClick({ provider });
if (!config?.recaptcha_site_key || !recaptchaReady) {
// No reCAPTCHA or token generation failed - redirect normally
window.location.href = redirectUrl;
return;
const url = new URL(redirectUrl);
const currentState =
url.searchParams.get("state") || window.location.origin;
// Build base state data
let stateData: Record<string, string> = {
redirect_url: currentState,
};
// Add invitation token if present
if (buildOAuthStateData) {
stateData = buildOAuthStateData(stateData);
}
// If reCAPTCHA is configured, encode token in OAuth state
try {
const token = await executeRecaptcha("LOGIN");
if (token) {
const url = new URL(redirectUrl);
const currentState =
url.searchParams.get("state") || window.location.origin;
// Encode state with reCAPTCHA token for backend verification
const stateData = {
redirect_url: currentState,
recaptcha_token: token,
};
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
// If reCAPTCHA is configured, add token to state
if (config?.recaptcha_site_key && recaptchaReady) {
try {
const token = await executeRecaptcha("LOGIN");
if (token) {
stateData.recaptcha_token = token;
}
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
return;
}
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
}
// Encode state and redirect
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
};
const handleGitHubAuth = () => {
@@ -123,6 +134,10 @@ export function LoginContent({
const buttonBaseClasses =
"w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed";
const buttonLabelClasses = "text-sm font-medium leading-5 px-1";
const shouldShownHelperText =
emailVerified || hasDuplicatedEmail || recaptchaBlocked || hasInvitation;
return (
<div
className="flex flex-col items-center w-full gap-12.5"
@@ -136,20 +151,29 @@ export function LoginContent({
{t(I18nKey.AUTH$LETS_GET_STARTED)}
</h1>
{emailVerified && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
</p>
)}
{hasDuplicatedEmail && (
<p className="text-sm text-danger text-center">
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
</p>
)}
{recaptchaBlocked && (
<p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
{shouldShownHelperText && (
<div className="flex flex-col items-center gap-3">
{emailVerified && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
</p>
)}
{hasDuplicatedEmail && (
<p className="text-sm text-danger text-center">
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
</p>
)}
{recaptchaBlocked && (
<p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
)}
{hasInvitation && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$INVITATION_PENDING)}
</p>
)}
</div>
)}
<div className="flex flex-col items-center gap-3">

View File

@@ -1,6 +1,7 @@
import React from "react";
import { useTranslation } from "react-i18next";
import { Link } from "react-router";
import { useFeatureFlagEnabled } from "posthog-js/react";
import { ContextMenu } from "#/ui/context-menu";
import { ContextMenuListItem } from "./context-menu-list-item";
import { Divider } from "#/ui/divider";
@@ -8,7 +9,10 @@ import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
import { I18nKey } from "#/i18n/declaration";
import LogOutIcon from "#/icons/log-out.svg?react";
import DocumentIcon from "#/icons/document.svg?react";
import PlusIcon from "#/icons/plus.svg?react";
import { useSettingsNavItems } from "#/hooks/use-settings-nav-items";
import { useConfig } from "#/hooks/query/use-config";
import { useTracking } from "#/hooks/use-tracking";
interface AccountSettingsContextMenuProps {
onLogout: () => void;
@@ -21,9 +25,17 @@ export function AccountSettingsContextMenu({
}: AccountSettingsContextMenuProps) {
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
const { t } = useTranslation();
const { trackAddTeamMembersButtonClick } = useTracking();
const { data: config } = useConfig();
const isAddTeamMemberEnabled = useFeatureFlagEnabled(
"exp_add_team_member_button",
);
// Get navigation items and filter out LLM settings if the feature flag is enabled
const items = useSettingsNavItems();
const isSaasMode = config?.app_mode === "saas";
const showAddTeamMembers = isSaasMode && isAddTeamMemberEnabled;
const navItems = items.map((item) => ({
...item,
icon: React.cloneElement(item.icon, {
@@ -33,6 +45,11 @@ export function AccountSettingsContextMenu({
}));
const handleNavigationClick = () => onClose();
const handleAddTeamMembers = () => {
trackAddTeamMembersButtonClick();
onClose();
};
return (
<ContextMenu
testId="account-settings-context-menu"
@@ -40,6 +57,18 @@ export function AccountSettingsContextMenu({
alignment="right"
className="mt-0 md:right-full md:left-full md:bottom-0 ml-0 w-fit z-[9999]"
>
{showAddTeamMembers && (
<ContextMenuListItem
testId="add-team-members-button"
onClick={handleAddTeamMembers}
className="flex items-center gap-2 p-2 hover:bg-[#5C5D62] rounded h-[30px]"
>
<PlusIcon width={16} height={16} />
<span className="text-white text-sm">
{t(I18nKey.SETTINGS$NAV_ADD_TEAM_MEMBERS)}
</span>
</ContextMenuListItem>
)}
{navItems.map(({ to, text, icon }) => (
<Link key={to} to={to} className="text-decoration-none">
<ContextMenuListItem

View File

@@ -10,6 +10,7 @@ import {
displayErrorToast,
displaySuccessToast,
} from "#/utils/custom-toast-handlers";
import { mutateWithToast } from "#/utils/mutate-with-toast";
import { CreateApiKeyModal } from "./create-api-key-modal";
import { DeleteApiKeyModal } from "./delete-api-key-modal";
import { NewApiKeyModal } from "./new-api-key-modal";
@@ -60,18 +61,10 @@ function LlmApiKeyManager({
const { t } = useTranslation();
const [showLlmApiKey, setShowLlmApiKey] = useState(false);
const handleRefreshLlmApiKey = () => {
refreshLlmApiKey.mutate(undefined, {
onSuccess: () => {
displaySuccessToast(
t(I18nKey.SETTINGS$API_KEY_REFRESHED, {
defaultValue: "API key refreshed successfully",
}),
);
},
onError: () => {
displayErrorToast(t(I18nKey.ERROR$GENERIC));
},
const handleRefreshLlmApiKey = async () => {
await mutateWithToast(refreshLlmApiKey, undefined, {
success: t(I18nKey.SETTINGS$API_KEY_REFRESHED),
error: t(I18nKey.ERROR$GENERIC),
});
};

Some files were not shown because too many files have changed in this diff Show More