Compare commits

..

2 Commits

Author SHA1 Message Date
openhands
23aaa93a60 chore: add verification messaging to contributor emails
Ask contributors to verify their changes work as intended and
submit follow-up fixes if needed.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 01:03:57 +00:00
openhands
9e1e529ea1 feat: add release contributor notification system
This adds a script and GitHub workflow to notify code contributors
when their code goes live as part of a release.

Features:
- Identifies all contributors between two release tags
- Generates personalized HTML emails showing each contributor's commits
- Uses Resend for email delivery (matching existing infrastructure)
- Supports dry-run mode for previewing notifications
- Filters out bot accounts automatically
- Can resolve noreply GitHub emails via API

Co-authored-by: openhands <openhands@all-hands.dev>
2026-02-18 00:59:34 +00:00
85 changed files with 1668 additions and 7925 deletions

View File

@@ -0,0 +1,29 @@
# Feature branch preview for enterprise code
name: Enterprise Preview
# Run on PRs labeled
on:
pull_request:
types: [labeled]
# Match ghcr-build.yml, but don't interrupt it.
concurrency:
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
cancel-in-progress: false
jobs:
# This must happen for the PR Docker workflow when the label is present,
# and also if it's added after the fact. Thus, it exists in both places.
enterprise-preview:
name: Enterprise preview
if: github.event.label.name == 'deploy'
runs-on: blacksmith-4vcpu-ubuntu-2204
steps:
# This should match the version in ghcr-build.yml
- name: Trigger remote job
run: |
curl --fail-with-body -sS -X POST \
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
-H "Accept: application/vnd.github+json" \
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches

View File

@@ -240,6 +240,21 @@ jobs:
# Add build attestations for better security
sbom: true
enterprise-preview:
name: Enterprise preview
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
runs-on: blacksmith-4vcpu-ubuntu-2204
needs: [ghcr_build_enterprise]
steps:
# This should match the version in enterprise-preview.yml
- name: Trigger remote job
run: |
curl --fail-with-body -sS -X POST \
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
-H "Accept: application/vnd.github+json" \
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
# "All Runtime Tests Passed" is a required job for PRs to merge
# We can remove this once the config changes
runtime_tests_check_success:

View File

@@ -0,0 +1,146 @@
# Workflow to notify contributors when their code goes live in a release
#
# This workflow can be triggered:
# 1. Automatically when a production deployment succeeds (via workflow_call)
# 2. Manually via workflow_dispatch to send notifications for any release
#
# To integrate with deploy repo, add a workflow_call to this workflow after
# successful production deployment.
name: Notify Release Contributors
on:
workflow_dispatch:
inputs:
release_tag:
description: 'Release tag (e.g., 1.3.0)'
required: true
type: string
previous_tag:
description: 'Previous release tag (auto-detected if empty)'
required: false
type: string
dry_run:
description: 'Preview only - do not send emails'
required: false
type: boolean
default: true
resolve_emails:
description: 'Try to resolve emails from GitHub profiles'
required: false
type: boolean
default: true
# Can be called from other workflows (e.g., deploy workflow)
workflow_call:
inputs:
release_tag:
description: 'Release tag'
required: true
type: string
previous_tag:
description: 'Previous release tag'
required: false
type: string
dry_run:
description: 'Preview only'
required: false
type: boolean
default: false
secrets:
RESEND_API_KEY:
description: 'Resend API key for sending emails'
required: false
jobs:
notify-contributors:
name: Notify Contributors
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0 # Need full history for tag comparison
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install dependencies
run: |
pip install resend
- name: Determine previous tag
id: prev_tag
run: |
RELEASE_TAG="${{ inputs.release_tag }}"
PREVIOUS_TAG="${{ inputs.previous_tag }}"
if [ -z "$PREVIOUS_TAG" ]; then
# Get the tag before the release tag
PREVIOUS_TAG=$(git tag --sort=-creatordate | grep -A1 "^${RELEASE_TAG}$" | tail -n1)
if [ -z "$PREVIOUS_TAG" ] || [ "$PREVIOUS_TAG" = "$RELEASE_TAG" ]; then
echo "Error: Could not determine previous tag"
exit 1
fi
fi
echo "previous_tag=$PREVIOUS_TAG" >> $GITHUB_OUTPUT
echo "Using previous tag: $PREVIOUS_TAG"
- name: Get contributors and generate report
id: contributors
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python scripts/notify_release_contributors.py \
--repo "${{ github.repository }}" \
--from-tag "${{ steps.prev_tag.outputs.previous_tag }}" \
--to-tag "${{ inputs.release_tag }}" \
--output-json contributors.json \
--resolve-emails \
--dry-run
# Count contributors
CONTRIBUTOR_COUNT=$(jq '.contributors | length' contributors.json)
echo "contributor_count=$CONTRIBUTOR_COUNT" >> $GITHUB_OUTPUT
- name: Upload contributor report
uses: actions/upload-artifact@v4
with:
name: contributor-report-${{ inputs.release_tag }}
path: contributors.json
- name: Send email notifications
if: ${{ !inputs.dry_run && env.RESEND_API_KEY != '' }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RESEND_API_KEY: ${{ secrets.RESEND_API_KEY }}
run: |
python scripts/notify_release_contributors.py \
--repo "${{ github.repository }}" \
--from-tag "${{ steps.prev_tag.outputs.previous_tag }}" \
--to-tag "${{ inputs.release_tag }}" \
--email-provider resend \
--from-email "OpenHands Team <contact@all-hands.dev>" \
--resolve-emails
- name: Post summary
run: |
echo "## 📧 Release Contributor Notification Summary" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
echo "**Release:** ${{ inputs.release_tag }}" >> $GITHUB_STEP_SUMMARY
echo "**Previous Release:** ${{ steps.prev_tag.outputs.previous_tag }}" >> $GITHUB_STEP_SUMMARY
echo "**Contributors:** ${{ steps.contributors.outputs.contributor_count }}" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
if [ "${{ inputs.dry_run }}" = "true" ]; then
echo "⚠️ **DRY RUN** - No emails were sent" >> $GITHUB_STEP_SUMMARY
else
echo "✅ Notification emails sent to contributors" >> $GITHUB_STEP_SUMMARY
fi
echo "" >> $GITHUB_STEP_SUMMARY
echo "### Contributors" >> $GITHUB_STEP_SUMMARY
echo "" >> $GITHUB_STEP_SUMMARY
jq -r '.contributors[] | "- **\(.name)** (@\(.username)) - \(.commit_count) commit(s)"' contributors.json >> $GITHUB_STEP_SUMMARY

View File

@@ -2,11 +2,16 @@
name: PR Review by OpenHands
on:
# TEMPORARY MITIGATION (Clinejection hardening)
#
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
# workflow is fully hardened for untrusted execution.
pull_request:
# Use pull_request_target to allow fork PRs to access secrets when triggered by maintainers
# Security: This workflow runs when:
# 1. A new PR is opened (non-draft), OR
# 2. A draft PR is marked as ready for review, OR
# 3. A maintainer adds the 'review-this' label, OR
# 4. A maintainer requests openhands-agent or all-hands-bot as a reviewer
# Only users with write access can add labels or request reviews, ensuring security.
# The PR code is explicitly checked out for review, but secrets are only accessible
# because the workflow runs in the base repository context
pull_request_target:
types: [opened, ready_for_review, labeled, review_requested]
permissions:
@@ -16,33 +21,107 @@ permissions:
jobs:
pr-review:
# Note: fork PRs will not have access to repository secrets under `pull_request`.
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
# Run when one of the following conditions is met:
# 1. A new non-draft PR is opened by a trusted contributor, OR
# 2. A draft PR is converted to ready for review by a trusted contributor, OR
# 3. 'review-this' label is added, OR
# 4. openhands-agent or all-hands-bot is requested as a reviewer
# Note: FIRST_TIME_CONTRIBUTOR PRs require manual trigger via label/reviewer request
if: |
github.event.pull_request.head.repo.full_name == github.repository &&
(
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
github.event.action == 'ready_for_review' ||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
(
github.event.action == 'review_requested' &&
(
github.event.requested_reviewer.login == 'openhands-agent' ||
github.event.requested_reviewer.login == 'all-hands-bot'
)
)
)
(github.event.action == 'opened' && github.event.pull_request.draft == false && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
(github.event.action == 'ready_for_review' && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
github.event.label.name == 'review-this' ||
github.event.requested_reviewer.login == 'openhands-agent' ||
github.event.requested_reviewer.login == 'all-hands-bot'
concurrency:
group: pr-review-${{ github.event.pull_request.number }}
cancel-in-progress: true
runs-on: ubuntu-24.04
runs-on: blacksmith-4vcpu-ubuntu-2404
env:
LLM_MODEL: litellm_proxy/claude-sonnet-4-5-20250929
LLM_BASE_URL: https://llm-proxy.app.all-hands.dev
# PR context will be automatically provided by the agent script
PR_NUMBER: ${{ github.event.pull_request.number }}
PR_TITLE: ${{ github.event.pull_request.title }}
PR_BODY: ${{ github.event.pull_request.body }}
PR_BASE_BRANCH: ${{ github.event.pull_request.base.ref }}
PR_HEAD_BRANCH: ${{ github.event.pull_request.head.ref }}
REPO_NAME: ${{ github.repository }}
steps:
- name: Run PR Review
uses: OpenHands/extensions/plugins/pr-review@main
- name: Checkout software-agent-sdk repository
uses: actions/checkout@v5
with:
llm-model: litellm_proxy/claude-sonnet-4-5-20250929
llm-base-url: https://llm-proxy.app.all-hands.dev
review-style: roasted
llm-api-key: ${{ secrets.LLM_API_KEY }}
github-token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
lmnr-api-key: ${{ secrets.LMNR_SKILLS_API_KEY }}
repository: OpenHands/software-agent-sdk
path: software-agent-sdk
- name: Checkout PR repository
uses: actions/checkout@v5
with:
# When using pull_request_target, explicitly checkout the PR branch
# This ensures we review the actual PR code (including fork PRs)
repository: ${{ github.event.pull_request.head.repo.full_name }}
ref: ${{ github.event.pull_request.head.ref }}
fetch-depth: 0
# Security: Don't persist credentials to prevent untrusted PR code from using them
persist-credentials: false
path: pr-repo
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.13'
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install GitHub CLI
run: |
# Install GitHub CLI for posting review comments
sudo apt-get update
sudo apt-get install -y gh
- name: Install OpenHands dependencies
run: |
# Install OpenHands SDK and tools from local checkout
uv pip install --system ./software-agent-sdk/openhands-sdk ./software-agent-sdk/openhands-tools
- name: Check required configuration
env:
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
run: |
if [ -z "$LLM_API_KEY" ]; then
echo "Error: LLM_API_KEY secret is not set."
exit 1
fi
echo "PR Number: $PR_NUMBER"
echo "PR Title: $PR_TITLE"
echo "Repository: $REPO_NAME"
echo "LLM model: $LLM_MODEL"
if [ -n "$LLM_BASE_URL" ]; then
echo "LLM base URL: $LLM_BASE_URL"
fi
- name: Run PR review
env:
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
GITHUB_TOKEN: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
run: |
# Change to the PR repository directory so agent can analyze the code
cd pr-repo
# Run the PR review script from the software-agent-sdk checkout
uv run python ../software-agent-sdk/examples/03_github_workflows/02_pr_review/agent_script.py
- name: Upload logs as artifact
uses: actions/upload-artifact@v5
if: always()
with:
name: openhands-pr-review-logs
path: |
*.log
output/
retention-days: 7

View File

@@ -1,85 +0,0 @@
---
name: PR Review Evaluation
# This workflow evaluates how well PR review comments were addressed.
# It runs when a PR is closed to assess review effectiveness.
#
# Security note: pull_request_target is safe here because:
# 1. Only triggers on PR close (not on code changes)
# 2. Does not checkout PR code - only downloads artifacts from trusted workflow runs
# 3. Runs evaluation scripts from the extensions repo, not from the PR
on:
pull_request_target:
types: [closed]
permissions:
contents: read
pull-requests: read
jobs:
evaluate:
runs-on: ubuntu-24.04
env:
PR_NUMBER: ${{ github.event.pull_request.number }}
REPO_NAME: ${{ github.repository }}
PR_MERGED: ${{ github.event.pull_request.merged }}
steps:
- name: Download review trace artifact
id: download-trace
uses: dawidd6/action-download-artifact@v6
continue-on-error: true
with:
workflow: pr-review-by-openhands.yml
name: pr-review-trace-${{ github.event.pull_request.number }}
path: trace-info
search_artifacts: true
if_no_artifact_found: warn
- name: Check if trace file exists
id: check-trace
run: |
if [ -f "trace-info/laminar_trace_info.json" ]; then
echo "trace_exists=true" >> $GITHUB_OUTPUT
echo "Found trace file for PR #$PR_NUMBER"
else
echo "trace_exists=false" >> $GITHUB_OUTPUT
echo "No trace file found for PR #$PR_NUMBER - skipping evaluation"
fi
# Always checkout main branch for security - cannot test script changes in PRs
- name: Checkout extensions repository
if: steps.check-trace.outputs.trace_exists == 'true'
uses: actions/checkout@v5
with:
repository: OpenHands/extensions
path: extensions
- name: Set up Python
if: steps.check-trace.outputs.trace_exists == 'true'
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install dependencies
if: steps.check-trace.outputs.trace_exists == 'true'
run: pip install lmnr
- name: Run evaluation
if: steps.check-trace.outputs.trace_exists == 'true'
env:
# Script expects LMNR_PROJECT_API_KEY; org secret is named LMNR_SKILLS_API_KEY
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
python extensions/plugins/pr-review/scripts/evaluate_review.py \
--trace-file trace-info/laminar_trace_info.json
- name: Upload evaluation logs
uses: actions/upload-artifact@v5
if: always() && steps.check-trace.outputs.trace_exists == 'true'
with:
name: pr-review-evaluation-${{ github.event.pull_request.number }}
path: '*.log'
retention-days: 30

View File

@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
### OpenHands Cloud
This is a deployment of OpenHands GUI, running on hosted infrastructure.
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
OpenHands Cloud comes with source-available features and integrations:
- Integrations with Slack, Jira, and Linear

View File

@@ -1,110 +0,0 @@
"""create org_invitation table
Revision ID: 094
Revises: 093
Create Date: 2026-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '094'
down_revision: Union[str, None] = '093'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Create org_invitation table
op.create_table(
'org_invitation',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('token', sa.String(64), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('email', sa.String(255), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column(
'status',
sa.String(20),
nullable=False,
server_default=sa.text("'pending'"),
),
sa.Column(
'created_at',
sa.DateTime,
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('expires_at', sa.DateTime, nullable=False),
sa.Column('accepted_at', sa.DateTime, nullable=True),
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
# Foreign key constraints
sa.ForeignKeyConstraint(
['org_id'],
['org.id'],
name='org_invitation_org_fkey',
ondelete='CASCADE',
),
sa.ForeignKeyConstraint(
['role_id'],
['role.id'],
name='org_invitation_role_fkey',
),
sa.ForeignKeyConstraint(
['inviter_id'],
['user.id'],
name='org_invitation_inviter_fkey',
),
sa.ForeignKeyConstraint(
['accepted_by_user_id'],
['user.id'],
name='org_invitation_accepter_fkey',
),
)
# Create indexes
op.create_index(
'ix_org_invitation_token',
'org_invitation',
['token'],
unique=True,
)
op.create_index(
'ix_org_invitation_org_id',
'org_invitation',
['org_id'],
)
op.create_index(
'ix_org_invitation_email',
'org_invitation',
['email'],
)
op.create_index(
'ix_org_invitation_status',
'org_invitation',
['status'],
)
# Composite index for checking pending invitations
op.create_index(
'ix_org_invitation_org_email_status',
'org_invitation',
['org_id', 'email', 'status'],
)
def downgrade() -> None:
# Drop indexes
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
# Drop table
op.drop_table('org_invitation')

View File

@@ -1,37 +0,0 @@
"""Drop pending_free_credits column from org table.
Revision ID: 095
Revises: 094
Create Date: 2025-02-18 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '095'
down_revision: Union[str, None] = '094'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Drop the pending_free_credits column from org table.
# This column was used for tracking free credit eligibility but is no longer needed.
op.drop_column('org', 'pending_free_credits')
def downgrade() -> None:
# Re-add pending_free_credits column with default false.
op.add_column(
'org',
sa.Column(
'pending_free_credits',
sa.Boolean,
nullable=False,
server_default=sa.text('false'),
),
)

View File

@@ -1,67 +0,0 @@
"""Create resend_synced_users table.
Revision ID: 096
Revises: 095
Create Date: 2025-02-17 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '096'
down_revision: Union[str, None] = '095'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
op.create_table(
'resend_synced_users',
sa.Column(
'id',
sa.UUID(as_uuid=True),
nullable=False,
primary_key=True,
),
sa.Column('email', sa.String(), nullable=False),
sa.Column('audience_id', sa.String(), nullable=False),
sa.Column(
'synced_at',
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text('CURRENT_TIMESTAMP'),
),
sa.Column('keycloak_user_id', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint(
'email', 'audience_id', name='uq_resend_synced_email_audience'
),
)
# Create index on email for fast lookups
op.create_index(
'ix_resend_synced_users_email',
'resend_synced_users',
['email'],
)
# Create index on audience_id for filtering by audience
op.create_index(
'ix_resend_synced_users_audience_id',
'resend_synced_users',
['audience_id'],
)
def downgrade() -> None:
"""Drop resend_synced_users table."""
op.drop_index(
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
)
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
op.drop_table('resend_synced_users')

View File

@@ -38,12 +38,6 @@ from server.routes.integration.linear import linear_integration_router # noqa:
from server.routes.integration.slack import slack_router # noqa: E402
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.org_invitations import ( # noqa: E402
accept_router as invitation_accept_router,
)
from server.routes.org_invitations import ( # noqa: E402
invitation_router,
)
from server.routes.orgs import org_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
@@ -105,8 +99,6 @@ if GITLAB_APP_CLIENT_ID:
base_app.include_router(api_keys_router) # Add routes for API key management
base_app.include_router(org_router) # Add routes for organization management
base_app.include_router(invitation_router) # Add routes for org invitation management
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
add_github_proxy_routes(base_app)
add_debugging_routes(
base_app

View File

@@ -1,306 +0,0 @@
"""
Permission-based authorization dependencies for API endpoints.
This module provides FastAPI dependencies for checking user permissions
within organizations. It uses a permission-based authorization model where
roles (owner, admin, member) are mapped to specific permissions.
Permissions are defined in the Permission enum and mapped to roles via
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
the familiar role-based hierarchy.
Usage:
from server.auth.authorization import (
Permission,
require_permission,
)
@router.get('/{org_id}/settings')
async def get_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
):
# Only users with VIEW_LLM_SETTINGS permission can access
...
@router.patch('/{org_id}/settings')
async def update_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
):
# Only users with EDIT_LLM_SETTINGS permission can access
...
"""
from enum import Enum
from uuid import UUID
from fastapi import Depends, HTTPException, status
from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.role_store import RoleStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
class Permission(str, Enum):
"""Permissions that can be assigned to roles."""
# Secrets
MANAGE_SECRETS = 'manage_secrets'
# MCP
MANAGE_MCP = 'manage_mcp'
# Integrations
MANAGE_INTEGRATIONS = 'manage_integrations'
# Application Settings
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
# API Keys
MANAGE_API_KEYS = 'manage_api_keys'
# LLM Settings
VIEW_LLM_SETTINGS = 'view_llm_settings'
EDIT_LLM_SETTINGS = 'edit_llm_settings'
# Billing
VIEW_BILLING = 'view_billing'
ADD_CREDITS = 'add_credits'
# Organization Members
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
# Organization Management
VIEW_ORG_SETTINGS = 'view_org_settings'
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
DELETE_ORGANIZATION = 'delete_organization'
# Temporary permissions until we finish the API updates.
EDIT_ORG_SETTINGS = 'edit_org_settings'
class RoleName(str, Enum):
"""Role names used in the system."""
OWNER = 'owner'
ADMIN = 'admin'
MEMBER = 'member'
# Permission mappings for each role
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
RoleName.OWNER: frozenset(
[
# Settings (Full access)
Permission.MANAGE_SECRETS,
Permission.MANAGE_MCP,
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
Permission.VIEW_LLM_SETTINGS,
Permission.EDIT_LLM_SETTINGS,
Permission.VIEW_BILLING,
Permission.ADD_CREDITS,
# Organization Members
Permission.INVITE_USER_TO_ORGANIZATION,
Permission.CHANGE_USER_ROLE_MEMBER,
Permission.CHANGE_USER_ROLE_ADMIN,
Permission.CHANGE_USER_ROLE_OWNER,
# Organization Management
Permission.VIEW_ORG_SETTINGS,
Permission.EDIT_ORG_SETTINGS,
# Organization Management (Owner only)
Permission.CHANGE_ORGANIZATION_NAME,
Permission.DELETE_ORGANIZATION,
]
),
RoleName.ADMIN: frozenset(
[
# Settings (Full access)
Permission.MANAGE_SECRETS,
Permission.MANAGE_MCP,
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
Permission.VIEW_LLM_SETTINGS,
Permission.EDIT_LLM_SETTINGS,
Permission.VIEW_BILLING,
Permission.ADD_CREDITS,
# Organization Members
Permission.INVITE_USER_TO_ORGANIZATION,
Permission.CHANGE_USER_ROLE_MEMBER,
Permission.CHANGE_USER_ROLE_ADMIN,
# Organization Management
Permission.VIEW_ORG_SETTINGS,
Permission.EDIT_ORG_SETTINGS,
]
),
RoleName.MEMBER: frozenset(
[
# Settings (Full access)
Permission.MANAGE_SECRETS,
Permission.MANAGE_MCP,
Permission.MANAGE_INTEGRATIONS,
Permission.MANAGE_APPLICATION_SETTINGS,
Permission.MANAGE_API_KEYS,
# Settings (View only)
Permission.VIEW_ORG_SETTINGS,
Permission.VIEW_LLM_SETTINGS,
]
),
}
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
"""
Get the user's role in an organization (synchronous version).
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID, or None to use the user's current organization
Returns:
Role object if user is a member, None otherwise
"""
from uuid import UUID as parse_uuid
if org_id is None:
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
else:
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
if not org_member:
return None
return RoleStore.get_role_by_id(org_member.role_id)
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
"""
Get the user's role in an organization (async version).
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID, or None to use the user's current organization
Returns:
Role object if user is a member, None otherwise
"""
from uuid import UUID as parse_uuid
if org_id is None:
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
parse_uuid(user_id)
)
else:
org_member = await OrgMemberStore.get_org_member_async(
org_id, parse_uuid(user_id)
)
if not org_member:
return None
return await RoleStore.get_role_by_id_async(org_member.role_id)
def get_role_permissions(role_name: str) -> frozenset[Permission]:
"""
Get the permissions for a role.
Args:
role_name: Name of the role
Returns:
Set of permissions for the role
"""
try:
role_enum = RoleName(role_name)
return ROLE_PERMISSIONS.get(role_enum, frozenset())
except ValueError:
return frozenset()
def has_permission(user_role: Role, permission: Permission) -> bool:
"""
Check if a role has a specific permission.
Args:
user_role: User's Role object
permission: Permission to check
Returns:
True if the role has the permission
"""
permissions = get_role_permissions(user_role.name)
return permission in permissions
def require_permission(permission: Permission):
"""
Factory function that creates a dependency to require a specific permission.
This creates a FastAPI dependency that:
1. Extracts org_id from the path parameter
2. Gets the authenticated user_id
3. Checks if the user has the required permission in the organization
4. Returns the user_id if authorized, raises HTTPException otherwise
Usage:
@router.get('/{org_id}/settings')
async def get_settings(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
):
...
Args:
permission: The permission required to access the endpoint
Returns:
Dependency function that validates permission and returns user_id
"""
async def permission_checker(
org_id: UUID | None = None,
user_id: str | None = Depends(get_user_id),
) -> str:
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='User not authenticated',
)
user_role = await get_user_org_role_async(user_id, org_id)
if not user_role:
logger.warning(
'User not a member of organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='User is not a member of this organization',
)
if not has_permission(user_role, permission):
logger.warning(
'Insufficient permissions',
extra={
'user_id': user_id,
'org_id': str(org_id),
'user_role': user_role.name,
'required_permission': permission.value,
},
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f'Requires {permission.value} permission',
)
return user_id
return permission_checker

View File

@@ -61,6 +61,8 @@ SUBSCRIPTION_PRICE_DATA = {
},
}
FREE_CREDIT_THRESHOLD = float(os.environ.get('FREE_CREDIT_THRESHOLD', '10'))
FREE_CREDIT_AMOUNT = float(os.environ.get('FREE_CREDIT_AMOUNT', '10'))
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')

View File

@@ -51,14 +51,6 @@ def custom_json_serializer(obj, **kwargs):
obj['stack_info'] = format_stack(stack_info)
result = json.dumps(obj, **kwargs)
# Swap out newlines to make things easier to read. This will produce
# invalid json but means we can have similar logs in local development
# to production, making things easier to correlate. Obviously,
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
if LOG_JSON_FOR_CONSOLE:
result = result.replace('\\n', '\n')
return result

View File

@@ -160,7 +160,6 @@ class SetAuthCookieMiddleware:
'/api/billing/customer-setup-success',
'/api/billing/stripe-webhook',
'/api/email/resend',
'/api/organizations/members/invite/accept',
'/oauth/device/authorize',
'/oauth/device/token',
'/api/v1/web-client/config',

View File

@@ -2,7 +2,6 @@ from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from storage.api_key import ApiKey
from storage.api_key_store import ApiKeyStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
@@ -136,9 +135,9 @@ class ApiKeyCreate(BaseModel):
class ApiKeyResponse(BaseModel):
id: int
name: str | None = None
created_at: datetime
last_used_at: datetime | None = None
expires_at: datetime | None = None
created_at: str
last_used_at: str | None = None
expires_at: str | None = None
class ApiKeyCreateResponse(ApiKeyResponse):
@@ -153,29 +152,12 @@ class ByorPermittedResponse(BaseModel):
permitted: bool
class MessageResponse(BaseModel):
message: str
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
"""Convert an ApiKey model to an ApiKeyResponse."""
return ApiKeyResponse(
id=key.id,
name=key.name,
created_at=key.created_at,
last_used_at=key.last_used_at,
expires_at=key.expires_at,
)
@api_router.get('/llm/byor/permitted', tags=['Keys'])
async def check_byor_permitted(
user_id: str = Depends(get_user_id),
) -> ByorPermittedResponse:
@api_router.get('/llm/byor/permitted', response_model=ByorPermittedResponse)
async def check_byor_permitted(user_id: str = Depends(get_user_id)):
"""Check if BYOR key export is permitted for the user's current org."""
try:
permitted = await OrgService.check_byor_export_enabled(user_id)
return ByorPermittedResponse(permitted=permitted)
return {'permitted': permitted}
except Exception as e:
logger.exception(
'Error checking BYOR export permission', extra={'error': str(e)}
@@ -186,10 +168,8 @@ async def check_byor_permitted(
)
@api_router.post('', tags=['Keys'])
async def create_api_key(
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
) -> ApiKeyCreateResponse:
@api_router.post('', response_model=ApiKeyCreateResponse)
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
"""Create a new API key for the authenticated user."""
try:
api_key = await api_key_store.create_api_key(
@@ -198,29 +178,48 @@ async def create_api_key(
# Get the created key details
keys = await api_key_store.list_api_keys(user_id)
for key in keys:
if key.name == key_data.name:
return ApiKeyCreateResponse(
id=key.id,
name=key.name,
key=api_key,
created_at=key.created_at,
last_used_at=key.last_used_at,
expires_at=key.expires_at,
)
if key['name'] == key_data.name:
return {
**key,
'key': api_key,
'created_at': (
key['created_at'].isoformat() if key['created_at'] else None
),
'last_used_at': (
key['last_used_at'].isoformat() if key['last_used_at'] else None
),
'expires_at': (
key['expires_at'].isoformat() if key['expires_at'] else None
),
}
except Exception:
logger.exception('Error creating API key')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create API key',
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to create API key',
)
@api_router.get('', tags=['Keys'])
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
@api_router.get('', response_model=list[ApiKeyResponse])
async def list_api_keys(user_id: str = Depends(get_user_id)):
"""List all API keys for the authenticated user."""
try:
keys = await api_key_store.list_api_keys(user_id)
return [api_key_to_response(key) for key in keys]
return [
{
**key,
'created_at': (
key['created_at'].isoformat() if key['created_at'] else None
),
'last_used_at': (
key['last_used_at'].isoformat() if key['last_used_at'] else None
),
'expires_at': (
key['expires_at'].isoformat() if key['expires_at'] else None
),
}
for key in keys
]
except Exception:
logger.exception('Error listing API keys')
raise HTTPException(
@@ -229,10 +228,8 @@ async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyRespo
)
@api_router.delete('/{key_id}', tags=['Keys'])
async def delete_api_key(
key_id: int, user_id: str = Depends(get_user_id)
) -> MessageResponse:
@api_router.delete('/{key_id}')
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
"""Delete an API key."""
try:
# First, verify the key belongs to the user
@@ -240,7 +237,7 @@ async def delete_api_key(
key_to_delete = None
for key in keys:
if key.id == key_id:
if key['id'] == key_id:
key_to_delete = key
break
@@ -258,7 +255,7 @@ async def delete_api_key(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to delete API key',
)
return MessageResponse(message='API key deleted successfully')
return {'message': 'API key deleted successfully'}
except HTTPException:
raise
except Exception:
@@ -269,10 +266,8 @@ async def delete_api_key(
)
@api_router.get('/llm/byor', tags=['Keys'])
async def get_llm_api_key_for_byor(
user_id: str = Depends(get_user_id),
) -> LlmApiKeyResponse:
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
This endpoint validates that the key exists in LiteLLM before returning it.
@@ -295,7 +290,7 @@ async def get_llm_api_key_for_byor(
# Validate that the key is actually registered in LiteLLM
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
if is_valid:
return LlmApiKeyResponse(key=byor_key)
return {'key': byor_key}
else:
# Key exists in DB but is invalid in LiteLLM - regenerate it
logger.warning(
@@ -320,7 +315,7 @@ async def get_llm_api_key_for_byor(
'Successfully generated and stored new BYOR key',
extra={'user_id': user_id},
)
return LlmApiKeyResponse(key=key)
return {'key': key}
else:
logger.error(
'Failed to generate new BYOR LLM API key',
@@ -342,10 +337,8 @@ async def get_llm_api_key_for_byor(
)
@api_router.post('/llm/byor/refresh', tags=['Keys'])
async def refresh_llm_api_key_for_byor(
user_id: str = Depends(get_user_id),
) -> LlmApiKeyResponse:
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
@@ -398,7 +391,7 @@ async def refresh_llm_api_key_for_byor(
'BYOR LLM API key refresh completed successfully',
extra={'user_id': user_id},
)
return LlmApiKeyResponse(key=key)
return {'key': key}
except HTTPException as he:
logger.error(
'HTTP exception during BYOR LLM API key refresh',

View File

@@ -5,7 +5,6 @@ import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional
from urllib.parse import quote
from uuid import UUID as parse_uuid
import posthog
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
@@ -27,13 +26,6 @@ from server.auth.token_manager import TokenManager
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
OrgInvitationService,
UserAlreadyMemberError,
)
from storage.database import session_maker
from storage.user import User
from storage.user_store import UserStore
@@ -112,40 +104,22 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
)
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
Returns:
Tuple of (redirect_url, recaptcha_token, invitation_token).
Tokens may be None.
"""
if not state:
return '', None, None
try:
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return (
state_data.get('redirect_url', ''),
state_data.get('recaptcha_token'),
state_data.get('invitation_token'),
)
except Exception:
# Old format - state is just the redirect URL
return state, None, None
# Keep alias for backward compatibility
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
"""Extract redirect URL and reCAPTCHA token from OAuth state.
Deprecated: Use _extract_oauth_state instead.
Returns:
Tuple of (redirect_url, recaptcha_token). Token may be None.
"""
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
return redirect_url, recaptcha_token
if not state:
return '', None
try:
# Try to decode as JSON (new format with reCAPTCHA)
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
except Exception:
# Old format - state is just the redirect URL
return state, None
@oauth_router.get('/keycloak/callback')
@@ -156,8 +130,8 @@ async def keycloak_callback(
error: Optional[str] = None,
error_description: Optional[str] = None,
):
# Extract redirect URL, reCAPTCHA token, and invitation token from state
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
# Extract redirect URL and reCAPTCHA token from state
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
if not redirect_url:
redirect_url = str(request.base_url)
@@ -328,13 +302,8 @@ async def keycloak_callback(
from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
# Preserve invitation token so it can be included in OAuth state after verification
if invitation_token:
verification_redirect_url = (
f'{verification_redirect_url}&invitation_token={invitation_token}'
)
response = RedirectResponse(verification_redirect_url, status_code=302)
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
response = RedirectResponse(redirect_url, status_code=302)
return response
# default to github IDP for now.
@@ -412,90 +381,14 @@ async def keycloak_callback(
)
has_accepted_tos = user.accepted_tos is not None
# Process invitation token if present (after email verification but before TOS)
if invitation_token:
try:
logger.info(
'Processing invitation token during auth callback',
extra={
'user_id': user_id,
'invitation_token_prefix': invitation_token[:10] + '...',
},
)
await OrgInvitationService.accept_invitation(
invitation_token, parse_uuid(user_id)
)
logger.info(
'Invitation accepted during auth callback',
extra={'user_id': user_id},
)
except InvitationExpiredError:
logger.warning(
'Invitation expired during auth callback',
extra={'user_id': user_id},
)
# Add query param to redirect URL
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
except InvitationInvalidError as e:
logger.warning(
'Invalid invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
except UserAlreadyMemberError:
logger.info(
'User already member during invitation acceptance',
extra={'user_id': user_id},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
except EmailMismatchError as e:
logger.warning(
'Email mismatch during auth callback invitation acceptance',
extra={'user_id': user_id, 'error': str(e)},
)
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
except Exception as e:
logger.exception(
'Unexpected error processing invitation during auth callback',
extra={'user_id': user_id, 'error': str(e)},
)
# Don't fail the login if invitation processing fails
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_error=true'
else:
redirect_url = f'{redirect_url}?invitation_error=true'
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
tos_redirect_url = (
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
)
if invitation_token:
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302)
else:
if invitation_token:
redirect_url = f'{redirect_url}&invitation_success=true'
response = RedirectResponse(redirect_url, status_code=302)
set_response_cookie(

View File

@@ -9,7 +9,11 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from integrations import stripe_service
from pydantic import BaseModel
from server.constants import STRIPE_API_KEY
from server.constants import (
FREE_CREDIT_AMOUNT,
FREE_CREDIT_THRESHOLD,
STRIPE_API_KEY,
)
from server.logger import logger
from starlette.datastructures import URL
from storage.billing_session import BillingSession
@@ -93,9 +97,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(user.current_org_id)
)
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(user.current_org_id)
)
# Update to use calculate_credits
spend = user_team_info.get('spend', 0)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
credits = max(max_budget - spend, 0)
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
@@ -147,7 +151,7 @@ async def create_customer_setup_session(
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{base_url}?setup=success',
success_url=f'{base_url}?free_credits=success',
cancel_url=f'{base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
@@ -249,13 +253,31 @@ async def success_callback(session_id: str, request: Request):
)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
user_team_info, billing_session.user_id, str(user.current_org_id)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
org = session.query(Org).filter(Org.id == user.current_org_id).first()
new_max_budget = max_budget + add_credits
# Grant free credits if:
# 1. The org has pending free credits (new org, eligible)
# 2. The budget after this purchase meets the threshold
should_grant_free_credits = (
org and org.pending_free_credits and new_max_budget >= FREE_CREDIT_THRESHOLD
)
if should_grant_free_credits:
new_max_budget += FREE_CREDIT_AMOUNT
org.pending_free_credits = False
logger.info(
'free_credits_granted',
extra={
'user_id': billing_session.user_id,
'org_id': str(user.current_org_id),
'free_credit_amount': FREE_CREDIT_AMOUNT,
},
)
await LiteLlmManager.update_team_and_users_budget(
str(user.current_org_id), new_max_budget
)
@@ -277,6 +299,7 @@ async def success_callback(session_id: str, request: Request):
'org_id': str(user.current_org_id),
'checkout_session_id': billing_session.id,
'stripe_customer_id': stripe_session.customer,
'free_credits_granted': should_grant_free_credits,
},
)
session.commit()

View File

@@ -4,7 +4,7 @@ import json
import os
import re
import uuid
from urllib.parse import urlencode, urlparse
from urllib.parse import urlparse
import requests
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
@@ -371,7 +371,9 @@ async def create_jira_workspace(request: Request, workspace_data: JiraWorkspaceC
'prompt': 'consent',
}
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
auth_url = (
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
)
return JSONResponse(
content={
@@ -430,7 +432,9 @@ async def create_workspace_link(request: Request, link_data: JiraLinkCreate):
'response_type': 'code',
'prompt': 'consent',
}
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
auth_url = (
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
)
return JSONResponse(
content={

View File

@@ -2,7 +2,7 @@ import json
import os
import re
import uuid
from urllib.parse import urlencode, urlparse
from urllib.parse import urlparse
import requests
from fastapi import (
@@ -316,7 +316,7 @@ async def create_jira_dc_workspace(
'response_type': 'code',
}
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
return JSONResponse(
content={
@@ -436,7 +436,7 @@ async def create_workspace_link(request: Request, link_data: JiraDcLinkCreate):
'state': state,
'response_type': 'code',
}
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
return JSONResponse(
content={

View File

@@ -1,122 +0,0 @@
"""
Pydantic models and custom exceptions for organization invitations.
"""
from pydantic import BaseModel, EmailStr
from storage.org_invitation import OrgInvitation
from storage.role_store import RoleStore
class InvitationError(Exception):
"""Base exception for invitation errors."""
pass
class InvitationAlreadyExistsError(InvitationError):
"""Raised when a pending invitation already exists for the email."""
def __init__(
self, message: str = 'A pending invitation already exists for this email'
):
super().__init__(message)
class UserAlreadyMemberError(InvitationError):
"""Raised when the user is already a member of the organization."""
def __init__(self, message: str = 'User is already a member of this organization'):
super().__init__(message)
class InvitationExpiredError(InvitationError):
"""Raised when the invitation has expired."""
def __init__(self, message: str = 'Invitation has expired'):
super().__init__(message)
class InvitationInvalidError(InvitationError):
"""Raised when the invitation is invalid or revoked."""
def __init__(self, message: str = 'Invitation is no longer valid'):
super().__init__(message)
class InsufficientPermissionError(InvitationError):
"""Raised when the user lacks permission to perform the action."""
def __init__(self, message: str = 'Insufficient permission'):
super().__init__(message)
class EmailMismatchError(InvitationError):
"""Raised when the accepting user's email doesn't match the invitation email."""
def __init__(self, message: str = 'Your email does not match the invitation'):
super().__init__(message)
class InvitationCreate(BaseModel):
"""Request model for creating invitation(s)."""
emails: list[EmailStr]
role: str = 'member' # Default to member role
class InvitationResponse(BaseModel):
"""Response model for invitation details."""
id: int
email: str
role: str
status: str
created_at: str
expires_at: str
inviter_email: str | None = None
@classmethod
def from_invitation(
cls,
invitation: OrgInvitation,
inviter_email: str | None = None,
) -> 'InvitationResponse':
"""Create an InvitationResponse from an OrgInvitation entity.
Args:
invitation: The invitation entity to convert
inviter_email: Optional email of the inviter
Returns:
InvitationResponse: The response model instance
"""
role_name = ''
if invitation.role:
role_name = invitation.role.name
elif invitation.role_id:
role = RoleStore.get_role_by_id(invitation.role_id)
role_name = role.name if role else ''
return cls(
id=invitation.id,
email=invitation.email,
role=role_name,
status=invitation.status,
created_at=invitation.created_at.isoformat(),
expires_at=invitation.expires_at.isoformat(),
inviter_email=inviter_email,
)
class InvitationFailure(BaseModel):
"""Response model for a failed invitation."""
email: str
error: str
class BatchInvitationResponse(BaseModel):
"""Response model for batch invitation creation."""
successful: list[InvitationResponse]
failed: list[InvitationFailure]

View File

@@ -1,226 +0,0 @@
"""API routes for organization invitations."""
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import RedirectResponse
from server.routes.org_invitation_models import (
BatchInvitationResponse,
EmailMismatchError,
InsufficientPermissionError,
InvitationCreate,
InvitationExpiredError,
InvitationFailure,
InvitationInvalidError,
InvitationResponse,
UserAlreadyMemberError,
)
from server.services.org_invitation_service import OrgInvitationService
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.server.user_auth.user_auth import get_user_auth
# Router for invitation operations on an organization (requires org_id)
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
# Router for accepting invitations (no org_id required)
accept_router = APIRouter(prefix='/api/organizations/members/invite')
@invitation_router.post(
'/invite',
response_model=BatchInvitationResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_invitation(
org_id: UUID,
invitation_data: InvitationCreate,
request: Request,
user_id: str = Depends(get_user_id),
):
"""Create organization invitations for multiple email addresses.
Sends emails to invitees with secure links to join the organization.
Supports batch invitations - some may succeed while others fail.
Permission rules:
- Only owners and admins can create invitations
- Admins can only invite with 'member' or 'admin' role (not 'owner')
- Owners can invite with any role
Args:
org_id: Organization UUID
invitation_data: Invitation details (emails array, role)
request: FastAPI request
user_id: Authenticated user ID (from dependency)
Returns:
BatchInvitationResponse: Lists of successful and failed invitations
Raises:
HTTPException 400: Invalid role or organization not found
HTTPException 403: User lacks permission to invite
HTTPException 429: Rate limit exceeded
"""
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
await check_rate_limit_by_user_id(
request=request,
key_prefix='org_invitation_create',
user_id=user_id,
user_rate_limit_seconds=6,
)
try:
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=[str(email) for email in invitation_data.emails],
role_name=invitation_data.role,
inviter_id=UUID(user_id),
)
logger.info(
'Batch organization invitations created',
extra={
'org_id': str(org_id),
'total_emails': len(invitation_data.emails),
'successful': len(successful),
'failed': len(failed),
'inviter_id': user_id,
},
)
return BatchInvitationResponse(
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
failed=[
InvitationFailure(email=email, error=error) for email, error in failed
],
)
except InsufficientPermissionError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except Exception as e:
logger.exception(
'Unexpected error creating batch invitations',
extra={'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)
@accept_router.get('/accept')
async def accept_invitation(
token: str,
request: Request,
):
"""Accept an organization invitation via token.
This endpoint is accessed via the link in the invitation email.
Flow:
1. If user is authenticated: Accept invitation directly and redirect to home
2. If user is not authenticated: Redirect to login page with invitation token
- Frontend stores token and includes it in OAuth state during login
- After authentication, keycloak_callback processes the invitation
Args:
token: The invitation token from the email link
request: FastAPI request
Returns:
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
or home page with error query params on failure
"""
base_url = str(request.base_url).rstrip('/')
# Try to get user_id from auth (may not be authenticated)
user_id = None
try:
user_auth = await get_user_auth(request)
if user_auth:
user_id = await user_auth.get_user_id()
except Exception:
pass
if not user_id:
# User not authenticated - redirect to login page with invitation token
# Frontend will store the token and include it in OAuth state during login
logger.info(
'Invitation accept: redirecting unauthenticated user to login',
extra={'token_prefix': token[:10] + '...'},
)
login_url = f'{base_url}/login?invitation_token={token}'
return RedirectResponse(login_url, status_code=302)
# User is authenticated - process the invitation directly
try:
await OrgInvitationService.accept_invitation(token, UUID(user_id))
logger.info(
'Invitation accepted successfully',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
},
)
# Redirect to home page on success
return RedirectResponse(f'{base_url}/', status_code=302)
except InvitationExpiredError:
logger.warning(
'Invitation accept failed: expired',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
except InvitationInvalidError as e:
logger.warning(
'Invitation accept failed: invalid',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
except UserAlreadyMemberError:
logger.info(
'Invitation accept: user already member',
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
)
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
except EmailMismatchError as e:
logger.warning(
'Invitation accept failed: email mismatch',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
except Exception as e:
logger.exception(
'Unexpected error accepting invitation',
extra={
'token_prefix': token[:10] + '...',
'user_id': user_id,
'error': str(e),
},
)
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)

View File

@@ -214,7 +214,6 @@ class OrgPage(BaseModel):
items: list[OrgResponse]
next_page_id: str | None = None
current_org_id: str | None = None
class OrgUpdate(BaseModel):
@@ -258,7 +257,7 @@ class OrgMemberResponse(BaseModel):
user_id: str
email: str | None
role_id: int
role: str
role_name: str
role_rank: int
status: str | None

View File

@@ -2,10 +2,6 @@ from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException, Query, status
from server.auth.authorization import (
Permission,
require_permission,
)
from server.email_validation import get_admin_user_id
from server.routes.org_models import (
CannotModifySelfError,
@@ -32,7 +28,6 @@ from server.routes.org_models import (
)
from server.services.org_member_service import OrgMemberService
from storage.org_service import OrgService
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@@ -79,12 +74,6 @@ async def list_user_orgs(
)
try:
# Fetch user to get current_org_id
user = await UserStore.get_user_by_id_async(user_id)
current_org_id = (
str(user.current_org_id) if user and user.current_org_id else None
)
# Fetch organizations from service layer
orgs, next_page_id = OrgService.get_user_orgs_paginated(
user_id=user_id,
@@ -106,11 +95,7 @@ async def list_user_orgs(
},
)
return OrgPage(
items=org_responses,
next_page_id=next_page_id,
current_org_id=current_org_id,
)
return OrgPage(items=org_responses, next_page_id=next_page_id)
except Exception as e:
logger.exception(
@@ -204,26 +189,23 @@ async def create_org(
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
async def get_org(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Get organization details by ID.
This endpoint retrieves details for a specific organization. Access requires
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
(member, admin, and owner roles).
This endpoint allows authenticated users who are members of an organization
to retrieve its details. Only members of the organization can access this endpoint.
Args:
org_id: Organization ID (UUID)
user_id: Authenticated user ID (injected by require_permission dependency)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
HTTPException: 404 if organization not found
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 404 if organization not found or user is not a member
HTTPException: 500 if retrieval fails
"""
logger.info(
@@ -323,24 +305,23 @@ async def get_me(
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
async def delete_org(
org_id: UUID,
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
user_id: str = Depends(get_user_id),
) -> dict:
"""Delete an organization.
This endpoint permanently deletes an organization and all associated data including
organization members, conversations, billing data, and external LiteLLM team resources.
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
This endpoint allows authenticated organization owners to delete their organization.
All associated data including organization members, conversations, billing data,
and external LiteLLM team resources will be permanently removed.
Args:
org_id: Organization ID to delete (UUID)
user_id: Authenticated user ID (injected by require_permission dependency)
org_id: Organization ID to delete
user_id: Authenticated user ID (injected by dependency)
Returns:
dict: Confirmation message with deleted organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
HTTPException: 403 if user is not the organization owner
HTTPException: 404 if organization not found
HTTPException: 500 if deletion fails
"""
@@ -433,26 +414,25 @@ async def delete_org(
async def update_org(
org_id: UUID,
update_data: OrgUpdate,
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Update an existing organization.
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
permission, which is granted to admin and owner roles.
This endpoint allows authenticated users to update organization settings.
LLM-related settings require admin or owner role in the organization.
Args:
org_id: Organization ID to update (UUID)
org_id: Organization ID to update (UUID validated by FastAPI)
update_data: Organization update data
user_id: Authenticated user ID (injected by require_permission dependency)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The updated organization details
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
HTTPException: 403 if user lacks permission for LLM settings
HTTPException: 404 if organization not found
HTTPException: 409 if organization name already exists
HTTPException: 422 if validation errors occur (handled by FastAPI)
HTTPException: 500 if update fails
"""
@@ -516,7 +496,7 @@ async def update_org(
@org_router.get('/{org_id}/members')
async def get_org_members(
org_id: UUID,
org_id: str,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
@@ -529,33 +509,13 @@ async def get_org_members(
lte=100,
),
] = 100,
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
current_user_id: str = Depends(get_user_id),
) -> OrgMemberPage:
"""Get all members of an organization with cursor-based pagination.
This endpoint retrieves a paginated list of organization members. Access requires
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
(member, admin, and owner roles).
Args:
org_id: Organization ID (UUID)
page_id: Optional page ID (offset) for pagination
limit: Maximum number of members to return (1-100, default 100)
user_id: Authenticated user ID (injected by require_permission dependency)
Returns:
OrgMemberPage: Paginated list of organization members
Raises:
HTTPException: 401 if user is not authenticated
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
HTTPException: 400 if org_id or page_id format is invalid
HTTPException: 500 if retrieval fails
"""
"""Get all members of an organization with cursor-based pagination."""
try:
success, error_code, data = await OrgMemberService.get_org_members(
org_id=org_id,
current_user_id=UUID(user_id),
org_id=UUID(org_id),
current_user_id=UUID(current_user_id),
page_id=page_id,
limit=limit,
)
@@ -602,7 +562,7 @@ async def get_org_members(
@org_router.delete('/{org_id}/members/{user_id}')
async def remove_org_member(
org_id: UUID,
org_id: str,
user_id: str,
current_user_id: str = Depends(get_user_id),
):
@@ -616,7 +576,7 @@ async def remove_org_member(
"""
try:
success, error = await OrgMemberService.remove_org_member(
org_id=org_id,
org_id=UUID(org_id),
target_user_id=UUID(user_id),
current_user_id=UUID(current_user_id),
)
@@ -748,7 +708,7 @@ async def switch_org(
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
async def update_org_member(
org_id: UUID,
org_id: str,
user_id: str,
update_data: OrgMemberUpdate,
current_user_id: str = Depends(get_user_id),
@@ -765,7 +725,7 @@ async def update_org_member(
"""
try:
return await OrgMemberService.update_org_member(
org_id=org_id,
org_id=UUID(org_id),
target_user_id=UUID(user_id),
current_user_id=UUID(current_user_id),
update_data=update_data,

View File

@@ -1,131 +0,0 @@
"""Email service for sending transactional emails via Resend."""
import os
try:
import resend
RESEND_AVAILABLE = True
except ImportError:
RESEND_AVAILABLE = False
from openhands.core.logger import openhands_logger as logger
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
class EmailService:
"""Service for sending transactional emails."""
@staticmethod
def _get_resend_client() -> bool:
"""Initialize and return the Resend client.
Returns:
bool: True if client is ready, False otherwise
"""
if not RESEND_AVAILABLE:
logger.warning('Resend library not installed, skipping email')
return False
resend_api_key = os.environ.get('RESEND_API_KEY')
if not resend_api_key:
logger.warning('RESEND_API_KEY not configured, skipping email')
return False
resend.api_key = resend_api_key
return True
@staticmethod
def send_invitation_email(
to_email: str,
org_name: str,
inviter_name: str,
role_name: str,
invitation_token: str,
invitation_id: int,
) -> None:
"""Send an organization invitation email.
Args:
to_email: Recipient's email address
org_name: Name of the organization
inviter_name: Display name of the person who sent the invite
role_name: Role being offered (e.g., 'member', 'admin')
invitation_token: The secure invitation token
invitation_id: The invitation ID for logging
"""
if not EmailService._get_resend_client():
return
# Build invitation URL
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
params = {
'from': from_email,
'to': [to_email],
'subject': f"You're invited to join {org_name} on OpenHands",
'html': f"""
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
<p>Hi,</p>
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
<p>Click the button below to accept the invitation:</p>
<p style="margin: 30px 0;">
<a href="{invitation_url}"
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
text-decoration: none; border-radius: 8px; display: inline-block;
font-size: 14px; font-weight: 600;">
Accept Invitation
</a>
</p>
<p style="color: #666; font-size: 14px;">
Or copy and paste this link into your browser:<br>
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
</p>
<p style="color: #666; font-size: 14px;">
This invitation will expire in 7 days.
</p>
<p style="color: #666; font-size: 14px;">
If you weren't expecting this invitation, you can safely ignore this email.
</p>
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
<p style="color: #999; font-size: 12px;">
Best,<br>
The OpenHands Team
</p>
</div>
""",
}
try:
response = resend.Emails.send(params)
logger.info(
'Invitation email sent',
extra={
'invitation_id': invitation_id,
'email': to_email,
'response_id': response.get('id') if response else None,
},
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation_id,
'email': to_email,
'error': str(e),
},
)
raise

View File

@@ -1,397 +0,0 @@
"""Service for managing organization invitations."""
import asyncio
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import ROLE_ADMIN, ROLE_OWNER
from server.routes.org_invitation_models import (
EmailMismatchError,
InsufficientPermissionError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.services.email_service import EmailService
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import OrgInvitationStore
from storage.org_member_store import OrgMemberStore
from storage.org_service import OrgService
from storage.org_store import OrgStore
from storage.role_store import RoleStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
class OrgInvitationService:
"""Service for organization invitation operations."""
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_name: str,
inviter_id: UUID,
) -> OrgInvitation:
"""Create a new organization invitation.
This method:
1. Validates the organization exists
2. Validates this is not a personal workspace
3. Checks inviter has owner/admin role
4. Validates role assignment permissions
5. Checks if user is already a member
6. Creates the invitation
7. Sends the invitation email
Args:
org_id: Organization UUID
email: Invitee's email address
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitation
Returns:
OrgInvitation: The created invitation
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
UserAlreadyMemberError: If email is already a member
InvitationAlreadyExistsError: If pending invitation exists
"""
email = email.lower().strip()
logger.info(
'Creating organization invitation',
extra={
'org_id': str(org_id),
'email': email,
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
# Step 2: Check this is not a personal workspace
# A personal workspace has org_id matching the user's id
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
# Step 3: Check inviter is a member and has permission
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
# Step 4: Validate role assignment permissions
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
# Get the target role
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 5: Check if user is already a member (by email)
existing_user = await UserStore.get_user_by_email_async(email)
if existing_user:
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
if existing_member:
raise UserAlreadyMemberError(
'User is already a member of this organization'
)
# Step 6: Create the invitation
invitation = await OrgInvitationStore.create_invitation(
org_id=org_id,
email=email,
role_id=target_role.id,
inviter_id=inviter_id,
)
# Step 7: Send invitation email
try:
# Get inviter info for the email
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
inviter_name = 'A team member'
if inviter_user and inviter_user.email:
inviter_name = inviter_user.email.split('@')[0]
EmailService.send_invitation_email(
to_email=email,
org_name=org.name,
inviter_name=inviter_name,
role_name=target_role.name,
invitation_token=invitation.token,
invitation_id=invitation.id,
)
except Exception as e:
logger.error(
'Failed to send invitation email',
extra={
'invitation_id': invitation.id,
'email': email,
'error': str(e),
},
)
# Don't fail the invitation creation if email fails
# The user can still access via direct link
return invitation
@staticmethod
async def create_invitations_batch(
org_id: UUID,
emails: list[str],
role_name: str,
inviter_id: UUID,
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
"""Create multiple organization invitations concurrently.
Validates permissions once upfront, then creates invitations in parallel.
Args:
org_id: Organization UUID
emails: List of invitee email addresses
role_name: Role to assign on acceptance (owner, admin, member)
inviter_id: User ID of the person creating the invitations
Returns:
Tuple of (successful_invitations, failed_emails_with_errors)
Raises:
ValueError: If organization or role not found
InsufficientPermissionError: If inviter lacks permission
"""
logger.info(
'Creating batch organization invitations',
extra={
'org_id': str(org_id),
'email_count': len(emails),
'role_name': role_name,
'inviter_id': str(inviter_id),
},
)
# Step 1: Validate permissions upfront (shared for all emails)
org = OrgStore.get_org_by_id(org_id)
if not org:
raise ValueError(f'Organization {org_id} not found')
if str(org_id) == str(inviter_id):
raise InsufficientPermissionError(
'Cannot invite users to a personal workspace'
)
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
if not inviter_member:
raise InsufficientPermissionError(
'You are not a member of this organization'
)
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
raise InsufficientPermissionError('Only owners and admins can invite users')
role_name_lower = role_name.lower()
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
raise InsufficientPermissionError('Only owners can invite with owner role')
target_role = RoleStore.get_role_by_name(role_name_lower)
if not target_role:
raise ValueError(f'Invalid role: {role_name}')
# Step 2: Create invitations concurrently
async def create_single(
email: str,
) -> tuple[str, OrgInvitation | None, str | None]:
"""Create single invitation, return (email, invitation, error)."""
try:
invitation = await OrgInvitationService.create_invitation(
org_id=org_id,
email=email,
role_name=role_name,
inviter_id=inviter_id,
)
return (email, invitation, None)
except (UserAlreadyMemberError, ValueError) as e:
return (email, None, str(e))
results = await asyncio.gather(*[create_single(email) for email in emails])
# Step 3: Separate successes and failures
successful: list[OrgInvitation] = []
failed: list[tuple[str, str]] = []
for email, invitation, error in results:
if invitation:
successful.append(invitation)
elif error:
failed.append((email, error))
logger.info(
'Batch invitation creation completed',
extra={
'org_id': str(org_id),
'successful': len(successful),
'failed': len(failed),
},
)
return successful, failed
@staticmethod
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
"""Accept an organization invitation.
This method:
1. Validates the token and invitation status
2. Checks expiration
3. Verifies user is not already a member
4. Creates LiteLLM integration
5. Adds user to the organization
6. Marks invitation as accepted
Args:
token: The invitation token
user_id: The user accepting the invitation
Returns:
OrgInvitation: The accepted invitation
Raises:
InvitationInvalidError: If token is invalid or invitation not pending
InvitationExpiredError: If invitation has expired
UserAlreadyMemberError: If user is already a member
"""
logger.info(
'Accepting organization invitation',
extra={
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
'user_id': str(user_id),
},
)
# Step 1: Get and validate invitation
invitation = await OrgInvitationStore.get_invitation_by_token(token)
if not invitation:
raise InvitationInvalidError('Invalid invitation token')
if invitation.status != OrgInvitation.STATUS_PENDING:
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
raise InvitationInvalidError('Invitation has already been accepted')
elif invitation.status == OrgInvitation.STATUS_REVOKED:
raise InvitationInvalidError('Invitation has been revoked')
else:
raise InvitationInvalidError('Invitation is no longer valid')
# Step 2: Check expiration
if OrgInvitationStore.is_token_expired(invitation):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
raise InvitationExpiredError('Invitation has expired')
# Step 2.5: Verify user email matches invitation email
user = await UserStore.get_user_by_id_async(str(user_id))
if not user:
raise InvitationInvalidError('User not found')
user_email = user.email
# Fallback: fetch email from Keycloak if not in database (for existing users)
if not user_email:
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
user_email = user_info.get('email') if user_info else None
if not user_email:
raise EmailMismatchError('Your account does not have an email address')
user_email = user_email.lower().strip()
invitation_email = invitation.email.lower().strip()
if user_email != invitation_email:
logger.warning(
'Email mismatch during invitation acceptance',
extra={
'user_id': str(user_id),
'user_email': user_email,
'invitation_email': invitation_email,
'invitation_id': invitation.id,
},
)
raise EmailMismatchError()
# Step 3: Check if user is already a member
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
if existing_member:
raise UserAlreadyMemberError(
'You are already a member of this organization'
)
# Step 4: Create LiteLLM integration for the user in the new org
try:
settings = await OrgService.create_litellm_integration(
invitation.org_id, str(user_id)
)
except Exception as e:
logger.error(
'Failed to create LiteLLM integration for invitation acceptance',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'error': str(e),
},
)
raise InvitationInvalidError(
'Failed to set up organization access. Please try again.'
)
# Step 5: Add user to organization
from storage.org_member_store import OrgMemberStore as OMS
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
# Don't override with org defaults - use invitation-specified role
org_member_kwargs.pop('llm_model', None)
org_member_kwargs.pop('llm_base_url', None)
OrgMemberStore.add_user_to_org(
org_id=invitation.org_id,
user_id=user_id,
role_id=invitation.role_id,
llm_api_key=settings.llm_api_key,
status='active',
)
# Step 6: Mark invitation as accepted
updated_invitation = await OrgInvitationStore.update_invitation_status(
invitation.id,
OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
logger.info(
'Organization invitation accepted',
extra={
'invitation_id': invitation.id,
'user_id': str(user_id),
'org_id': str(invitation.org_id),
'role_id': invitation.role_id,
},
)
return updated_invitation

View File

@@ -104,7 +104,7 @@ class OrgMemberService:
user_id=str(member.user_id),
email=user.email if user else None,
role_id=member.role_id,
role=role.name if role else '',
role_name=role.name if role else '',
role_rank=role.rank if role else 0,
status=member.status,
)
@@ -240,7 +240,7 @@ class OrgMemberService:
user_id=str(target_membership.user_id),
email=user.email if user else None,
role_id=target_membership.role_id,
role=target_role.name,
role_name=target_role.name,
role_rank=target_role.rank,
status=target_membership.status,
)
@@ -280,7 +280,7 @@ class OrgMemberService:
user_id=str(updated_member.user_id),
email=user.email if user else None,
role_id=updated_member.role_id,
role=new_role.name,
role_name=new_role.name,
role_rank=new_role.rank,
status=updated_member.status,
)

View File

@@ -22,63 +22,11 @@ from openhands.app_server.app_conversation.app_conversation_models import (
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.errors import AuthError
from openhands.app_server.services.injector import InjectorState
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
async def _get_current_user(self) -> User | None:
"""Get the current user using the existing db_session.
Uses self.db_session to avoid opening a separate database session.
Returns:
User object or None if no user_id is available
"""
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
return None
user_id_uuid = UUID(user_id_str)
result = await self.db_session.execute(
select(User).where(User.id == user_id_uuid)
)
return result.scalars().first()
async def _apply_user_and_org_filter(self, query):
"""Apply user_id and org_id filters to ensure conversation isolation.
Filters conversations by:
- user_id: Only show conversations belonging to the current user
- org_id: Only show conversations belonging to the user's current organization
Args:
query: SQLAlchemy query to apply filters to
Returns:
Query with user and organization filters applied
Raises:
AuthError: If no user_id is available (secure default: deny access)
"""
user_id_str = await self.user_context.get_user_id()
if not user_id_str:
# Secure default: no user means no access, not "show everything"
raise AuthError('User authentication required')
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
# Filter by organization ID to ensure conversations are isolated per organization
user = await self._get_current_user()
if user and user.current_org_id is not None:
query = query.where(
StoredConversationMetadataSaas.org_id == user.current_org_id
)
return query
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
async def _secure_select(self):
query = (
@@ -90,7 +38,13 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
return await self._apply_user_and_org_filter(query)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
async def _secure_select_with_saas_metadata(self):
"""Select query that includes SAAS metadata for retrieving user_id."""
@@ -103,7 +57,13 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
return await self._apply_user_and_org_filter(query)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
async def search_app_conversation_info(
self,
@@ -195,16 +155,21 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Count conversations matching the given filters with SAAS metadata."""
query = (
select(func.count(StoredConversationMetadata.conversation_id))
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
.select_from(
StoredConversationMetadata.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
# Apply user and organization filtering
query = await self._apply_user_and_org_filter(query)
# Apply user filtering
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
query = self._apply_filters_with_saas_metadata(
query=query,

View File

@@ -20,10 +20,8 @@ from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR
from storage.org import Org
from storage.org_invitation import OrgInvitation
from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation
from storage.resend_synced_user import ResendSyncedUser
from storage.role import Role
from storage.slack_conversation import SlackConversation
from storage.slack_team import SlackTeam
@@ -67,10 +65,8 @@ __all__ = [
'MaintenanceTaskStatus',
'OpenhandsPR',
'Org',
'OrgInvitation',
'OrgMember',
'ProactiveConversation',
'ResendSyncedUser',
'Role',
'SlackConversation',
'SlackTeam',

View File

@@ -126,7 +126,7 @@ class ApiKeyStore:
return True
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
async def list_api_keys(self, user_id: str) -> list[dict]:
"""List all API keys for a user."""
user = await UserStore.get_user_by_id_async(user_id)
org_id = user.current_org_id
@@ -134,14 +134,24 @@ class ApiKeyStore:
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
with self.session_maker() as session:
keys: list[ApiKey] = (
keys = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
return [key for key in keys if key.name != 'MCP_API_KEY']
return [
{
'id': key.id,
'name': key.name,
'created_at': key.created_at,
'last_used_at': key.last_used_at,
'expires_at': key.expires_at,
}
for key in keys
if 'MCP_API_KEY' != key.name
]
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = await UserStore.get_user_by_id_async(user_id)

View File

@@ -18,17 +18,17 @@ def _get_db_session_injector():
return _config.db_session
def session_maker(**kwargs):
def session_maker():
db_session_injector = _get_db_session_injector()
factory = db_session_injector.get_session_maker()
return factory(**kwargs)
session_maker = db_session_injector.get_session_maker()
return session_maker()
@contextlib.asynccontextmanager
async def a_session_maker(**kwargs):
async def a_session_maker():
db_session_injector = _get_db_session_injector()
factory = await db_session_injector.get_async_session_maker()
async with factory(**kwargs) as session:
a_session_maker = await db_session_injector.get_async_session_maker()
async with a_session_maker() as session:
yield session

View File

@@ -43,34 +43,6 @@ def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
class LiteLlmManager:
"""Manage LiteLLM interactions."""
@staticmethod
def get_budget_from_team_info(
user_team_info: dict | None, user_id: str, org_id: str
) -> tuple[float, float]:
"""Extract max_budget and spend from user team info.
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
Args:
user_team_info: The response from get_user_team_info
user_id: The user's ID
org_id: The organization's ID
Returns:
Tuple of (max_budget, spend)
"""
if not user_team_info:
return 0, 0
spend = user_team_info.get('spend', 0)
if user_id == org_id:
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
else:
max_budget = user_team_info.get('max_budget_in_team') or 0
return max_budget, spend
@staticmethod
async def create_entries(
org_id: str,
@@ -99,34 +71,8 @@ class LiteLlmManager:
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
# Check if team already exists and get its budget
# New users joining existing orgs should inherit the team's budget
team_budget = 0.0
try:
existing_team = await LiteLlmManager._get_team(client, org_id)
if existing_team:
team_info = existing_team.get('team_info', {})
team_budget = team_info.get('max_budget', 0.0) or 0.0
logger.info(
'LiteLlmManager:create_entries:existing_team_budget',
extra={
'org_id': org_id,
'user_id': keycloak_user_id,
'team_budget': team_budget,
},
)
except httpx.HTTPStatusError as e:
# Team doesn't exist yet (404) - this is expected for first user
if e.response.status_code != 404:
raise
logger.info(
'LiteLlmManager:create_entries:no_existing_team',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, team_budget
)
# New users start with $0 budget - they must purchase credits
await LiteLlmManager._create_team(client, keycloak_user_id, org_id, 0)
if create_user:
await LiteLlmManager._create_user(
@@ -134,7 +80,7 @@ class LiteLlmManager:
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, team_budget
client, keycloak_user_id, org_id, 0
)
key = await LiteLlmManager._generate_key(
@@ -946,31 +892,21 @@ class LiteLlmManager:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
team_response = await LiteLlmManager._get_team(client, team_id)
if not team_response:
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
return None
# Filter team_memberships based on team_id and keycloak_user_id
user_membership = next(
(
membership
for membership in team_response.get('team_memberships', [])
for membership in team_info.get('team_memberships', [])
if membership.get('user_id') == keycloak_user_id
and membership.get('team_id') == team_id
),
None,
)
if not user_membership:
return None
# For team orgs (user_id != team_id), include team-level budget info
# The team's max_budget and spend are shared across all members
if keycloak_user_id != team_id:
team_info = team_response.get('team_info', {})
user_membership['max_budget_in_team'] = team_info.get('max_budget')
user_membership['spend'] = team_info.get('spend', 0)
return user_membership
@staticmethod

View File

@@ -47,11 +47,11 @@ class Org(Base): # type: ignore
conversation_expiration = Column(Integer, nullable=True)
condenser_max_size = Column(Integer, nullable=True)
byor_export_enabled = Column(Boolean, nullable=False, default=False)
pending_free_credits = Column(Boolean, nullable=False, default=False)
# Relationships
org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org')
invitations = relationship('OrgInvitation', back_populates='org')
billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org'

View File

@@ -1,59 +0,0 @@
"""
SQLAlchemy model for Organization Invitation.
"""
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.orm import relationship
from storage.base import Base
class OrgInvitation(Base): # type: ignore
"""Organization invitation model.
Represents an invitation for a user to join an organization.
Invitations are created by organization owners/admins and contain
a secure token that can be used to accept the invitation.
"""
__tablename__ = 'org_invitation'
id = Column(Integer, primary_key=True, autoincrement=True)
token = Column(String(64), nullable=False, unique=True, index=True)
org_id = Column(
UUID(as_uuid=True),
ForeignKey('org.id', ondelete='CASCADE'),
nullable=False,
index=True,
)
email = Column(String(255), nullable=False, index=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
status = Column(
String(20),
nullable=False,
server_default=text("'pending'"),
)
created_at = Column(
DateTime,
nullable=False,
server_default=text('CURRENT_TIMESTAMP'),
)
expires_at = Column(DateTime, nullable=False)
accepted_at = Column(DateTime, nullable=True)
accepted_by_user_id = Column(
UUID(as_uuid=True),
ForeignKey('user.id'),
nullable=True,
)
# Relationships
org = relationship('Org', back_populates='invitations')
role = relationship('Role')
inviter = relationship('User', foreign_keys=[inviter_id])
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
# Status constants
STATUS_PENDING = 'pending'
STATUS_ACCEPTED = 'accepted'
STATUS_REVOKED = 'revoked'
STATUS_EXPIRED = 'expired'

View File

@@ -1,227 +0,0 @@
"""
Store class for managing organization invitations.
"""
import secrets
import string
from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID
from sqlalchemy import and_, select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker
from storage.org_invitation import OrgInvitation
from openhands.core.logger import openhands_logger as logger
# Invitation token configuration
INVITATION_TOKEN_PREFIX = 'inv-'
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
DEFAULT_EXPIRATION_DAYS = 7
class OrgInvitationStore:
"""Store for managing organization invitations."""
@staticmethod
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
"""Generate a secure invitation token.
Uses cryptographically secure random generation for tokens.
Pattern from api_key_store.py.
Args:
length: Length of the random part of the token
Returns:
str: Token with prefix (e.g., 'inv-aBcDeF123...')
"""
alphabet = string.ascii_letters + string.digits
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
@staticmethod
async def create_invitation(
org_id: UUID,
email: str,
role_id: int,
inviter_id: UUID,
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
) -> OrgInvitation:
"""Create a new organization invitation.
Args:
org_id: Organization UUID
email: Invitee's email address
role_id: Role ID to assign on acceptance
inviter_id: User ID of the person creating the invitation
expiration_days: Days until the invitation expires
Returns:
OrgInvitation: The created invitation record
"""
async with a_session_maker() as session:
token = OrgInvitationStore.generate_token()
# Use timezone-naive datetime for database compatibility
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
invitation = OrgInvitation(
token=token,
org_id=org_id,
email=email.lower().strip(),
role_id=role_id,
inviter_id=inviter_id,
status=OrgInvitation.STATUS_PENDING,
expires_at=expires_at,
)
session.add(invitation)
await session.commit()
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.role))
.filter(OrgInvitation.id == invitation.id)
)
invitation = result.scalars().first()
logger.info(
'Created organization invitation',
extra={
'invitation_id': invitation.id,
'org_id': str(org_id),
'email': email,
'inviter_id': str(inviter_id),
'expires_at': expires_at.isoformat(),
},
)
return invitation
@staticmethod
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
"""Get an invitation by its token.
Args:
token: The invitation token
Returns:
OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation)
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
.filter(OrgInvitation.token == token)
)
return result.scalars().first()
@staticmethod
async def get_pending_invitation(
org_id: UUID, email: str
) -> Optional[OrgInvitation]:
"""Get a pending invitation for an email in an organization.
Args:
org_id: Organization UUID
email: Email address to check
Returns:
OrgInvitation or None if no pending invitation exists
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(
and_(
OrgInvitation.org_id == org_id,
OrgInvitation.email == email.lower().strip(),
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
)
)
)
return result.scalars().first()
@staticmethod
async def update_invitation_status(
invitation_id: int,
status: str,
accepted_by_user_id: Optional[UUID] = None,
) -> Optional[OrgInvitation]:
"""Update an invitation's status.
Args:
invitation_id: The invitation ID
status: New status (pending, accepted, revoked, expired)
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
Returns:
Updated OrgInvitation or None if not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
)
invitation = result.scalars().first()
if not invitation:
return None
old_status = invitation.status
invitation.status = status
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
# Use timezone-naive datetime for database compatibility
invitation.accepted_at = datetime.utcnow()
invitation.accepted_by_user_id = accepted_by_user_id
await session.commit()
await session.refresh(invitation)
logger.info(
'Updated invitation status',
extra={
'invitation_id': invitation_id,
'old_status': old_status,
'new_status': status,
'accepted_by_user_id': (
str(accepted_by_user_id) if accepted_by_user_id else None
),
},
)
return invitation
@staticmethod
def is_token_expired(invitation: OrgInvitation) -> bool:
"""Check if an invitation token has expired.
Args:
invitation: The invitation to check
Returns:
bool: True if expired, False otherwise
"""
# Use timezone-naive datetime for comparison (database stores without timezone)
now = datetime.utcnow()
return invitation.expires_at < now
@staticmethod
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
"""Check if invitation is expired and update status if needed.
Args:
invitation: The invitation to check
Returns:
bool: True if invitation was marked as expired, False otherwise
"""
if (
invitation.status == OrgInvitation.STATUS_PENDING
and OrgInvitationStore.is_token_expired(invitation)
):
await OrgInvitationStore.update_invitation_status(
invitation.id, OrgInvitation.STATUS_EXPIRED
)
return True
return False

View File

@@ -9,7 +9,6 @@ from sqlalchemy import select
from sqlalchemy.orm import joinedload
from storage.database import a_session_maker, session_maker
from storage.org_member import OrgMember
from storage.user import User
from storage.user_settings import UserSettings
from openhands.storage.data_models.settings import Settings
@@ -61,51 +60,6 @@ class OrgMemberStore:
)
return result.scalars().first()
@staticmethod
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
"""Get the org member for a user's current organization.
Args:
user_id: The user's UUID.
Returns:
The OrgMember for the user's current organization, or None if not found.
"""
with session_maker() as session:
result = (
session.query(OrgMember)
.join(User, User.id == OrgMember.user_id)
.filter(
User.id == user_id,
OrgMember.org_id == User.current_org_id,
)
.first()
)
return result
@staticmethod
async def get_org_member_for_current_org_async(
user_id: UUID,
) -> Optional[OrgMember]:
"""Get the org member for a user's current organization (async version).
Args:
user_id: The user's UUID.
Returns:
The OrgMember for the user's current organization, or None if not found.
"""
async with a_session_maker() as session:
result = await session.execute(
select(OrgMember)
.join(User, User.id == OrgMember.user_id)
.filter(
User.id == user_id,
OrgMember.org_id == User.current_org_id,
)
)
return result.scalars().first()
@staticmethod
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
"""Get all organizations for a user."""

View File

@@ -112,6 +112,7 @@ class OrgService:
contact_email=contact_email,
org_version=ORG_SETTINGS_VERSION,
default_llm_model=get_default_litellm_model(),
pending_free_credits=True,
)
@staticmethod
@@ -656,9 +657,10 @@ class OrgService:
)
return None
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
user_team_info, user_id, str(org_id)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
spend = user_team_info.get('spend', 0)
credits = max(max_budget - spend, 0)
logger.debug(

View File

@@ -1,35 +0,0 @@
"""SQLAlchemy model for tracking users synced to Resend audiences."""
from datetime import UTC, datetime
from uuid import uuid4
from sqlalchemy import Column, DateTime, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from storage.base import Base
class ResendSyncedUser(Base): # type: ignore
"""Tracks users that have been synced to a Resend audience.
This table ensures that once a user is synced to a Resend audience,
they won't be re-added even if they are later deleted from the
Resend UI. This respects manual deletions/unsubscribes.
"""
__tablename__ = 'resend_synced_users'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
email = Column(String, nullable=False, index=True)
audience_id = Column(String, nullable=False, index=True)
synced_at = Column(
DateTime(timezone=True),
default=lambda: datetime.now(UTC),
nullable=False,
)
keycloak_user_id = Column(String, nullable=True)
__table_args__ = (
UniqueConstraint(
'email', 'audience_id', name='uq_resend_synced_email_audience'
),
)

View File

@@ -1,125 +0,0 @@
"""Store class for managing Resend synced users."""
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Optional, Set
from sqlalchemy import delete, select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import sessionmaker
from storage.resend_synced_user import ResendSyncedUser
@dataclass
class ResendSyncedUserStore:
"""Store for tracking users synced to Resend audiences."""
session_maker: sessionmaker
def is_user_synced(self, email: str, audience_id: str) -> bool:
"""Check if a user has been synced to a specific audience.
Args:
email: The email address to check.
audience_id: The Resend audience ID.
Returns:
True if the user has been synced, False otherwise.
"""
with self.session_maker() as session:
stmt = select(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt).first()
return result is not None
def get_synced_emails_for_audience(self, audience_id: str) -> Set[str]:
"""Get all synced email addresses for a specific audience.
Args:
audience_id: The Resend audience ID.
Returns:
A set of lowercase email addresses that have been synced.
"""
with self.session_maker() as session:
stmt = select(ResendSyncedUser.email).where(
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt).scalars().all()
return set(result)
def mark_user_synced(
self,
email: str,
audience_id: str,
keycloak_user_id: Optional[str] = None,
) -> ResendSyncedUser:
"""Mark a user as synced to a specific audience.
Uses upsert to handle race conditions - if the user is already
marked as synced, this is a no-op.
Args:
email: The email address of the user.
audience_id: The Resend audience ID.
keycloak_user_id: Optional Keycloak user ID.
Returns:
The ResendSyncedUser record.
Raises:
RuntimeError: If the record could not be created or retrieved.
"""
with self.session_maker() as session:
stmt = (
insert(ResendSyncedUser)
.values(
email=email.lower(),
audience_id=audience_id,
keycloak_user_id=keycloak_user_id,
synced_at=datetime.now(UTC),
)
.on_conflict_do_nothing(constraint='uq_resend_synced_email_audience')
.returning(ResendSyncedUser)
)
result = session.execute(stmt)
session.commit()
row = result.first()
if row:
return row[0]
# on_conflict_do_nothing triggered, fetch the existing record
existing = session.execute(
select(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
).first()
if existing:
return existing[0]
raise RuntimeError(
f'Failed to create or retrieve synced user record for {email}'
)
def remove_synced_user(self, email: str, audience_id: str) -> bool:
"""Remove a user's synced status for a specific audience.
Args:
email: The email address of the user.
audience_id: The Resend audience ID.
Returns:
True if a record was deleted, False if no record existed.
"""
with self.session_maker() as session:
stmt = delete(ResendSyncedUser).where(
ResendSyncedUser.email == email.lower(),
ResendSyncedUser.audience_id == audience_id,
)
result = session.execute(stmt)
session.commit()
return result.rowcount > 0

View File

@@ -29,20 +29,6 @@ class RoleStore:
with session_maker() as session:
return session.query(Role).filter(Role.id == role_id).first()
@staticmethod
async def get_role_by_id_async(
role_id: int,
session: Optional[AsyncSession] = None,
) -> Optional[Role]:
"""Get role by ID (async version)."""
if session is not None:
result = await session.execute(select(Role).where(Role.id == role_id))
return result.scalars().first()
async with a_session_maker() as session:
result = await session.execute(select(Role).where(Role.id == role_id))
return result.scalars().first()
@staticmethod
def get_role_by_name(name: str) -> Optional[Role]:
"""Get role by name."""

View File

@@ -59,6 +59,7 @@ class UserStore:
or user_info.get('preferred_username', ''),
contact_email=user_info['email'],
v1_enabled=True,
pending_free_credits=True,
)
session.add(org)
@@ -83,8 +84,6 @@ class UserStore:
role_id=role_id,
**user_kwargs,
)
user.email = user_info.get('email')
user.email_verified = user_info.get('email_verified')
session.add(user)
role = RoleStore.get_role_by_name('owner')
@@ -197,6 +196,7 @@ class UserStore:
or user_info.get('username', ''),
contact_email=user_info['email'],
byor_export_enabled=has_completed_billing,
pending_free_credits=not has_completed_billing,
)
session.add(org)
@@ -770,30 +770,6 @@ class UserStore:
finally:
await UserStore._release_user_creation_lock(user_id)
@staticmethod
async def get_user_by_email_async(email: str) -> Optional[User]:
"""Get user by email address (async version).
This method looks up a user by their email address. Note that email
addresses may not be unique across all users in rare cases.
Args:
email: The email address to search for
Returns:
User: The user with the matching email, or None if not found
"""
if not email:
return None
async with a_session_maker() as session:
result = await session.execute(
select(User)
.options(joinedload(User.org_members))
.filter(User.email == email.lower().strip())
)
return result.scalars().first()
@staticmethod
def list_users() -> list[User]:
"""List all users."""

View File

@@ -35,7 +35,6 @@ import resend
from keycloak.exceptions import KeycloakError
from resend.exceptions import ResendError
from server.auth.token_manager import get_keycloak_admin
from storage.resend_synced_user_store import ResendSyncedUserStore
from tenacity import (
retry,
retry_if_exception_type,
@@ -70,6 +69,9 @@ RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
# Set up Resend API
resend.api_key = RESEND_API_KEY
print('resend module', resend)
print('has contacts', hasattr(resend, 'Contacts'))
class ResendSyncError(Exception):
"""Base exception for Resend sync errors."""
@@ -96,7 +98,7 @@ class ResendAPIError(ResendSyncError):
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
def is_valid_email(email: Optional[str]) -> bool:
def is_valid_email(email: str) -> bool:
"""Validate an email address format.
This uses a regex pattern that matches most valid email addresses
@@ -104,10 +106,10 @@ def is_valid_email(email: Optional[str]) -> bool:
does not accept (e.g., exclamation marks).
Args:
email: The email address to validate, or None.
email: The email address to validate.
Returns:
True if the email is valid, False otherwise (including for None).
True if the email is valid, False otherwise.
"""
if not email:
return False
@@ -197,6 +199,8 @@ def get_resend_contacts(audience_id: str) -> Dict[str, Dict[str, Any]]:
Raises:
ResendAPIError: If the API call fails.
"""
print('getting resend contacts')
print('has resend contacts', hasattr(resend, 'Contacts'))
try:
contacts = resend.Contacts.list(audience_id).get('data', [])
# Create a dictionary mapping email addresses to contact data for
@@ -251,15 +255,6 @@ def add_contact_to_resend(
raise
@retry(
stop=stop_after_attempt(MAX_RETRIES),
wait=wait_exponential(
multiplier=INITIAL_BACKOFF_SECONDS,
max=MAX_BACKOFF_SECONDS,
exp_base=BACKOFF_FACTOR,
),
retry=retry_if_exception_type(ResendError),
)
def send_welcome_email(
email: str,
first_name: Optional[str] = None,
@@ -276,7 +271,7 @@ def send_welcome_email(
The API response.
Raises:
ResendError: If the API call fails after retries.
ResendError: If the API call fails.
"""
try:
# Prepare the recipient name
@@ -322,84 +317,8 @@ def send_welcome_email(
raise
def _get_resend_synced_user_store() -> ResendSyncedUserStore:
"""Get the ResendSyncedUserStore instance.
This is separated into a function to allow for easier testing/mocking.
"""
from openhands.app_server.config import get_global_config
config = get_global_config()
db_session_injector = config.db_session
return ResendSyncedUserStore(session_maker=db_session_injector.get_session_maker())
def _backfill_existing_resend_contacts(
synced_user_store: ResendSyncedUserStore,
audience_id: str,
) -> int:
"""Backfill the synced_users table with contacts already in Resend.
This ensures that users who were added to Resend before the tracking
table existed are properly recorded, preventing duplicate welcome emails.
Args:
synced_user_store: The store for tracking synced users.
audience_id: The Resend audience ID.
Returns:
The number of contacts backfilled.
"""
logger.info('Starting backfill of existing Resend contacts...')
try:
resend_contacts = get_resend_contacts(audience_id)
logger.info(f'Found {len(resend_contacts)} contacts in Resend audience')
already_synced_emails = synced_user_store.get_synced_emails_for_audience(
audience_id
)
logger.info(
f'Found {len(already_synced_emails)} already synced emails in database'
)
backfilled_count = 0
for email in resend_contacts:
if email.lower() not in already_synced_emails:
synced_user_store.mark_user_synced(
email=email,
audience_id=audience_id,
keycloak_user_id=None, # We don't have this info during backfill
)
backfilled_count += 1
logger.debug(f'Backfilled existing Resend contact: {email}')
logger.info(
f'Backfill completed: {backfilled_count} contacts added to tracking'
)
return backfilled_count
except Exception:
logger.exception('Error during backfill of existing Resend contacts')
# Don't fail the entire sync if backfill fails - just log and continue
return 0
def sync_users_to_resend():
"""Sync users from Keycloak to Resend.
This function syncs users from Keycloak to a Resend audience. It tracks
which users have been synced in the database to ensure that:
1. Users are only added once (even across multiple sync runs)
2. Users who are manually deleted from Resend are not re-added
The tracking is done via the resend_synced_users table, which records
each email/audience_id combination that has been synced.
On first run (or when new contacts exist in Resend), it will backfill
the tracking table with existing Resend contacts to avoid sending
duplicate welcome emails.
"""
"""Sync users from Keycloak to Resend."""
# Check required environment variables
required_vars = {
'RESEND_API_KEY': RESEND_API_KEY,
@@ -425,36 +344,28 @@ def sync_users_to_resend():
)
try:
# Get the store for tracking synced users
synced_user_store = _get_resend_synced_user_store()
# Backfill existing Resend contacts into our tracking table
# This ensures users already in Resend don't get duplicate welcome emails
backfilled_count = _backfill_existing_resend_contacts(
synced_user_store, RESEND_AUDIENCE_ID
)
# Get the total number of users
total_users = get_total_keycloak_users()
logger.info(
f'Found {total_users} users in Keycloak realm {KEYCLOAK_REALM_NAME}'
)
# Get contacts from Resend
resend_contacts = get_resend_contacts(RESEND_AUDIENCE_ID)
logger.info(
f'Found {len(resend_contacts)} contacts in Resend audience '
f'{RESEND_AUDIENCE_ID}'
)
# Stats
stats = {
'total_users': total_users,
'backfilled_contacts': backfilled_count,
'already_synced': 0,
'existing_contacts': len(resend_contacts),
'added_contacts': 0,
'skipped_invalid_emails': 0,
'errors': 0,
}
synced_emails = synced_user_store.get_synced_emails_for_audience(
RESEND_AUDIENCE_ID
)
logger.info(f'Found {len(synced_emails)} already synced emails in database')
# Process users in batches
offset = 0
while offset < total_users:
@@ -467,12 +378,8 @@ def sync_users_to_resend():
continue
email = email.lower()
if email in synced_emails:
logger.debug(
f'User {email} was already synced to this audience, skipping'
)
stats['already_synced'] += 1
if email in resend_contacts:
logger.debug(f'User {email} already exists in Resend, skipping')
continue
# Validate email format before attempting to add to Resend
@@ -481,51 +388,35 @@ def sync_users_to_resend():
stats['skipped_invalid_emails'] += 1
continue
first_name = user.get('first_name')
last_name = user.get('last_name')
keycloak_user_id = user.get('id')
# Mark as synced first (optimistic) to ensure consistency.
# If Resend API fails, we remove the record.
try:
synced_user_store.mark_user_synced(
email=email,
audience_id=RESEND_AUDIENCE_ID,
keycloak_user_id=keycloak_user_id,
)
except Exception:
logger.exception(f'Failed to mark user {email} as synced')
stats['errors'] += 1
continue
first_name = user.get('first_name')
last_name = user.get('last_name')
try:
# Add the contact to the Resend audience
add_contact_to_resend(
RESEND_AUDIENCE_ID, email, first_name, last_name
)
logger.info(f'Added user {email} to Resend')
stats['added_contacts'] += 1
# Sleep to respect rate limit after first API call
time.sleep(1 / RATE_LIMIT)
# Send a welcome email to the newly added contact
try:
send_welcome_email(email, first_name, last_name)
logger.info(f'Sent welcome email to {email}')
except Exception:
logger.exception(
f'Failed to send welcome email to {email}, but contact was added to audience'
)
# Continue with the sync process even if sending the welcome email fails
# Sleep to respect rate limit after second API call
time.sleep(1 / RATE_LIMIT)
except Exception:
logger.exception(f'Error adding user {email} to Resend')
synced_user_store.remove_synced_user(email, RESEND_AUDIENCE_ID)
stats['errors'] += 1
continue
synced_emails.add(email)
stats['added_contacts'] += 1
# Sleep to respect rate limit after first API call
time.sleep(1 / RATE_LIMIT)
# Send a welcome email to the newly added contact
try:
send_welcome_email(email, first_name, last_name)
logger.info(f'Sent welcome email to {email}')
except Exception:
logger.exception(
f'Failed to send welcome email to {email}, but contact was added to audience'
)
# Sleep to respect rate limit after second API call
time.sleep(1 / RATE_LIMIT)
offset += BATCH_SIZE

View File

@@ -6,8 +6,6 @@ import httpx
import pytest
from fastapi import HTTPException
from server.routes.api_keys import (
ByorPermittedResponse,
LlmApiKeyResponse,
check_byor_permitted,
delete_byor_key_from_litellm,
get_llm_api_key_for_byor,
@@ -205,7 +203,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == LlmApiKeyResponse(key=new_key)
assert result == {'key': new_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_generate_key.assert_called_once_with(user_id)
@@ -230,7 +228,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == LlmApiKeyResponse(key=existing_key)
assert result == {'key': existing_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(existing_key, user_id)
@@ -267,7 +265,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == LlmApiKeyResponse(key=new_key)
assert result == {'key': new_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(invalid_key, user_id)
@@ -307,7 +305,7 @@ class TestGetLlmApiKeyForByor:
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == LlmApiKeyResponse(key=new_key)
assert result == {'key': new_key}
mock_check_enabled.assert_called_once_with(user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
mock_generate_key.assert_called_once_with(user_id)
@@ -480,7 +478,7 @@ class TestCheckByorPermitted:
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == ByorPermittedResponse(permitted=True)
assert result == {'permitted': True}
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio
@@ -495,7 +493,7 @@ class TestCheckByorPermitted:
result = await check_byor_permitted(user_id=user_id)
# Assert
assert result == ByorPermittedResponse(permitted=False)
assert result == {'permitted': False}
mock_check_enabled.assert_called_once_with(user_id)
@pytest.mark.asyncio

View File

@@ -1220,60 +1220,3 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
assert result == mock_workspace
# Tests for OAuth URL encoding
class TestJiraDcOAuthUrlEncoding:
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
@pytest.mark.asyncio
@patch('server.routes.integration.jira_dc.get_user_auth')
@patch('server.routes.integration.jira_dc.redis_client')
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
async def test_create_jira_dc_workspace_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_jira_dc_workspace properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
workspace_data = JiraDcWorkspaceCreate(
workspace_name='test-workspace',
webhook_secret='secret',
svc_acc_email='svc@test.com',
svc_acc_api_key='key',
is_active=True,
)
response = await create_jira_dc_workspace(mock_request, workspace_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url
@pytest.mark.asyncio
@patch('server.routes.integration.jira_dc.get_user_auth')
@patch('server.routes.integration.jira_dc.redis_client')
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
async def test_create_workspace_link_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
link_data = JiraDcLinkCreate(workspace_name='test-workspace')
response = await create_workspace_link(mock_request, link_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url

View File

@@ -1323,58 +1323,3 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
assert result == mock_workspace
# Tests for OAuth URL encoding
class TestJiraOAuthUrlEncoding:
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
@pytest.mark.asyncio
@patch('server.routes.integration.jira.get_user_auth')
@patch('server.routes.integration.jira.redis_client')
async def test_create_jira_workspace_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_jira_workspace properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
workspace_data = JiraWorkspaceCreate(
workspace_name='test-workspace',
webhook_secret='secret',
svc_acc_email='svc@test.com',
svc_acc_api_key='key',
is_active=True,
)
response = await create_jira_workspace(mock_request, workspace_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url
@pytest.mark.asyncio
@patch('server.routes.integration.jira.get_user_auth')
@patch('server.routes.integration.jira.redis_client')
async def test_create_workspace_link_url_encoding(
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
):
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
mock_get_auth.return_value = mock_user_auth
mock_redis.setex.return_value = True
link_data = JiraLinkCreate(workspace_name='test-workspace')
response = await create_workspace_link(mock_request, link_data)
content = json.loads(response.body)
auth_url = content['authorizationUrl']
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
assert ' ' not in auth_url
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
# Verify redirect_uri is properly encoded
assert 'redirect_uri=https%3A%2F%2F' in auth_url

File diff suppressed because it is too large Load Diff

View File

@@ -179,7 +179,7 @@ class TestOrgMemberServiceGetOrgMembers:
assert data.items[0].user_id == str(current_user_id)
assert data.items[0].email == 'test@example.com'
assert data.items[0].role_id == 1
assert data.items[0].role == 'owner'
assert data.items[0].role_name == 'owner'
assert data.items[0].role_rank == 10
assert data.items[0].status == 'active'
@@ -462,7 +462,7 @@ class TestOrgMemberServiceGetOrgMembers:
assert success is True
assert data is not None
assert len(data.items) == 1
assert data.items[0].role == ''
assert data.items[0].role_name == ''
assert data.items[0].role_rank == 0
@pytest.mark.asyncio
@@ -1099,7 +1099,7 @@ class TestOrgMemberServiceUpdateOrgMember:
# Assert
assert isinstance(data, OrgMemberResponse)
assert data.role == 'admin'
assert data.role_name == 'admin'
assert data.role_rank == 20
mock_update.assert_called_once_with(org_id, target_user_id, admin_role.id)
@@ -1431,7 +1431,7 @@ class TestOrgMemberServiceUpdateOrgMember:
# Assert
assert data is not None
assert data.role == 'member'
assert data.role_name == 'member'
assert data.role_rank == 1000

View File

@@ -1,99 +0,0 @@
"""Tests for the enterprise storage.database module.
These tests verify that the session_maker function properly forwards
keyword arguments to the underlying session maker for backward compatibility.
"""
from unittest.mock import MagicMock, patch
class TestSessionMaker:
"""Test cases for the session_maker function."""
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_without_args(self, mock_get_injector):
"""Test that session_maker works without any arguments."""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
# Call session_maker without arguments
result = session_maker()
# Verify the inner session maker was called without arguments
mock_inner_session_maker.assert_called_once_with()
assert result == mock_session
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_with_expire_on_commit_false(self, mock_get_injector):
"""Test that session_maker accepts expire_on_commit keyword argument.
This is a critical backward compatibility test - the session_maker
must accept keyword arguments like expire_on_commit=False which is
used in slack.py and potentially other integration modules.
"""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
# Call session_maker with expire_on_commit=False
# This is the exact call pattern used in slack.py line 242
result = session_maker(expire_on_commit=False)
# Verify the inner session maker was called with the keyword argument
mock_inner_session_maker.assert_called_once_with(expire_on_commit=False)
assert result == mock_session
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_with_multiple_kwargs(self, mock_get_injector):
"""Test that session_maker passes through multiple keyword arguments."""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
# Call with multiple kwargs
result = session_maker(
expire_on_commit=False, autoflush=False, autocommit=False
)
# Verify all kwargs were passed through
mock_inner_session_maker.assert_called_once_with(
expire_on_commit=False, autoflush=False, autocommit=False
)
assert result == mock_session
@patch('enterprise.storage.database._get_db_session_injector')
def test_session_maker_returns_correct_session(self, mock_get_injector):
"""Test that session_maker returns the session from the inner session maker."""
from enterprise.storage.database import session_maker
# Set up mock
mock_injector = MagicMock()
mock_inner_session_maker = MagicMock()
mock_session = MagicMock()
mock_inner_session_maker.return_value = mock_session
mock_injector.get_session_maker.return_value = mock_inner_session_maker
mock_get_injector.return_value = mock_injector
result = session_maker()
# Verify the returned session is from the inner session maker
assert result is mock_session

View File

@@ -1,158 +0,0 @@
"""Unit tests for ResendSyncedUserStore."""
from unittest.mock import MagicMock
import pytest
# Import directly from the module files to avoid loading all of storage/__init__.py
# which has many dependencies
from storage.resend_synced_user import ResendSyncedUser
from storage.resend_synced_user_store import ResendSyncedUserStore
@pytest.fixture
def mock_session():
"""Mock database session."""
session = MagicMock()
return session
@pytest.fixture
def mock_session_maker(mock_session):
"""Mock session maker."""
session_maker = MagicMock()
session_maker.return_value.__enter__.return_value = mock_session
session_maker.return_value.__exit__.return_value = None
return session_maker
@pytest.fixture
def store(mock_session_maker):
"""Create ResendSyncedUserStore instance."""
return ResendSyncedUserStore(session_maker=mock_session_maker)
class TestResendSyncedUserStore:
"""Test cases for ResendSyncedUserStore."""
def test_is_user_synced_returns_true_when_exists(self, store, mock_session):
"""Test is_user_synced returns True when user exists in database."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_row = MagicMock()
mock_session.execute.return_value.first.return_value = mock_row
result = store.is_user_synced(email, audience_id)
assert result is True
mock_session.execute.assert_called_once()
def test_is_user_synced_returns_false_when_not_exists(self, store, mock_session):
"""Test is_user_synced returns False when user doesn't exist."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_session.execute.return_value.first.return_value = None
result = store.is_user_synced(email, audience_id)
assert result is False
def test_is_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
"""Test that is_user_synced normalizes email to lowercase."""
email = 'TEST@EXAMPLE.COM'
audience_id = 'test-audience-123'
mock_session.execute.return_value.first.return_value = None
store.is_user_synced(email, audience_id)
# Verify the query was called (we can't easily check the exact SQL)
mock_session.execute.assert_called_once()
def test_mark_user_synced_creates_new_record(self, store, mock_session):
"""Test that mark_user_synced creates a new record."""
email = 'test@example.com'
audience_id = 'test-audience-123'
keycloak_user_id = 'kc-user-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
result = store.mark_user_synced(email, audience_id, keycloak_user_id)
assert result == mock_synced_user
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_mark_user_synced_handles_existing_record(self, store, mock_session):
"""Test that mark_user_synced handles conflict (existing record)."""
email = 'test@example.com'
audience_id = 'test-audience-123'
# First execute (insert) returns None (conflict occurred)
# Second execute (select existing) returns the record
mock_existing_user = MagicMock(spec=ResendSyncedUser)
mock_result_insert = MagicMock()
mock_result_insert.first.return_value = None
mock_result_select = MagicMock()
mock_result_select.first.return_value = (mock_existing_user,)
mock_session.execute.side_effect = [mock_result_insert, mock_result_select]
result = store.mark_user_synced(email, audience_id)
assert result == mock_existing_user
assert mock_session.execute.call_count == 2
mock_session.commit.assert_called_once()
def test_mark_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
"""Test that mark_user_synced normalizes email to lowercase."""
email = 'TEST@EXAMPLE.COM'
audience_id = 'test-audience-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
store.mark_user_synced(email, audience_id)
# Verify execute was called (the email normalization happens in the SQL)
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
def test_mark_user_synced_without_keycloak_user_id(self, store, mock_session):
"""Test that mark_user_synced works without keycloak_user_id."""
email = 'test@example.com'
audience_id = 'test-audience-123'
mock_synced_user = MagicMock(spec=ResendSyncedUser)
mock_result = MagicMock()
mock_result.first.return_value = (mock_synced_user,)
mock_session.execute.return_value = mock_result
result = store.mark_user_synced(email, audience_id)
assert result == mock_synced_user
mock_session.execute.assert_called_once()
class TestResendSyncedUser:
"""Test cases for ResendSyncedUser model."""
def test_model_has_required_fields(self):
"""Test that the model has all required fields."""
assert hasattr(ResendSyncedUser, 'id')
assert hasattr(ResendSyncedUser, 'email')
assert hasattr(ResendSyncedUser, 'audience_id')
assert hasattr(ResendSyncedUser, 'synced_at')
assert hasattr(ResendSyncedUser, 'keycloak_user_id')
def test_model_table_name(self):
"""Test the model's table name."""
assert ResendSyncedUser.__tablename__ == 'resend_synced_users'

View File

@@ -10,12 +10,8 @@ from unittest.mock import AsyncMock, MagicMock
from uuid import UUID, uuid4
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.org import Org
from storage.user import User
from enterprise.server.utils.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
@@ -24,15 +20,10 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.service_types import ProviderType
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# Test UUIDs
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
@pytest.fixture
async def async_engine():
@@ -64,41 +55,6 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
yield db_session
@pytest.fixture
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session with pre-populated Org and User rows for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
# Insert Orgs first (required for User foreign key)
org1 = Org(
id=ORG1_ID,
name='test-org-1',
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
org2 = Org(
id=ORG2_ID,
name='test-org-2',
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
db_session.add(org1)
db_session.add(org2)
await db_session.flush()
# Insert Users
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
db_session.add(user1)
db_session.add(user2)
await db_session.commit()
yield db_session
@pytest.fixture
def service(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance for testing."""
@@ -222,26 +178,15 @@ class TestSaasSQLAppConversationInfoService:
assert user1_id != user2_id
@pytest.mark.asyncio
async def test_secure_select_includes_user_and_org_filtering(
async def test_secure_select_includes_user_filtering(
self,
async_session_with_users: AsyncSession,
saas_service_user1: SaasSQLAppConversationInfoService,
):
"""Test that _secure_select method includes both user_id and org_id filtering."""
service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
query = await service._secure_select()
# Convert query to string to verify filters are present
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
# Verify user_id filter is present
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
# Verify org_id filter is present (user1 is in org1)
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
"""Test that _secure_select method includes user filtering."""
# This test verifies that the _secure_select method exists and can be called
# The actual SQL generation is tested implicitly through integration
query = await saas_service_user1._secure_select()
assert query is not None
@pytest.mark.asyncio
async def test_to_info_with_user_id_functionality(
@@ -296,32 +241,100 @@ class TestSaasSQLAppConversationInfoService:
assert result.sandbox_id == 'test-sandbox'
@pytest.mark.asyncio
async def test_user_isolation_different_users(
async def test_user_isolation(
self,
async_session_with_users: AsyncSession,
async_session: AsyncSession,
multiple_conversation_infos: list[AppConversationInfo],
):
"""Test that different users cannot see each other's conversations."""
"""Test that user isolation works correctly."""
from unittest.mock import MagicMock
from storage.user import User
# Mock the database session execute method to return mock users
# This mock intercepts User queries and returns a mock user object
# with user_id and org_id the same as the user_id_uuid from the query
original_execute = async_session.execute
async def mock_execute(query):
query_str = str(query)
# Check if this is a User query
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
# Extract the UUID from the query parameters
# The query will have bound parameters, we need to get the UUID value
if hasattr(query, 'compile'):
try:
compiled = query.compile(compile_kwargs={'literal_binds': True})
query_with_params = str(compiled)
# Extract UUID from the query string
import re
# Try both formats: with dashes and without dashes
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
uuid_match = re.search(
uuid_pattern_with_dashes, query_with_params
)
if not uuid_match:
uuid_match = re.search(
uuid_pattern_without_dashes, query_with_params
)
if uuid_match:
user_id_str = uuid_match.group(0)
# If the UUID doesn't have dashes, add them
if len(user_id_str) == 32 and '-' not in user_id_str:
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
user_id_uuid = UUID(user_id_str)
# Create a mock user with user_id and org_id the same as user_id_uuid
mock_user = MagicMock(spec=User)
mock_user.id = user_id_uuid
mock_user.current_org_id = user_id_uuid
# Create a mock result
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_user
return mock_result
except Exception:
# If there's any error in parsing, fall back to original execute
pass
# For all other queries, use the original execute method
return await original_execute(query)
# Apply the mock
async_session.execute = mock_execute
# Create services for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
db_session=async_session,
user_context=SpecifyUserContext(
user_id='a1111111-1111-1111-1111-111111111111'
),
)
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
db_session=async_session,
user_context=SpecifyUserContext(
user_id='b2222222-2222-2222-2222-222222222222'
),
)
# Create conversations for different users
user1_info = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
sandbox_id='sandbox_user1',
title='User 1 Conversation',
)
user2_info = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER2_ID),
created_by_user_id='b2222222-2222-2222-2222-222222222222',
sandbox_id='sandbox_user2',
title='User 2 Conversation',
)
@@ -333,12 +346,18 @@ class TestSaasSQLAppConversationInfoService:
# User 1 should only see their conversation
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 1
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
assert (
user1_page.items[0].created_by_user_id
== 'a1111111-1111-1111-1111-111111111111'
)
# User 2 should only see their conversation
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 1
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
assert (
user2_page.items[0].created_by_user_id
== 'b2222222-2222-2222-2222-222222222222'
)
# User 1 should not be able to get user 2's conversation
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
@@ -347,142 +366,3 @@ class TestSaasSQLAppConversationInfoService:
# User 2 should not be able to get user 1's conversation
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
assert user1_from_user2 is None
@pytest.mark.asyncio
async def test_same_user_org_switching_isolation(
self,
async_session_with_users: AsyncSession,
):
"""Test that the same user switching orgs cannot see conversations from other orgs.
This tests the actual bug scenario: a user creates a conversation in org1,
then switches to org2, and should NOT see org1's conversations.
"""
# Create service for user1 in org1
user1_service_org1 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create a conversation while user is in org1
conv_in_org1 = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_org1',
title='Conversation in Org 1',
)
await user1_service_org1.save_app_conversation_info(conv_in_org1)
# Verify user can see the conversation in org1
page_in_org1 = await user1_service_org1.search_app_conversation_info()
assert len(page_in_org1.items) == 1
assert page_in_org1.items[0].title == 'Conversation in Org 1'
# Simulate user switching to org2 by updating current_org_id using ORM
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG2_ID
await async_session_with_users.commit()
# Clear SQLAlchemy's identity map cache to simulate a new request
async_session_with_users.expire_all()
# Create new service instance (simulating a new request after org switch)
user1_service_org2 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# User should NOT see org1's conversations after switching to org2
page_in_org2 = await user1_service_org2.search_app_conversation_info()
assert (
len(page_in_org2.items) == 0
), 'User should not see conversations from org1 after switching to org2'
# User should not be able to get the specific conversation from org1
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
conv_in_org1.id
)
assert (
conv_from_org2 is None
), 'User should not be able to access org1 conversation from org2'
# Now create a conversation in org2
conv_in_org2 = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id='sandbox_org2',
title='Conversation in Org 2',
)
await user1_service_org2.save_app_conversation_info(conv_in_org2)
# User should only see org2's conversation
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
assert len(page_in_org2_after.items) == 1
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
# Switch back to org1 and verify isolation works both ways
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG1_ID
await async_session_with_users.commit()
async_session_with_users.expire_all()
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# User should only see org1's conversation now
page_back_in_org1 = (
await user1_service_back_to_org1.search_app_conversation_info()
)
assert len(page_back_in_org1.items) == 1
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
@pytest.mark.asyncio
async def test_count_respects_org_isolation(
self,
async_session_with_users: AsyncSession,
):
"""Test that count_app_conversation_info respects org isolation."""
# Create service for user1 in org1
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Create conversations in org1
for i in range(3):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id=str(USER1_ID),
sandbox_id=f'sandbox_org1_{i}',
title=f'Org1 Conversation {i}',
)
await user1_service.save_app_conversation_info(conv)
# Count should be 3
count_org1 = await user1_service.count_app_conversation_info()
assert count_org1 == 3
# Switch to org2 using ORM
result = await async_session_with_users.execute(
select(User).where(User.id == USER1_ID)
)
user_to_update = result.scalars().first()
user_to_update.current_org_id = ORG2_ID
await async_session_with_users.commit()
async_session_with_users.expire_all()
user1_service_org2 = SaasSQLAppConversationInfoService(
db_session=async_session_with_users,
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
)
# Count should be 0 in org2
count_org2 = await user1_service_org2.count_app_conversation_info()
assert count_org2 == 0

View File

@@ -1,25 +1,6 @@
"""Tests for Resend Keycloak sync functionality."""
"""Tests for resend_keycloak email validation."""
import os
from unittest.mock import MagicMock, patch
import pytest
from resend.exceptions import ResendError
from tenacity import RetryError
# Set required environment variables before importing the module
# that reads them at import time
os.environ['RESEND_API_KEY'] = 'test_api_key'
os.environ['RESEND_AUDIENCE_ID'] = 'test_audience_id'
os.environ['KEYCLOAK_SERVER_URL'] = 'http://localhost:8080'
os.environ['KEYCLOAK_REALM_NAME'] = 'test_realm'
os.environ['KEYCLOAK_ADMIN_PASSWORD'] = 'test_password'
from enterprise.sync.resend_keycloak import ( # noqa: E402
add_contact_to_resend,
is_valid_email,
send_welcome_email,
)
from sync.resend_keycloak import is_valid_email
class TestIsValidEmail:
@@ -134,134 +115,3 @@ class TestIsValidEmail:
"""Test that validation works for uppercase emails."""
assert is_valid_email('USER@EXAMPLE.COM') is True
assert is_valid_email('User@Example.Com') is True
class TestSendWelcomeEmail:
"""Tests for send_welcome_email function."""
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_success(self, mock_send: MagicMock) -> None:
"""Test successful welcome email sending."""
mock_send.return_value = {'id': 'email_123'}
result = send_welcome_email(
email='test@example.com',
first_name='John',
last_name='Doe',
)
assert result == {'id': 'email_123'}
mock_send.assert_called_once()
call_args = mock_send.call_args[0][0]
assert call_args['to'] == ['test@example.com']
assert call_args['subject'] == 'Welcome to OpenHands Cloud'
assert 'Hi John Doe,' in call_args['html']
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_retries_on_rate_limit(
self, mock_send: MagicMock
) -> None:
"""Test that send_welcome_email retries on rate limit errors."""
# First two calls raise rate limit error, third succeeds
mock_send.side_effect = [
ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
),
ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
),
{'id': 'email_123'},
]
result = send_welcome_email(
email='test@example.com',
first_name='John',
last_name='Doe',
)
assert result == {'id': 'email_123'}
assert mock_send.call_count == 3
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_fails_after_max_retries(
self, mock_send: MagicMock
) -> None:
"""Test that send_welcome_email fails after max retries."""
# All calls raise rate limit error
mock_send.side_effect = ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
)
# Tenacity wraps the final exception in RetryError
with pytest.raises(RetryError):
send_welcome_email(
email='test@example.com',
first_name='John',
last_name='Doe',
)
# Default MAX_RETRIES is 3
assert mock_send.call_count == 3
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
def test_send_welcome_email_no_name(self, mock_send: MagicMock) -> None:
"""Test welcome email with no name provided."""
mock_send.return_value = {'id': 'email_123'}
result = send_welcome_email(email='test@example.com')
assert result == {'id': 'email_123'}
call_args = mock_send.call_args[0][0]
assert 'Hi there,' in call_args['html']
class TestAddContactToResend:
"""Tests for add_contact_to_resend function."""
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
def test_add_contact_to_resend_success(self, mock_create: MagicMock) -> None:
"""Test successful contact addition."""
mock_create.return_value = {'id': 'contact_123'}
result = add_contact_to_resend(
audience_id='test_audience',
email='test@example.com',
first_name='John',
last_name='Doe',
)
assert result == {'id': 'contact_123'}
mock_create.assert_called_once()
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
def test_add_contact_to_resend_retries_on_rate_limit(
self, mock_create: MagicMock
) -> None:
"""Test that add_contact_to_resend retries on rate limit errors."""
# First call raises rate limit error, second succeeds
mock_create.side_effect = [
ResendError(
code=429,
message='Too many requests',
error_type='rate_limit_exceeded',
suggested_action='',
),
{'id': 'contact_123'},
]
result = add_contact_to_resend(
audience_id='test_audience',
email='test@example.com',
)
assert result == {'id': 'contact_123'}
assert mock_create.call_count == 2

View File

@@ -265,17 +265,17 @@ async def test_list_api_keys(
# Verify
mock_get_user.assert_called_once_with(user_id)
assert len(result) == 2
assert result[0].id == 1
assert result[0].name == 'Key 1'
assert result[0].created_at == now
assert result[0].last_used_at == now
assert result[0].expires_at == now + timedelta(days=30)
assert result[0]['id'] == 1
assert result[0]['name'] == 'Key 1'
assert result[0]['created_at'] == now
assert result[0]['last_used_at'] == now
assert result[0]['expires_at'] == now + timedelta(days=30)
assert result[1].id == 2
assert result[1].name == 'Key 2'
assert result[1].created_at == now
assert result[1].last_used_at is None
assert result[1].expires_at is None
assert result[1]['id'] == 2
assert result[1]['name'] == 'Key 2'
assert result[1]['created_at'] == now
assert result[1]['last_used_at'] is None
assert result[1]['expires_at'] is None
@pytest.mark.asyncio

View File

@@ -1,181 +0,0 @@
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
import pytest
class TestAuthCallbackInvitationEmailMismatch:
"""Test cases for EmailMismatchError handling during auth callback."""
@pytest.fixture
def mock_redirect_url(self):
"""Base redirect URL."""
return 'https://app.example.com/'
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return '87654321-4321-8765-4321-876543218765'
def test_email_mismatch_appends_to_url_without_query_params(
self, mock_redirect_url, mock_user_id
):
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
from server.routes.org_invitation_models import EmailMismatchError
# Simulate the logic from auth.py
redirect_url = mock_redirect_url
try:
raise EmailMismatchError('Your email does not match the invitation')
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
from server.routes.org_invitation_models import EmailMismatchError
redirect_url = 'https://app.example.com/?other_param=value'
try:
raise EmailMismatchError()
except EmailMismatchError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&email_mismatch=true'
else:
redirect_url = f'{redirect_url}?email_mismatch=true'
assert (
redirect_url
== 'https://app.example.com/?other_param=value&email_mismatch=true'
)
def test_email_mismatch_error_has_default_message(self):
"""Test that EmailMismatchError has the default message."""
from server.routes.org_invitation_models import EmailMismatchError
error = EmailMismatchError()
assert str(error) == 'Your email does not match the invitation'
def test_email_mismatch_error_accepts_custom_message(self):
"""Test that EmailMismatchError accepts a custom message."""
from server.routes.org_invitation_models import EmailMismatchError
custom_message = 'Custom error message'
error = EmailMismatchError(custom_message)
assert str(error) == custom_message
def test_email_mismatch_error_is_invitation_error(self):
"""Test that EmailMismatchError inherits from InvitationError."""
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationError,
)
error = EmailMismatchError()
assert isinstance(error, InvitationError)
class TestInvitationTokenInOAuthState:
"""Test cases for invitation token handling in OAuth state."""
def test_invitation_token_included_in_oauth_state(self):
"""Test that invitation token is included in OAuth state data."""
import base64
import json
# Simulate building OAuth state with invitation token
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
assert decoded_data['redirect_url'] == 'https://app.example.com/'
def test_invitation_token_extracted_from_oauth_state(self):
"""Test that invitation token can be extracted from OAuth state."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
'invitation_token': 'inv-test-token-12345',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
# Simulate decoding in callback
decoded_state = json.loads(base64.b64decode(encoded_state))
invitation_token = decoded_state.get('invitation_token')
assert invitation_token == 'inv-test-token-12345'
def test_oauth_state_without_invitation_token(self):
"""Test that OAuth state works without invitation token."""
import base64
import json
state_data = {
'redirect_url': 'https://app.example.com/',
}
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
decoded_data = json.loads(base64.b64decode(encoded_state))
assert 'invitation_token' not in decoded_data
assert decoded_data['redirect_url'] == 'https://app.example.com/'
class TestAuthCallbackInvitationErrors:
"""Test cases for various invitation error scenarios in auth callback."""
def test_invitation_expired_appends_flag(self):
"""Test that invitation_expired=true is appended for expired invitations."""
from server.routes.org_invitation_models import InvitationExpiredError
redirect_url = 'https://app.example.com/'
try:
raise InvitationExpiredError()
except InvitationExpiredError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_expired=true'
else:
redirect_url = f'{redirect_url}?invitation_expired=true'
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
def test_invitation_invalid_appends_flag(self):
"""Test that invitation_invalid=true is appended for invalid invitations."""
from server.routes.org_invitation_models import InvitationInvalidError
redirect_url = 'https://app.example.com/'
try:
raise InvitationInvalidError()
except InvitationInvalidError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_invalid=true'
else:
redirect_url = f'{redirect_url}?invitation_invalid=true'
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
def test_already_member_appends_flag(self):
"""Test that already_member=true is appended when user is already a member."""
from server.routes.org_invitation_models import UserAlreadyMemberError
redirect_url = 'https://app.example.com/'
try:
raise UserAlreadyMemberError()
except UserAlreadyMemberError:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&already_member=true'
else:
redirect_url = f'{redirect_url}?already_member=true'
assert redirect_url == 'https://app.example.com/?already_member=true'

View File

@@ -1,756 +0,0 @@
"""
Unit tests for permission-based authorization (authorization.py).
Tests the FastAPI dependencies that validate user permissions within organizations.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import HTTPException
from server.auth.authorization import (
ROLE_PERMISSIONS,
Permission,
RoleName,
get_role_permissions,
get_user_org_role,
has_permission,
require_permission,
)
# =============================================================================
# Tests for Permission enum
# =============================================================================
class TestPermission:
"""Tests for Permission enum."""
def test_permission_values(self):
"""
GIVEN: Permission enum
WHEN: Accessing permission values
THEN: All expected permissions exist with correct string values
"""
assert Permission.MANAGE_SECRETS.value == 'manage_secrets'
assert Permission.MANAGE_MCP.value == 'manage_mcp'
assert Permission.MANAGE_INTEGRATIONS.value == 'manage_integrations'
assert (
Permission.MANAGE_APPLICATION_SETTINGS.value
== 'manage_application_settings'
)
assert Permission.MANAGE_API_KEYS.value == 'manage_api_keys'
assert Permission.VIEW_LLM_SETTINGS.value == 'view_llm_settings'
assert Permission.EDIT_LLM_SETTINGS.value == 'edit_llm_settings'
assert Permission.VIEW_BILLING.value == 'view_billing'
assert Permission.ADD_CREDITS.value == 'add_credits'
assert (
Permission.INVITE_USER_TO_ORGANIZATION.value
== 'invite_user_to_organization'
)
assert Permission.CHANGE_USER_ROLE_MEMBER.value == 'change_user_role:member'
assert Permission.CHANGE_USER_ROLE_ADMIN.value == 'change_user_role:admin'
assert Permission.CHANGE_USER_ROLE_OWNER.value == 'change_user_role:owner'
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
def test_permission_from_string(self):
"""
GIVEN: Valid permission string
WHEN: Creating Permission from string
THEN: Correct enum value is returned
"""
assert Permission('manage_secrets') == Permission.MANAGE_SECRETS
assert Permission('view_llm_settings') == Permission.VIEW_LLM_SETTINGS
assert Permission('delete_organization') == Permission.DELETE_ORGANIZATION
def test_permission_invalid_string(self):
"""
GIVEN: Invalid permission string
WHEN: Creating Permission from string
THEN: ValueError is raised
"""
with pytest.raises(ValueError):
Permission('invalid_permission')
# =============================================================================
# Tests for RoleName enum
# =============================================================================
class TestRoleName:
"""Tests for RoleName enum."""
def test_role_name_values(self):
"""
GIVEN: RoleName enum
WHEN: Accessing role name values
THEN: All expected roles exist with correct string values
"""
assert RoleName.OWNER.value == 'owner'
assert RoleName.ADMIN.value == 'admin'
assert RoleName.MEMBER.value == 'member'
def test_role_name_from_string(self):
"""
GIVEN: Valid role name string
WHEN: Creating RoleName from string
THEN: Correct enum value is returned
"""
assert RoleName('owner') == RoleName.OWNER
assert RoleName('admin') == RoleName.ADMIN
assert RoleName('member') == RoleName.MEMBER
def test_role_name_invalid_string(self):
"""
GIVEN: Invalid role name string
WHEN: Creating RoleName from string
THEN: ValueError is raised
"""
with pytest.raises(ValueError):
RoleName('invalid_role')
# =============================================================================
# Tests for ROLE_PERMISSIONS mapping
# =============================================================================
class TestRolePermissions:
"""Tests for role permission mappings."""
def test_owner_has_all_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking owner permissions
THEN: Owner has all permissions including owner-only permissions
"""
owner_perms = ROLE_PERMISSIONS[RoleName.OWNER]
assert Permission.MANAGE_SECRETS in owner_perms
assert Permission.MANAGE_MCP in owner_perms
assert Permission.VIEW_LLM_SETTINGS in owner_perms
assert Permission.EDIT_LLM_SETTINGS in owner_perms
assert Permission.VIEW_BILLING in owner_perms
assert Permission.ADD_CREDITS in owner_perms
assert Permission.INVITE_USER_TO_ORGANIZATION in owner_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in owner_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in owner_perms
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
assert Permission.DELETE_ORGANIZATION in owner_perms
def test_admin_has_admin_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking admin permissions
THEN: Admin has admin permissions but not owner-only permissions
"""
admin_perms = ROLE_PERMISSIONS[RoleName.ADMIN]
assert Permission.MANAGE_SECRETS in admin_perms
assert Permission.MANAGE_MCP in admin_perms
assert Permission.VIEW_LLM_SETTINGS in admin_perms
assert Permission.EDIT_LLM_SETTINGS in admin_perms
assert Permission.VIEW_BILLING in admin_perms
assert Permission.ADD_CREDITS in admin_perms
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
# Admin should NOT have owner-only permissions
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
assert Permission.DELETE_ORGANIZATION not in admin_perms
def test_member_has_limited_permissions(self):
"""
GIVEN: ROLE_PERMISSIONS mapping
WHEN: Checking member permissions
THEN: Member has limited permissions
"""
member_perms = ROLE_PERMISSIONS[RoleName.MEMBER]
# Member has basic settings permissions
assert Permission.MANAGE_SECRETS in member_perms
assert Permission.MANAGE_MCP in member_perms
assert Permission.MANAGE_INTEGRATIONS in member_perms
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
assert Permission.MANAGE_API_KEYS in member_perms
assert Permission.VIEW_LLM_SETTINGS in member_perms
assert Permission.VIEW_ORG_SETTINGS in member_perms
# Member should NOT have admin/owner permissions
assert Permission.EDIT_LLM_SETTINGS not in member_perms
assert Permission.VIEW_BILLING not in member_perms
assert Permission.ADD_CREDITS not in member_perms
assert Permission.INVITE_USER_TO_ORGANIZATION not in member_perms
assert Permission.CHANGE_USER_ROLE_MEMBER not in member_perms
assert Permission.CHANGE_USER_ROLE_ADMIN not in member_perms
assert Permission.CHANGE_USER_ROLE_OWNER not in member_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in member_perms
assert Permission.DELETE_ORGANIZATION not in member_perms
# =============================================================================
# Tests for get_role_permissions function
# =============================================================================
class TestGetRolePermissions:
"""Tests for get_role_permissions function."""
def test_get_owner_permissions(self):
"""
GIVEN: Role name 'owner'
WHEN: get_role_permissions is called
THEN: Owner permissions are returned
"""
perms = get_role_permissions('owner')
assert Permission.DELETE_ORGANIZATION in perms
assert Permission.CHANGE_ORGANIZATION_NAME in perms
def test_get_admin_permissions(self):
"""
GIVEN: Role name 'admin'
WHEN: get_role_permissions is called
THEN: Admin permissions are returned
"""
perms = get_role_permissions('admin')
assert Permission.EDIT_LLM_SETTINGS in perms
assert Permission.DELETE_ORGANIZATION not in perms
def test_get_member_permissions(self):
"""
GIVEN: Role name 'member'
WHEN: get_role_permissions is called
THEN: Member permissions are returned
"""
perms = get_role_permissions('member')
assert Permission.VIEW_LLM_SETTINGS in perms
assert Permission.EDIT_LLM_SETTINGS not in perms
def test_get_invalid_role_permissions(self):
"""
GIVEN: Invalid role name
WHEN: get_role_permissions is called
THEN: Empty frozenset is returned
"""
perms = get_role_permissions('invalid_role')
assert perms == frozenset()
# =============================================================================
# Tests for has_permission function
# =============================================================================
class TestHasPermission:
"""Tests for has_permission function."""
def test_owner_has_delete_organization_permission(self):
"""
GIVEN: User with owner role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'owner'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is True
def test_owner_has_view_llm_settings_permission(self):
"""
GIVEN: User with owner role
WHEN: Checking for VIEW_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'owner'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
def test_admin_has_edit_llm_settings_permission(self):
"""
GIVEN: User with admin role
WHEN: Checking for EDIT_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'admin'
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is True
def test_admin_lacks_delete_organization_permission(self):
"""
GIVEN: User with admin role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'admin'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
def test_member_has_view_llm_settings_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for VIEW_LLM_SETTINGS permission
THEN: Returns True
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
def test_member_lacks_edit_llm_settings_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for EDIT_LLM_SETTINGS permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is False
def test_member_lacks_delete_organization_permission(self):
"""
GIVEN: User with member role
WHEN: Checking for DELETE_ORGANIZATION permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'member'
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
def test_invalid_role_has_no_permissions(self):
"""
GIVEN: User with invalid role
WHEN: Checking for any permission
THEN: Returns False
"""
mock_role = MagicMock()
mock_role.name = 'invalid_role'
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is False
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
# =============================================================================
# Tests for get_user_org_role function
# =============================================================================
class TestGetUserOrgRole:
"""Tests for get_user_org_role function."""
def test_returns_role_when_member_exists(self):
"""
GIVEN: User is a member of organization with role
WHEN: get_user_org_role is called
THEN: Role object is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_org_member = MagicMock()
mock_org_member.role_id = 1
mock_role = MagicMock()
mock_role.name = 'admin'
with (
patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
return_value=mock_org_member,
),
patch(
'server.auth.authorization.RoleStore.get_role_by_id',
return_value=mock_role,
),
):
result = get_user_org_role(user_id, org_id)
assert result == mock_role
def test_returns_none_when_not_member(self):
"""
GIVEN: User is not a member of organization
WHEN: get_user_org_role is called
THEN: None is returned
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
return_value=None,
):
result = get_user_org_role(user_id, org_id)
assert result is None
def test_returns_role_when_org_id_is_none(self):
"""
GIVEN: User with a current organization
WHEN: get_user_org_role is called with org_id=None
THEN: Role object is returned using get_org_member_for_current_org
"""
user_id = str(uuid4())
mock_org_member = MagicMock()
mock_org_member.role_id = 1
mock_role = MagicMock()
mock_role.name = 'admin'
with (
patch(
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
return_value=mock_org_member,
) as mock_get_current,
patch(
'server.auth.authorization.OrgMemberStore.get_org_member',
) as mock_get_org_member,
patch(
'server.auth.authorization.RoleStore.get_role_by_id',
return_value=mock_role,
),
):
result = get_user_org_role(user_id, None)
assert result == mock_role
mock_get_current.assert_called_once()
mock_get_org_member.assert_not_called()
def test_returns_none_when_org_id_is_none_and_no_current_org(self):
"""
GIVEN: User with no current organization membership
WHEN: get_user_org_role is called with org_id=None
THEN: None is returned
"""
user_id = str(uuid4())
with patch(
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
return_value=None,
):
result = get_user_org_role(user_id, None)
assert result is None
# =============================================================================
# Tests for require_permission dependency
# =============================================================================
class TestRequirePermission:
"""Tests for require_permission dependency factory."""
@pytest.mark.asyncio
async def test_returns_user_id_when_authorized(self):
"""
GIVEN: User with required permission
WHEN: Permission checker is called
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_raises_401_when_not_authenticated(self):
"""
GIVEN: No user ID (not authenticated)
WHEN: Permission checker is called
THEN: 401 Unauthorized is raised
"""
org_id = uuid4()
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=None)
assert exc_info.value.status_code == 401
assert 'not authenticated' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_raises_403_when_not_member(self):
"""
GIVEN: User is not a member of organization
WHEN: Permission checker is called
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_raises_403_when_insufficient_permission(self):
"""
GIVEN: User without required permission
WHEN: Permission checker is called
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'delete_organization' in exc_info.value.detail.lower()
@pytest.mark.asyncio
async def test_owner_can_delete_organization(self):
"""
GIVEN: User with owner role
WHEN: DELETE_ORGANIZATION permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'owner'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_admin_cannot_delete_organization(self):
"""
GIVEN: User with admin role
WHEN: DELETE_ORGANIZATION permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_logs_warning_on_insufficient_permission(self):
"""
GIVEN: User without required permission
WHEN: Permission checker is called
THEN: Warning is logged with details
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with (
patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
),
patch('server.auth.authorization.logger') as mock_logger,
):
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
with pytest.raises(HTTPException):
await permission_checker(org_id=org_id, user_id=user_id)
mock_logger.warning.assert_called()
call_args = mock_logger.warning.call_args
assert call_args[1]['extra']['user_id'] == user_id
assert call_args[1]['extra']['user_role'] == 'member'
assert call_args[1]['extra']['required_permission'] == 'delete_organization'
@pytest.mark.asyncio
async def test_returns_user_id_when_org_id_is_none(self):
"""
GIVEN: User with required permission in their current org
WHEN: Permission checker is called with org_id=None
THEN: User ID is returned
"""
user_id = str(uuid4())
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
) as mock_get_role:
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
result = await permission_checker(org_id=None, user_id=user_id)
assert result == user_id
mock_get_role.assert_called_once_with(user_id, None)
@pytest.mark.asyncio
async def test_raises_403_when_org_id_is_none_and_not_member(self):
"""
GIVEN: User not a member of their current organization
WHEN: Permission checker is called with org_id=None
THEN: HTTPException with 403 status is raised
"""
user_id = str(uuid4())
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=None),
):
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=None, user_id=user_id)
assert exc_info.value.status_code == 403
assert 'not a member' in exc_info.value.detail
# =============================================================================
# Tests for permission-based access control scenarios
# =============================================================================
class TestPermissionScenarios:
"""Tests for real-world permission scenarios."""
@pytest.mark.asyncio
async def test_member_can_manage_secrets(self):
"""
GIVEN: User with member role
WHEN: MANAGE_SECRETS permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.MANAGE_SECRETS)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_member_cannot_invite_users(self):
"""
GIVEN: User with member role
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'member'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION
)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_admin_can_invite_users(self):
"""
GIVEN: User with admin role
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(
Permission.INVITE_USER_TO_ORGANIZATION
)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id
@pytest.mark.asyncio
async def test_admin_cannot_change_owner_role(self):
"""
GIVEN: User with admin role
WHEN: CHANGE_USER_ROLE_OWNER permission is required
THEN: 403 Forbidden is raised
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'admin'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
with pytest.raises(HTTPException) as exc_info:
await permission_checker(org_id=org_id, user_id=user_id)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_owner_can_change_owner_role(self):
"""
GIVEN: User with owner role
WHEN: CHANGE_USER_ROLE_OWNER permission is required
THEN: User ID is returned
"""
user_id = str(uuid4())
org_id = uuid4()
mock_role = MagicMock()
mock_role.name = 'owner'
with patch(
'server.auth.authorization.get_user_org_role_async',
AsyncMock(return_value=mock_role),
):
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
result = await permission_checker(org_id=org_id, user_id=user_id)
assert result == user_id

View File

@@ -101,7 +101,7 @@ async def test_get_credits_success():
json={
'user_info': {
'spend': 25.50,
'max_budget_in_team': 100.00,
'litellm_budget_table': {'max_budget': 100.00},
}
},
request=MagicMock(),
@@ -121,7 +121,7 @@ async def test_get_credits_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'max_budget_in_team': 100.00,
'litellm_budget_table': {'max_budget': 100.00},
},
),
):
@@ -291,7 +291,7 @@ async def test_success_callback_stripe_incomplete():
@pytest.mark.asyncio
async def test_success_callback_success():
"""Test successful payment completion and credit update."""
"""Test successful payment completion and credit update (bonus already granted)."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
@@ -300,6 +300,7 @@ async def test_success_callback_success():
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
mock_org.pending_free_credits = False # Not eligible (old org or already granted)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
@@ -313,7 +314,7 @@ async def test_success_callback_success():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'max_budget_in_team': 100.00,
'litellm_budget_table': {'max_budget': 100.00},
},
),
patch(
@@ -346,10 +347,10 @@ async def test_success_callback_success():
== 'https://test.com/settings/billing?checkout=success'
)
# Verify LiteLLM API calls
# Verify LiteLLM API calls - no bonus since not eligible
mock_update_budget.assert_called_once_with(
'mock_org_id',
125.0, # 100 + 25.00
125.0, # 100 + 25.00 (no bonus)
)
# Verify BYOR export is enabled for the org (updated in same session)
@@ -362,6 +363,92 @@ async def test_success_callback_success():
mock_db_session.commit.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.parametrize(
'initial_budget,purchase_cents,pending_credits,expected_final_budget,expected_pending_after',
[
# New user buys $10 -> gets free credits, pending becomes False
(0, 1000, True, 20.0, False),
# New user buys $5 -> below threshold, no free credits yet, pending stays True
(0, 500, True, 5.0, True),
# User with $5 buys $5 more -> reaches threshold, gets free credits
(5.0, 500, True, 20.0, False),
# User with $5 buys $3 -> below threshold, no free credits yet
(5.0, 300, True, 8.0, True),
# Old user (not pending) buys $25 -> no free credits, stays False
(20.0, 2500, False, 45.0, False),
],
ids=[
'new_user_buys_10_gets_free_credits',
'new_user_buys_5_below_threshold',
'user_with_5_buys_5_reaches_threshold',
'user_with_5_buys_3_below_threshold',
'old_user_not_eligible',
],
)
async def test_success_callback_free_credits(
initial_budget,
purchase_cents,
pending_credits,
expected_final_budget,
expected_pending_after,
):
"""Test free credits are granted only when pending and threshold is met."""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
mock_org.pending_free_credits = pending_credits
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch(
'storage.user_store.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 0,
'litellm_budget_table': {'max_budget': initial_budget},
},
),
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
) as mock_update_budget,
patch('server.routes.billing.FREE_CREDIT_THRESHOLD', 10.0),
patch('server.routes.billing.FREE_CREDIT_AMOUNT', 10.0),
):
mock_db_session = MagicMock()
mock_query_chain_billing = MagicMock()
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_query_chain_org = MagicMock()
mock_query_chain_org.filter.return_value.first.return_value = mock_org
mock_db_session.query.side_effect = [
mock_query_chain_billing,
mock_query_chain_org,
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete',
amount_subtotal=purchase_cents,
customer='mock_customer_id',
)
response = await success_callback('test_session_id', mock_request)
assert response.status_code == 302
mock_update_budget.assert_called_once_with('mock_org_id', expected_final_budget)
assert mock_org.pending_free_credits is expected_pending_after
@pytest.mark.asyncio
async def test_success_callback_lite_llm_error():
"""Test handling of LiteLLM API errors during success callback."""
@@ -404,10 +491,11 @@ async def test_success_callback_lite_llm_error():
@pytest.mark.asyncio
async def test_success_callback_lite_llm_update_budget_error_rollback():
"""Test that database changes are not committed when update_team_and_users_budget fails.
"""Test that pending_free_credits change is not committed when update_team_and_users_budget fails.
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
the database transaction rolls back.
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception
after pending_free_credits has been set to False, the database transaction rolls back and
pending_free_credits remains True.
"""
mock_request = Request(scope={'type': 'http'})
mock_request._base_url = URL('http://test.com/')
@@ -417,6 +505,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback():
mock_billing_session.user_id = 'mock_user'
mock_org = MagicMock()
mock_org.pending_free_credits = True
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
@@ -430,13 +519,15 @@ async def test_success_callback_lite_llm_update_budget_error_rollback():
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 0,
'max_budget_in_team': 0,
'litellm_budget_table': {'max_budget': 0},
},
),
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget',
side_effect=Exception('LiteLLM API Error'),
),
patch('server.routes.billing.FREE_CREDIT_THRESHOLD', 10.0),
patch('server.routes.billing.FREE_CREDIT_AMOUNT', 10.0),
):
mock_db_session = MagicMock()
mock_query_chain_billing = MagicMock()
@@ -449,6 +540,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback():
]
mock_session_maker.return_value.__enter__.return_value = mock_db_session
# Purchase $10 to reach threshold
mock_stripe_retrieve.return_value = MagicMock(
status='complete',
amount_subtotal=1000, # $10
@@ -579,6 +671,6 @@ async def test_create_customer_setup_session_success():
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='https://test.com/?setup=success',
success_url='https://test.com/?free_credits=success',
cancel_url='https://test.com/',
)

View File

@@ -48,7 +48,7 @@ async def test_create_customer_setup_session_uses_customer_id():
customer=customer_id,
mode='setup',
payment_method_types=['card'],
success_url=f'{request.base_url}?setup=success',
success_url=f'{request.base_url}?free_credits=success',
cancel_url=f'{request.base_url}',
)

View File

@@ -1,192 +0,0 @@
"""Tests for email service."""
import os
from unittest.mock import MagicMock, patch
from server.services.email_service import (
DEFAULT_WEB_HOST,
EmailService,
)
class TestEmailServiceInvitationUrl:
"""Test cases for invitation URL generation."""
def test_invitation_url_uses_correct_endpoint(self):
"""Test that invitation URL points to the correct API endpoint."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
# Get the call arguments
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify the URL in the email HTML contains the correct endpoint
assert (
'/api/organizations/members/invite/accept?token='
in email_params['html']
)
assert 'inv-test-token-12345' in email_params['html']
def test_invitation_url_uses_web_host_env_var(self):
"""Test that invitation URL uses WEB_HOST environment variable."""
custom_host = 'https://custom.example.com'
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(
os.environ,
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
def test_invitation_url_uses_default_host_when_env_not_set(self):
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
# Remove WEB_HOST if it exists
env_without_web_host.pop('WEB_HOST', None)
with (
patch.dict(os.environ, env_without_web_host, clear=True),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
# Clear WEB_HOST from the environment
os.environ.pop('WEB_HOST', None)
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token-12345',
invitation_id=1,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
assert expected_url in email_params['html']
class TestEmailServiceGetResendClient:
"""Test cases for Resend client initialization."""
def test_get_resend_client_returns_false_when_resend_not_available(self):
"""Test that _get_resend_client returns False when resend is not installed."""
with patch('server.services.email_service.RESEND_AVAILABLE', False):
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
"""Test that _get_resend_client returns False when API key is missing."""
with (
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch.dict(os.environ, {}, clear=True),
):
os.environ.pop('RESEND_API_KEY', None)
result = EmailService._get_resend_client()
assert result is False
def test_get_resend_client_returns_true_when_configured(self):
"""Test that _get_resend_client returns True when properly configured."""
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
result = EmailService._get_resend_client()
assert result is True
assert mock_resend.api_key == 'test-key'
class TestEmailServiceSendInvitationEmail:
"""Test cases for send_invitation_email method."""
def test_send_invitation_email_skips_when_client_not_ready(self):
"""Test that email sending is skipped when client is not ready."""
with patch.object(
EmailService, '_get_resend_client', return_value=False
) as mock_get_client:
# Should not raise, just return early
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Test Org',
inviter_name='Inviter',
role_name='member',
invitation_token='inv-test-token',
invitation_id=1,
)
mock_get_client.assert_called_once()
def test_send_invitation_email_includes_all_required_info(self):
"""Test that invitation email includes org name, inviter name, and role."""
mock_response = MagicMock()
mock_response.get.return_value = 'test-email-id'
with (
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
patch('server.services.email_service.RESEND_AVAILABLE', True),
patch('server.services.email_service.resend') as mock_resend,
):
mock_resend.Emails.send.return_value = mock_response
EmailService.send_invitation_email(
to_email='test@example.com',
org_name='Acme Corp',
inviter_name='John Doe',
role_name='admin',
invitation_token='inv-test-token-12345',
invitation_id=42,
)
call_args = mock_resend.Emails.send.call_args
email_params = call_args[0][0]
# Verify email content
assert email_params['to'] == ['test@example.com']
assert 'Acme Corp' in email_params['subject']
assert 'John Doe' in email_params['html']
assert 'Acme Corp' in email_params['html']
assert 'admin' in email_params['html']

View File

@@ -142,192 +142,44 @@ class TestLiteLlmManager:
@pytest.mark.asyncio
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
"""Test create_entries in cloud deployment mode."""
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.post.return_value = mock_response
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
result = await LiteLlmManager.create_entries(
'test-org-id',
'test-user-id',
mock_settings,
create_user=False,
)
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert (
result.llm_api_key.get_secret_value() == 'test-api-key'
)
assert result.llm_base_url == 'http://test.com'
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-api-key'
assert result.llm_base_url == 'http://test.com'
# Verify API calls were made (get_team + 3 posts)
assert mock_client.get.call_count == 1 # get_team
assert (
mock_client.post.call_count == 3
) # create_team, add_user_to_team, generate_key
@pytest.mark.asyncio
async def test_create_entries_inherits_existing_team_budget(
self, mock_settings, mock_response
):
"""Test that create_entries inherits budget from existing team."""
mock_team_response = MagicMock()
mock_team_response.is_success = True
mock_team_response.status_code = 200
mock_team_response.json.return_value = {
'team_info': {'max_budget': 30.0, 'spend': 5.0},
'team_memberships': [],
}
mock_team_response.raise_for_status = MagicMock()
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_team_response
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _get_team was called first
mock_client.get.assert_called_once()
get_call_url = mock_client.get.call_args[0][0]
assert 'team/info' in get_call_url
assert 'test-org-id' in get_call_url
# Verify _create_team was called with inherited budget (30.0)
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 30.0
# Verify _add_user_to_team was called with inherited budget (30.0)
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
@pytest.mark.asyncio
async def test_create_entries_new_org_uses_zero_budget(
self, mock_settings, mock_response
):
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
mock_404_response = MagicMock()
mock_404_response.status_code = 404
mock_404_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_404_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Not Found', request=MagicMock(), response=mock_404_response
)
)
mock_client.post.return_value = mock_response
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert result is not None
# Verify _create_team was called with budget=0
create_team_call = mock_client.post.call_args_list[0]
assert 'team/new' in create_team_call[0][0]
assert create_team_call[1]['json']['max_budget'] == 0.0
# Verify _add_user_to_team was called with budget=0
add_user_call = mock_client.post.call_args_list[1]
assert 'team/member_add' in add_user_call[0][0]
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
@pytest.mark.asyncio
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
"""Test that create_entries propagates non-404 errors from _get_team."""
mock_500_response = MagicMock()
mock_500_response.status_code = 500
mock_500_response.is_success = False
mock_token_manager = MagicMock()
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
return_value={'email': 'test@example.com'}
)
mock_client = AsyncMock()
mock_client.get.return_value = mock_500_response
mock_client.get.return_value.raise_for_status.side_effect = (
httpx.HTTPStatusError(
message='Internal Server Error',
request=MagicMock(),
response=mock_500_response,
)
)
mock_client_class = MagicMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
with (
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
patch('httpx.AsyncClient', mock_client_class),
):
with pytest.raises(httpx.HTTPStatusError) as exc_info:
await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings, create_user=False
)
assert exc_info.value.response.status_code == 500
# Verify API calls were made
assert (
mock_client.post.call_count == 3
) # create_team, create_user, add_user_to_team, generate_key
@pytest.mark.asyncio
async def test_migrate_entries_missing_config(self, mock_user_settings):

View File

@@ -1,464 +0,0 @@
"""Tests for organization invitation service - email validation."""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from server.routes.org_invitation_models import (
EmailMismatchError,
)
from server.services.org_invitation_service import OrgInvitationService
from storage.org_invitation import OrgInvitation
class TestAcceptInvitationEmailValidation:
"""Test cases for email validation during invitation acceptance."""
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation with pending status."""
invitation = MagicMock(spec=OrgInvitation)
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.status = OrgInvitation.STATUS_PENDING
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
invitation.role_id = 1
return invitation
@pytest.fixture
def mock_user(self):
"""Create a mock user with email."""
user = MagicMock()
user.id = UUID('87654321-4321-8765-4321-876543218765')
user.email = 'alice@example.com'
return user
@pytest.mark.asyncio
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
"""Test that invitation is accepted when user email matches invitation email."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
with patch.object(
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
) as mock_accept:
mock_accept.return_value = mock_invitation
# Act
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_accept.assert_called_once_with(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_email_mismatch_raises_error(
self, mock_invitation, mock_user
):
"""Test that EmailMismatchError is raised when emails don't match."""
# Arrange
user_id = mock_user.id
token = 'inv-test-token-12345'
mock_user.email = 'bob@example.com' # Different email
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Act & Assert
with pytest.raises(EmailMismatchError):
await OrgInvitationService.accept_invitation(token, user_id)
@pytest.mark.asyncio
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
self, mock_invitation
):
"""Test that Keycloak email is used when user has no email in database."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager instance
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(
return_value=mock_keycloak_user_info
)
mock_token_manager_class.return_value = mock_token_manager
mock_get_member.return_value = None # Not already a member
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because Keycloak email matches
await OrgInvitationService.accept_invitation(token, user_id)
# Assert
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
str(user_id)
)
@pytest.mark.asyncio
async def test_accept_invitation_no_email_anywhere_raises_error(
self, mock_invitation
):
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = None # No email in database
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.TokenManager'
) as mock_token_manager_class,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
# Mock TokenManager to return no email
mock_token_manager = MagicMock()
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
mock_token_manager_class.return_value = mock_token_manager
# Act & Assert
with pytest.raises(EmailMismatchError) as exc_info:
await OrgInvitationService.accept_invitation(token, user_id)
assert 'does not have an email address' in str(exc_info.value)
@pytest.mark.asyncio
async def test_accept_invitation_email_comparison_is_case_insensitive(
self, mock_invitation
):
"""Test that email comparison is case insensitive."""
# Arrange
user_id = UUID('87654321-4321-8765-4321-876543218765')
token = 'inv-test-token-12345'
mock_user = MagicMock()
mock_user.id = user_id
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
with (
patch(
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
new_callable=AsyncMock,
) as mock_get_invitation,
patch(
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
) as mock_is_expired,
patch(
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
new_callable=AsyncMock,
) as mock_get_user,
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
) as mock_get_member,
patch(
'server.services.org_invitation_service.OrgService.create_litellm_integration',
new_callable=AsyncMock,
) as mock_create_litellm,
patch(
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
),
patch(
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
new_callable=AsyncMock,
) as mock_update_status,
):
mock_get_invitation.return_value = mock_invitation
mock_is_expired.return_value = False
mock_get_user.return_value = mock_user
mock_get_member.return_value = None
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
mock_update_status.return_value = mock_invitation
# Act - should not raise error because emails match case-insensitively
await OrgInvitationService.accept_invitation(token, user_id)
# Assert - invitation was accepted (update_invitation_status was called)
mock_update_status.assert_called_once()
class TestCreateInvitationsBatch:
"""Test cases for batch invitation creation."""
@pytest.fixture
def org_id(self):
"""Organization UUID for testing."""
return UUID('12345678-1234-5678-1234-567812345678')
@pytest.fixture
def inviter_id(self):
"""Inviter UUID for testing."""
return UUID('87654321-4321-8765-4321-876543218765')
@pytest.fixture
def mock_org(self):
"""Create a mock organization."""
org = MagicMock()
org.id = UUID('12345678-1234-5678-1234-567812345678')
org.name = 'Test Org'
return org
@pytest.fixture
def mock_inviter_member(self):
"""Create a mock inviter member with owner role."""
member = MagicMock()
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
member.role_id = 1
return member
@pytest.fixture
def mock_owner_role(self):
"""Create a mock owner role."""
role = MagicMock()
role.id = 1
role.name = 'owner'
return role
@pytest.fixture
def mock_member_role(self):
"""Create a mock member role."""
role = MagicMock()
role.id = 3
role.name = 'member'
return role
@pytest.mark.asyncio
async def test_batch_creates_all_invitations_successfully(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch creation succeeds for all valid emails."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
mock_invitation_1 = MagicMock(spec=OrgInvitation)
mock_invitation_1.id = 1
mock_invitation_2 = MagicMock(spec=OrgInvitation)
mock_invitation_2.id = 2
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation_1, mock_invitation_2],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 2
assert len(failed) == 0
@pytest.mark.asyncio
async def test_batch_handles_partial_success(
self,
org_id,
inviter_id,
mock_org,
mock_inviter_member,
mock_owner_role,
mock_member_role,
):
"""Test that batch returns partial results when some emails fail."""
# Arrange
from server.routes.org_invitation_models import UserAlreadyMemberError
emails = ['alice@example.com', 'existing@example.com']
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=mock_member_role,
),
patch.object(
OrgInvitationService,
'create_invitation',
new_callable=AsyncMock,
side_effect=[mock_invitation, UserAlreadyMemberError()],
),
):
# Act
successful, failed = await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
# Assert
assert len(successful) == 1
assert len(failed) == 1
assert failed[0][0] == 'existing@example.com'
@pytest.mark.asyncio
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
"""Test that permission error fails the entire batch upfront."""
# Arrange
emails = ['alice@example.com', 'bob@example.com']
with patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=None, # Organization not found
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='member',
inviter_id=inviter_id,
)
assert 'not found' in str(exc_info.value)
@pytest.mark.asyncio
async def test_batch_fails_on_invalid_role(
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
):
"""Test that invalid role fails the entire batch."""
# Arrange
emails = ['alice@example.com']
with (
patch(
'server.services.org_invitation_service.OrgStore.get_org_by_id',
return_value=mock_org,
),
patch(
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
return_value=mock_inviter_member,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_id',
return_value=mock_owner_role,
),
patch(
'server.services.org_invitation_service.RoleStore.get_role_by_name',
return_value=None, # Invalid role
),
):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
await OrgInvitationService.create_invitations_batch(
org_id=org_id,
emails=emails,
role_name='invalid_role',
inviter_id=inviter_id,
)
assert 'Invalid role' in str(exc_info.value)

View File

@@ -1,308 +0,0 @@
"""Tests for organization invitation store."""
from datetime import datetime, timedelta
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from storage.org_invitation import OrgInvitation
from storage.org_invitation_store import (
INVITATION_TOKEN_LENGTH,
INVITATION_TOKEN_PREFIX,
OrgInvitationStore,
)
class TestGenerateToken:
"""Test cases for token generation."""
def test_generate_token_has_correct_prefix(self):
"""Test that generated tokens have the correct prefix."""
token = OrgInvitationStore.generate_token()
assert token.startswith(INVITATION_TOKEN_PREFIX)
def test_generate_token_has_correct_length(self):
"""Test that generated tokens have the correct total length."""
token = OrgInvitationStore.generate_token()
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
assert len(token) == expected_length
def test_generate_token_uses_alphanumeric_characters(self):
"""Test that generated tokens use only alphanumeric characters."""
token = OrgInvitationStore.generate_token()
# Remove prefix and check the rest is alphanumeric
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
assert random_part.isalnum()
def test_generate_token_is_unique(self):
"""Test that generated tokens are unique (probabilistically)."""
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
assert len(set(tokens)) == 100
class TestIsTokenExpired:
"""Test cases for token expiration checking."""
def test_token_not_expired_when_future(self):
"""Test that tokens with future expiration are not expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is False
def test_token_expired_when_past(self):
"""Test that tokens with past expiration are expired."""
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
def test_token_expired_at_exact_boundary(self):
"""Test that tokens at exact expiration time are expired."""
# A token that expires "now" should be expired
now = datetime.utcnow()
invitation = MagicMock(spec=OrgInvitation)
invitation.expires_at = now - timedelta(microseconds=1)
result = OrgInvitationStore.is_token_expired(invitation)
assert result is True
class TestCreateInvitation:
"""Test cases for invitation creation."""
@pytest.mark.asyncio
async def test_create_invitation_normalizes_email(self):
"""Test that email is normalized (lowercase, stripped) on creation."""
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.execute = AsyncMock()
# Mock the result of the re-fetch query
mock_result = MagicMock()
mock_invitation = MagicMock()
mock_invitation.id = 1
mock_invitation.email = 'test@example.com'
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute.return_value = mock_result
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.create_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
role_id=1,
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
)
# Verify that the OrgInvitation was created with normalized email
add_call = mock_session.add.call_args
created_invitation = add_call[0][0]
assert created_invitation.email == 'test@example.com'
class TestGetInvitationByToken:
"""Test cases for getting invitation by token."""
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_invitation(self):
"""Test that get_invitation_by_token returns the invitation when found."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.token = 'inv-test-token-12345'
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-test-token-12345'
)
assert result == mock_invitation
@pytest.mark.asyncio
async def test_get_invitation_by_token_returns_none_when_not_found(self):
"""Test that get_invitation_by_token returns None when not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.get_invitation_by_token(
'inv-nonexistent-token'
)
assert result is None
class TestGetPendingInvitation:
"""Test cases for getting pending invitation."""
@pytest.mark.asyncio
async def test_get_pending_invitation_normalizes_email(self):
"""Test that email is normalized when querying for pending invitations."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
from uuid import UUID
await OrgInvitationStore.get_pending_invitation(
org_id=UUID('12345678-1234-5678-1234-567812345678'),
email=' TEST@EXAMPLE.COM ',
)
# Verify the query was called (email normalization happens in the filter)
assert mock_session.execute.called
class TestUpdateInvitationStatus:
"""Test cases for updating invitation status."""
@pytest.mark.asyncio
async def test_update_status_sets_accepted_at_for_accepted(self):
"""Test that accepted_at is set when status is accepted."""
from uuid import UUID
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = mock_invitation
mock_session.execute = AsyncMock(return_value=mock_result)
mock_session.commit = AsyncMock()
mock_session.refresh = AsyncMock()
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
user_id = UUID('87654321-4321-8765-4321-876543218765')
await OrgInvitationStore.update_invitation_status(
invitation_id=1,
status=OrgInvitation.STATUS_ACCEPTED,
accepted_by_user_id=user_id,
)
assert mock_invitation.accepted_at is not None
assert mock_invitation.accepted_by_user_id == user_id
@pytest.mark.asyncio
async def test_update_status_returns_none_when_not_found(self):
"""Test that update returns None when invitation not found."""
mock_session = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.first.return_value = None
mock_session.execute = AsyncMock(return_value=mock_result)
with patch(
'storage.org_invitation_store.a_session_maker'
) as mock_session_maker:
mock_session_manager = AsyncMock()
mock_session_manager.__aenter__.return_value = mock_session
mock_session_manager.__aexit__.return_value = None
mock_session_maker.return_value = mock_session_manager
result = await OrgInvitationStore.update_invitation_status(
invitation_id=999,
status=OrgInvitation.STATUS_ACCEPTED,
)
assert result is None
class TestMarkExpiredIfNeeded:
"""Test cases for marking expired invitations."""
@pytest.mark.asyncio
async def test_marks_expired_when_pending_and_past_expiry(self):
"""Test that pending expired invitations are marked as expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is True
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
@pytest.mark.asyncio
async def test_does_not_mark_when_not_expired(self):
"""Test that non-expired invitations are not marked."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_PENDING
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()
@pytest.mark.asyncio
async def test_does_not_mark_when_not_pending(self):
"""Test that non-pending invitations are not marked even if expired."""
mock_invitation = MagicMock(spec=OrgInvitation)
mock_invitation.id = 1
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
with patch.object(
OrgInvitationStore,
'update_invitation_status',
new_callable=AsyncMock,
) as mock_update:
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
assert result is False
mock_update.assert_not_called()

View File

@@ -1,388 +0,0 @@
"""Tests for organization invitations API router."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from server.routes.org_invitation_models import (
EmailMismatchError,
InvitationExpiredError,
InvitationInvalidError,
UserAlreadyMemberError,
)
from server.routes.org_invitations import accept_router, invitation_router
@pytest.fixture
def app():
"""Create a FastAPI app with the invitation routers."""
app = FastAPI()
app.include_router(invitation_router)
app.include_router(accept_router)
return app
@pytest.fixture
def client(app):
"""Create a test client for the app."""
return TestClient(app)
class TestRouterPrefixes:
"""Test that router prefixes are configured correctly."""
def test_invitation_router_has_correct_prefix(self):
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
def test_accept_router_has_correct_prefix(self):
"""Test that accept_router has /api/organizations/members/invite prefix."""
assert accept_router.prefix == '/api/organizations/members/invite'
class TestAcceptInvitationEndpoint:
"""Test cases for the accept invitation endpoint."""
@pytest.fixture
def mock_user_auth(self):
"""Create a mock user auth."""
user_auth = MagicMock()
user_auth.get_user_id = AsyncMock(
return_value='87654321-4321-8765-4321-876543218765'
)
return user_auth
@pytest.mark.asyncio
async def test_accept_unauthenticated_redirects_to_login(self, client):
"""Test that unauthenticated users are redirected to login with invitation token."""
with patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=None,
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
'location', ''
)
@pytest.mark.asyncio
async def test_accept_authenticated_success_redirects_home(
self, client, mock_user_auth
):
"""Test that successful acceptance redirects to home page."""
mock_invitation = MagicMock()
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
return_value=mock_invitation,
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
location = response.headers.get('location', '')
assert location.endswith('/')
assert 'invitation_expired' not in location
assert 'invitation_invalid' not in location
assert 'email_mismatch' not in location
@pytest.mark.asyncio
async def test_accept_expired_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that expired invitation redirects with invitation_expired=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationExpiredError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_expired=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_invalid_invitation_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that invalid invitation redirects with invitation_invalid=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=InvitationInvalidError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_invalid=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_already_member_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that already member error redirects with already_member=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=UserAlreadyMemberError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'already_member=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_email_mismatch_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that email mismatch error redirects with email_mismatch=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=EmailMismatchError(),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'email_mismatch=true' in response.headers.get('location', '')
@pytest.mark.asyncio
async def test_accept_unexpected_error_redirects_with_flag(
self, client, mock_user_auth
):
"""Test that unexpected errors redirect with invitation_error=true."""
with (
patch(
'server.routes.org_invitations.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
new_callable=AsyncMock,
side_effect=Exception('Unexpected error'),
),
):
response = client.get(
'/api/organizations/members/invite/accept?token=inv-test-token-123',
follow_redirects=False,
)
assert response.status_code == 302
assert 'invitation_error=true' in response.headers.get('location', '')
class TestCreateInvitationBatchEndpoint:
"""Test cases for the batch invitation creation endpoint."""
@pytest.fixture
def batch_app(self):
"""Create a FastAPI app with dependency overrides for batch tests."""
from openhands.server.user_auth import get_user_id
app = FastAPI()
app.include_router(invitation_router)
# Override the get_user_id dependency
app.dependency_overrides[get_user_id] = (
lambda: '87654321-4321-8765-4321-876543218765'
)
return app
@pytest.fixture
def batch_client(self, batch_app):
"""Create a test client with dependency overrides."""
return TestClient(batch_app)
@pytest.fixture
def mock_invitation(self):
"""Create a mock invitation."""
from datetime import datetime
invitation = MagicMock()
invitation.id = 1
invitation.email = 'alice@example.com'
invitation.role = MagicMock(name='member')
invitation.role.name = 'member'
invitation.role_id = 3
invitation.status = 'pending'
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
return invitation
@pytest.mark.asyncio
async def test_batch_create_returns_successful_invitations(
self, batch_client, mock_invitation
):
"""Test that batch creation returns successful invitations."""
mock_invitation_2 = MagicMock()
mock_invitation_2.id = 2
mock_invitation_2.email = 'bob@example.com'
mock_invitation_2.role = MagicMock()
mock_invitation_2.role.name = 'member'
mock_invitation_2.role_id = 3
mock_invitation_2.status = 'pending'
mock_invitation_2.created_at = mock_invitation.created_at
mock_invitation_2.expires_at = mock_invitation.expires_at
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation, mock_invitation_2], []),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'bob@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 2
assert len(data['failed']) == 0
@pytest.mark.asyncio
async def test_batch_create_returns_partial_success(
self, batch_client, mock_invitation
):
"""Test that batch creation returns both successful and failed invitations."""
failed_emails = [('existing@example.com', 'User is already a member')]
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
return_value=([mock_invitation], failed_emails),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={
'emails': ['alice@example.com', 'existing@example.com'],
'role': 'member',
},
)
assert response.status_code == 201
data = response.json()
assert len(data['successful']) == 1
assert len(data['failed']) == 1
assert data['failed'][0]['email'] == 'existing@example.com'
assert 'already a member' in data['failed'][0]['error']
@pytest.mark.asyncio
async def test_batch_create_permission_denied_returns_403(self, batch_client):
"""Test that permission denied returns 403 for entire batch."""
from server.routes.org_invitation_models import InsufficientPermissionError
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=InsufficientPermissionError(
'Only owners and admins can invite'
),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'member'},
)
assert response.status_code == 403
assert 'owners and admins' in response.json()['detail']
@pytest.mark.asyncio
async def test_batch_create_invalid_role_returns_400(self, batch_client):
"""Test that invalid role returns 400."""
with (
patch(
'server.routes.org_invitations.check_rate_limit_by_user_id',
new_callable=AsyncMock,
),
patch(
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
new_callable=AsyncMock,
side_effect=ValueError('Invalid role: superuser'),
),
):
response = batch_client.post(
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
json={'emails': ['alice@example.com'], 'role': 'superuser'},
)
assert response.status_code == 400
assert 'Invalid role' in response.json()['detail']

View File

@@ -158,57 +158,6 @@ def test_get_org_member(session_maker):
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
def test_get_org_member_for_current_org(session_maker):
# Test getting org_member for user's current organization
with session_maker() as session:
# Create test data - user belongs to two orgs but current_org is org1
org1 = Org(name='test-org-1')
org2 = Org(name='test-org-2')
session.add_all([org1, org2])
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org1.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member1 = OrgMember(
org_id=org1.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org2.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
user_id = user.id
org1_id = org1.id
# Test retrieval - should return org_member for current_org (org1)
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(user_id)
assert retrieved_org_member is not None
assert retrieved_org_member.org_id == org1_id
assert retrieved_org_member.user_id == user_id
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key-1'
def test_get_org_member_for_current_org_user_not_found(session_maker):
# Test getting org_member for non-existent user
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(
uuid.uuid4()
)
assert retrieved_org_member is None
def test_add_user_to_org(session_maker):
# Test adding a user to an org
with session_maker() as session:

View File

@@ -482,7 +482,7 @@ async def test_get_org_credits_success(mock_litellm_api):
spend = 25.0
mock_team_info = {
'max_budget_in_team': max_budget,
'litellm_budget_table': {'max_budget': max_budget},
'spend': spend,
}

View File

@@ -398,121 +398,6 @@ async def test_create_user_contact_name_falls_back_to_username():
assert org.contact_name == 'jdoe'
# --- Tests for email fields in create_user() ---
# create_user() should populate user.email and user.email_verified from the
# Keycloak user_info, ensuring the user table has the correct email data.
class _StopAfterUserCreation(Exception):
"""Halt create_user() after User creation for email field inspection."""
pass
@pytest.mark.asyncio
async def test_create_user_sets_email_from_user_info():
"""create_user() should set user.email and user.email_verified from user_info."""
# Arrange
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'testuser',
'email': 'testuser@example.com',
'email_verified': True,
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(language='en')
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
patch.object(
mock_session,
'commit',
side_effect=_StopAfterUserCreation,
),
):
# Act
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
# Assert - User is the second object added to session (after Org)
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'testuser@example.com'
assert user.email_verified is True
@pytest.mark.asyncio
async def test_create_user_handles_missing_email_verified():
"""create_user() should handle missing email_verified in user_info gracefully."""
# Arrange
user_id = str(uuid.uuid4())
user_info = {
'preferred_username': 'testuser',
'email': 'testuser@example.com',
# email_verified is not present
}
mock_session = MagicMock()
mock_sm = MagicMock()
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
mock_settings = Settings(language='en')
mock_role = MagicMock()
mock_role.id = 1
with (
patch('storage.user_store.session_maker', mock_sm),
patch.object(
UserStore,
'create_default_settings',
new_callable=AsyncMock,
return_value=mock_settings,
),
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
patch(
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
return_value={'llm_model': None, 'llm_base_url': None},
),
patch.object(
mock_session,
'commit',
side_effect=_StopAfterUserCreation,
),
):
# Act
with pytest.raises(_StopAfterUserCreation):
await UserStore.create_user(user_id, user_info)
# Assert - User should have email but email_verified should be None
user = mock_session.add.call_args_list[1][0][0]
assert isinstance(user, User)
assert user.email == 'testuser@example.com'
assert user.email_verified is None
# --- Tests for backfill_contact_name on login ---
# Existing users created before the resolve_display_name fix may have
# username-style values in contact_name. The backfill updates these to

View File

@@ -151,9 +151,8 @@ describe("LoginContent", () => {
await user.click(githubButton);
// Wait for async handleAuthRedirect to complete
// The URL includes state parameter added by handleAuthRedirect
await waitFor(() => {
expect(window.location.href).toContain(mockUrl);
expect(window.location.href).toBe(mockUrl);
});
});
@@ -202,103 +201,4 @@ describe("LoginContent", () => {
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
});
it("should display invitation pending message when hasInvitation is true", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation
/>
</MemoryRouter>,
);
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
it("should not display invitation pending message when hasInvitation is false", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
hasInvitation={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should call buildOAuthStateData when clicking auth button", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
expect(mockBuildOAuthStateData).toHaveBeenCalled();
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
expect(callArg).toHaveProperty("redirect_url");
});
});
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/login/oauth/authorize"
appMode="saas"
providersConfigured={["github"]}
buildOAuthStateData={mockBuildOAuthStateData}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
await waitFor(() => {
const redirectUrl = window.location.href;
// The URL should contain an encoded state parameter
expect(redirectUrl).toContain("state=");
// Decode and verify the state contains invitation_token
const url = new URL(redirectUrl);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
});
});

View File

@@ -1,170 +0,0 @@
import { act, renderHook } from "@testing-library/react";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
// Mock setSearchParams function
const mockSetSearchParams = vi.fn();
// Default mock searchParams
let mockSearchParamsData: Record<string, string> = {};
// Mock react-router
vi.mock("react-router", () => ({
useSearchParams: () => [
{
get: (key: string) => mockSearchParamsData[key] || null,
has: (key: string) => key in mockSearchParamsData,
},
mockSetSearchParams,
],
}));
// Import after mocking
import { useInvitation } from "#/hooks/use-invitation";
describe("useInvitation", () => {
beforeEach(() => {
// Clear localStorage before each test
localStorage.clear();
// Reset mock data
mockSearchParamsData = {};
mockSetSearchParams.mockClear();
});
afterEach(() => {
vi.clearAllMocks();
});
describe("initialization", () => {
it("should initialize with null token when localStorage is empty", () => {
// Arrange - localStorage is empty (cleared in beforeEach)
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
it("should initialize with token from localStorage if present", () => {
// Arrange
const storedToken = "inv-stored-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(result.current.invitationToken).toBe(storedToken);
expect(result.current.hasInvitation).toBe(true);
});
});
describe("URL token capture", () => {
it("should capture invitation_token from URL and store in localStorage", () => {
// Arrange
const urlToken = "inv-url-token-67890";
mockSearchParamsData = { invitation_token: urlToken };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
expect(mockSetSearchParams).toHaveBeenCalled();
});
});
describe("completion cleanup", () => {
it("should clear localStorage when email_mismatch param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { email_mismatch: "true" };
// Act
const { result } = renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(mockSetSearchParams).toHaveBeenCalled();
});
it("should clear localStorage when invitation_success param is present", () => {
// Arrange
const storedToken = "inv-token-to-clear";
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
mockSearchParamsData = { invitation_success: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
it("should clear localStorage when invitation_expired param is present", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
mockSearchParamsData = { invitation_expired: "true" };
// Act
renderHook(() => useInvitation());
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
});
});
describe("buildOAuthStateData", () => {
it("should include invitation_token in OAuth state when token is present", () => {
// Arrange
const token = "inv-oauth-token-12345";
localStorage.setItem(INVITATION_TOKEN_KEY, token);
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBe(token);
expect(stateData.redirect_url).toBe("/dashboard");
});
it("should not include invitation_token when no token is present", () => {
// Arrange - no token in localStorage
const { result } = renderHook(() => useInvitation());
const baseState = { redirect_url: "/dashboard" };
// Act
const stateData = result.current.buildOAuthStateData(baseState);
// Assert
expect(stateData.invitation_token).toBeUndefined();
expect(stateData.redirect_url).toBe("/dashboard");
});
});
describe("clearInvitation", () => {
it("should remove token from localStorage when called", () => {
// Arrange
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
const { result } = renderHook(() => useInvitation());
// Act
act(() => {
result.current.clearInvitation();
});
// Assert
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
expect(result.current.invitationToken).toBeNull();
expect(result.current.hasInvitation).toBe(false);
});
});
});

View File

@@ -57,22 +57,6 @@ vi.mock("#/hooks/use-tracking", () => ({
}),
}));
const { useInvitationMock, buildOAuthStateDataMock } = vi.hoisted(() => ({
useInvitationMock: vi.fn(() => ({
invitationToken: null as string | null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
})),
buildOAuthStateDataMock: vi.fn(
(baseState: Record<string, string>) => baseState,
),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => useInvitationMock(),
}));
const RouterStub = createRoutesStub([
{
Component: LoginPage,
@@ -250,8 +234,7 @@ describe("LoginPage", () => {
});
await user.click(githubButton);
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(mockUrl);
expect(window.location.href).toBe(mockUrl);
});
it("should redirect to GitLab auth URL when GitLab button is clicked", async () => {
@@ -272,8 +255,7 @@ describe("LoginPage", () => {
});
await user.click(gitlabButton);
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain("https://gitlab.com/oauth/authorize");
expect(window.location.href).toBe("https://gitlab.com/oauth/authorize");
});
it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => {
@@ -300,8 +282,7 @@ describe("LoginPage", () => {
});
await user.click(bitbucketButton);
// URL includes state parameter added by handleAuthRedirect
expect(window.location.href).toContain(
expect(window.location.href).toBe(
"https://bitbucket.org/site/oauth2/authorize",
);
});
@@ -498,137 +479,4 @@ describe("LoginPage", () => {
});
});
});
describe("Invitation Flow", () => {
it("should display invitation pending message when hasInvitation is true", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
it("should not display invitation pending message when hasInvitation is false", async () => {
useInvitationMock.mockReturnValue({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByTestId("login-content")).toBeInTheDocument();
});
expect(
screen.queryByText("AUTH$INVITATION_PENDING"),
).not.toBeInTheDocument();
});
it("should pass buildOAuthStateData to LoginContent for OAuth state encoding", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// buildOAuthStateData should have been called during the OAuth redirect
expect(mockBuildOAuthStateData).toHaveBeenCalled();
});
it("should include invitation token in OAuth state when invitation is present", async () => {
const user = userEvent.setup();
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
...baseState,
invitation_token: "inv-test-token-12345",
}));
useInvitationMock.mockReturnValue({
invitationToken: "inv-test-token-12345",
hasInvitation: true,
buildOAuthStateData: mockBuildOAuthStateData,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
).toBeInTheDocument();
});
const githubButton = screen.getByRole("button", {
name: "GITHUB$CONNECT_TO_GITHUB",
});
await user.click(githubButton);
// Verify the redirect URL contains the state with invitation token
await waitFor(() => {
expect(window.location.href).toContain("state=");
});
// Decode and verify the state contains invitation_token
const url = new URL(window.location.href);
const state = url.searchParams.get("state");
if (state) {
const decodedState = JSON.parse(atob(state));
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
}
});
it("should handle login with invitation_token URL parameter", async () => {
useInvitationMock.mockReturnValue({
invitationToken: "inv-url-token-67890",
hasInvitation: true,
buildOAuthStateData: buildOAuthStateDataMock,
clearInvitation: vi.fn(),
});
render(<RouterStub initialEntries={["/login?invitation_token=inv-url-token-67890"]} />, {
wrapper: createWrapper(),
});
await waitFor(() => {
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
});
});
});
});

View File

@@ -42,15 +42,6 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: vi.fn(),
}));
vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => ({
invitationToken: null,
hasInvitation: false,
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
clearInvitation: vi.fn(),
}),
}));
function LoginStub() {
const [searchParams] = useSearchParams();
const emailVerificationRequired =
@@ -362,68 +353,4 @@ describe("MainApp", () => {
);
});
});
describe("Invitation URL Parameters", () => {
beforeEach(() => {
vi.spyOn(AuthService, "authenticate").mockRejectedValue({
response: { status: 401 },
isAxiosError: true,
});
});
it("should redirect to login when email_mismatch=true is in query params", async () => {
renderMainApp(["/?email_mismatch=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_success=true is in query params", async () => {
renderMainApp(["/?invitation_success=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_expired=true is in query params", async () => {
renderMainApp(["/?invitation_expired=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when invitation_invalid=true is in query params", async () => {
renderMainApp(["/?invitation_invalid=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
it("should redirect to login when already_member=true is in query params", async () => {
renderMainApp(["/?already_member=true"]);
await waitFor(
() => {
expect(screen.getByTestId("login-page")).toBeInTheDocument();
},
{ timeout: 2000 },
);
});
});
});

View File

@@ -21,10 +21,6 @@ export interface LoginContentProps {
emailVerified?: boolean;
hasDuplicatedEmail?: boolean;
recaptchaBlocked?: boolean;
hasInvitation?: boolean;
buildOAuthStateData?: (
baseStateData: Record<string, string>,
) => Record<string, string>;
}
export function LoginContent({
@@ -35,8 +31,6 @@ export function LoginContent({
emailVerified = false,
hasDuplicatedEmail = false,
recaptchaBlocked = false,
hasInvitation = false,
buildOAuthStateData,
}: LoginContentProps) {
const { t } = useTranslation();
const { trackLoginButtonClick } = useTracking();
@@ -65,36 +59,31 @@ export function LoginContent({
) => {
trackLoginButtonClick({ provider });
const url = new URL(redirectUrl);
const currentState =
url.searchParams.get("state") || window.location.origin;
// Build base state data
let stateData: Record<string, string> = {
redirect_url: currentState,
};
// Add invitation token if present
if (buildOAuthStateData) {
stateData = buildOAuthStateData(stateData);
if (!config?.recaptcha_site_key || !recaptchaReady) {
// No reCAPTCHA or token generation failed - redirect normally
window.location.href = redirectUrl;
return;
}
// If reCAPTCHA is configured, add token to state
if (config?.recaptcha_site_key && recaptchaReady) {
try {
const token = await executeRecaptcha("LOGIN");
if (token) {
stateData.recaptcha_token = token;
}
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
return;
// If reCAPTCHA is configured, encode token in OAuth state
try {
const token = await executeRecaptcha("LOGIN");
if (token) {
const url = new URL(redirectUrl);
const currentState =
url.searchParams.get("state") || window.location.origin;
// Encode state with reCAPTCHA token for backend verification
const stateData = {
redirect_url: currentState,
recaptcha_token: token,
};
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
}
} catch (err) {
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
}
// Encode state and redirect
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
window.location.href = url.toString();
};
const handleGitHubAuth = () => {
@@ -134,10 +123,6 @@ export function LoginContent({
const buttonBaseClasses =
"w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed";
const buttonLabelClasses = "text-sm font-medium leading-5 px-1";
const shouldShownHelperText =
emailVerified || hasDuplicatedEmail || recaptchaBlocked || hasInvitation;
return (
<div
className="flex flex-col items-center w-full gap-12.5"
@@ -151,29 +136,20 @@ export function LoginContent({
{t(I18nKey.AUTH$LETS_GET_STARTED)}
</h1>
{shouldShownHelperText && (
<div className="flex flex-col items-center gap-3">
{emailVerified && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
</p>
)}
{hasDuplicatedEmail && (
<p className="text-sm text-danger text-center">
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
</p>
)}
{recaptchaBlocked && (
<p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
)}
{hasInvitation && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$INVITATION_PENDING)}
</p>
)}
</div>
{emailVerified && (
<p className="text-sm text-muted-foreground text-center">
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
</p>
)}
{hasDuplicatedEmail && (
<p className="text-sm text-danger text-center">
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
</p>
)}
{recaptchaBlocked && (
<p className="text-sm text-danger text-center max-w-125">
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
</p>
)}
<div className="flex flex-col items-center gap-3">

View File

@@ -1,119 +0,0 @@
import React from "react";
import { useSearchParams } from "react-router";
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
interface UseInvitationReturn {
/** The invitation token, if present */
invitationToken: string | null;
/** Whether there is an active invitation */
hasInvitation: boolean;
/** Clear the stored invitation token */
clearInvitation: () => void;
/** Build OAuth state data including invitation token if present */
buildOAuthStateData: (
baseStateData: Record<string, string>,
) => Record<string, string>;
}
/**
* Hook to manage organization invitation tokens during the login flow.
*
* This hook:
* 1. Reads invitation_token from URL query params on mount
* 2. Persists the token in localStorage (survives page refresh and works across tabs)
* 3. Provides the token for inclusion in OAuth state
* 4. Provides cleanup method after successful authentication
*
* The invitation token flow:
* 1. User clicks invitation link → /api/invitations/accept?token=xxx
* 2. Backend redirects to /login?invitation_token=xxx
* 3. This hook captures token and stores in localStorage
* 4. When user clicks login button, token is included in OAuth state
* 5. After auth callback processes invitation, frontend clears the token
*
* Note: localStorage is used instead of sessionStorage to support scenarios where
* the user opens the email verification link in a new tab/browser window.
*/
export function useInvitation(): UseInvitationReturn {
const [searchParams, setSearchParams] = useSearchParams();
const [invitationToken, setInvitationToken] = React.useState<string | null>(
() => {
// Initialize from localStorage (persists across tabs and page refreshes)
if (typeof window !== "undefined") {
return localStorage.getItem(INVITATION_TOKEN_KEY);
}
return null;
},
);
// Capture invitation token from URL and persist to localStorage
// This only runs on the login page where the hook is used
React.useEffect(() => {
const tokenFromUrl = searchParams.get("invitation_token");
if (tokenFromUrl) {
// Store in localStorage for persistence across tabs and refreshes
localStorage.setItem(INVITATION_TOKEN_KEY, tokenFromUrl);
setInvitationToken(tokenFromUrl);
// Remove token from URL to clean up (prevents token exposure in browser history)
const newSearchParams = new URLSearchParams(searchParams);
newSearchParams.delete("invitation_token");
setSearchParams(newSearchParams, { replace: true });
}
}, [searchParams, setSearchParams]);
// Clear invitation token when invitation flow completes (success or failure)
// These query params are set by the backend after processing the invitation
React.useEffect(() => {
const invitationCompleted =
searchParams.has("invitation_success") ||
searchParams.has("invitation_expired") ||
searchParams.has("invitation_invalid") ||
searchParams.has("invitation_error") ||
searchParams.has("already_member") ||
searchParams.has("email_mismatch");
if (invitationCompleted) {
localStorage.removeItem(INVITATION_TOKEN_KEY);
setInvitationToken(null);
// Remove invitation params from URL to clean up
const newSearchParams = new URLSearchParams(searchParams);
newSearchParams.delete("invitation_success");
newSearchParams.delete("invitation_expired");
newSearchParams.delete("invitation_invalid");
newSearchParams.delete("invitation_error");
newSearchParams.delete("already_member");
newSearchParams.delete("email_mismatch");
setSearchParams(newSearchParams, { replace: true });
}
}, [searchParams, setSearchParams]);
const clearInvitation = React.useCallback(() => {
localStorage.removeItem(INVITATION_TOKEN_KEY);
setInvitationToken(null);
}, []);
const buildOAuthStateData = React.useCallback(
(baseStateData: Record<string, string>): Record<string, string> => {
const stateData = { ...baseStateData };
// Include invitation token in state if present
if (invitationToken) {
stateData.invitation_token = invitationToken;
}
return stateData;
},
[invitationToken],
);
return {
invitationToken,
hasInvitation: invitationToken !== null,
clearInvitation,
buildOAuthStateData,
};
}

View File

@@ -763,7 +763,6 @@ export enum I18nKey {
AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR",
AUTH$RECAPTCHA_BLOCKED = "AUTH$RECAPTCHA_BLOCKED",
AUTH$LETS_GET_STARTED = "AUTH$LETS_GET_STARTED",
AUTH$INVITATION_PENDING = "AUTH$INVITATION_PENDING",
COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE",
COMMON$AND = "COMMON$AND",
COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY",

View File

@@ -12207,22 +12207,6 @@
"de": "Lass uns anfangen",
"uk": "Почнімо"
},
"AUTH$INVITATION_PENDING": {
"en": "Sign in to accept your organization invitation",
"ja": "組織への招待を受け入れるにはサインインしてください",
"zh-CN": "登录以接受您的组织邀请",
"zh-TW": "登入以接受您的組織邀請",
"ko-KR": "조직 초대를 수락하려면 로그인하세요",
"no": "Logg inn for å godta organisasjonsinvitasjonen din",
"it": "Accedi per accettare l'invito della tua organizzazione",
"pt": "Faça login para aceitar o convite da sua organização",
"es": "Inicia sesión para aceptar la invitación de tu organización",
"ar": "سجّل الدخول لقبول دعوة مؤسستك",
"fr": "Connectez-vous pour accepter l'invitation de votre organisation",
"tr": "Organizasyon davetinizi kabul etmek için giriş yapın",
"de": "Melden Sie sich an, um Ihre Organisationseinladung anzunehmen",
"uk": "Увійдіть, щоб прийняти запрошення до організації"
},
"COMMON$TERMS_OF_SERVICE": {
"en": "Terms of Service",
"ja": "利用規約",

View File

@@ -10,7 +10,6 @@ import "./tailwind.css";
import "./index.css";
import React from "react";
import { Toaster } from "react-hot-toast";
import { useInvitation } from "#/hooks/use-invitation";
export function Layout({ children }: { children: React.ReactNode }) {
return (
@@ -38,9 +37,5 @@ export const meta: MetaFunction = () => [
];
export default function App() {
// Handle invitation token cleanup when invitation flow completes
// This runs on all pages to catch redirects from auth callback
useInvitation();
return <Outlet />;
}

View File

@@ -4,7 +4,6 @@ import { useIsAuthed } from "#/hooks/query/use-is-authed";
import { useConfig } from "#/hooks/query/use-config";
import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
import { useEmailVerification } from "#/hooks/use-email-verification";
import { useInvitation } from "#/hooks/use-invitation";
import { LoginContent } from "#/components/features/auth/login-content";
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
@@ -24,8 +23,6 @@ export default function LoginPage() {
userId,
} = useEmailVerification();
const { hasInvitation, buildOAuthStateData } = useInvitation();
const gitHubAuthUrl = useGitHubAuthUrl({
appMode: config.data?.app_mode || null,
authUrl: config.data?.auth_url,
@@ -72,8 +69,6 @@ export default function LoginPage() {
emailVerified={emailVerified}
hasDuplicatedEmail={hasDuplicatedEmail}
recaptchaBlocked={recaptchaBlocked}
hasInvitation={hasInvitation}
buildOAuthStateData={buildOAuthStateData}
/>
</main>

View File

@@ -7,7 +7,7 @@ import asyncio
import logging
from dataclasses import dataclass
from typing import AsyncGenerator
from uuid import UUID, uuid4
from uuid import UUID
from fastapi import Request
from sqlalchemy import UUID as SQLUUID
@@ -59,7 +59,7 @@ class StoredEventCallback(Base): # type: ignore
class StoredEventCallbackResult(Base): # type: ignore
__tablename__ = 'event_callback_result'
id = Column(SQLUUID, primary_key=True, default=uuid4)
id = Column(SQLUUID, primary_key=True)
status = Column(Enum(EventCallbackResultStatus), nullable=True)
event_callback_id = Column(SQLUUID, index=True)
event_id = Column(String, index=True)

View File

@@ -19,104 +19,12 @@ def _get_recaptcha_site_key() -> str | None:
return key if key else None
# OSS default PostHog key - used when no environment variable is configured
_OSS_POSTHOG_KEY = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
def _get_posthog_client_key() -> str:
"""Get PostHog client key from environment variable.
Reads POSTHOG_CLIENT_KEY from environment. If not set or empty,
returns the OSS default key for backwards compatibility.
"""
key = os.getenv('POSTHOG_CLIENT_KEY', '').strip()
return key if key else _OSS_POSTHOG_KEY
def _get_auth_url() -> str | None:
"""Get authentication service URL from environment variable.
Reads AUTH_URL from environment. If not set or empty, returns None.
"""
url = os.getenv('AUTH_URL', '').strip()
return url if url else None
def _get_maintenance_start_time() -> datetime | None:
"""Get maintenance start time from environment variable.
Reads MAINTENANCE_START_TIME from environment. If set to a valid ISO 8601
timestamp, returns the parsed datetime. If empty, unset, or invalid,
returns None (graceful fallback).
"""
value = os.getenv('MAINTENANCE_START_TIME', '').strip()
if not value:
return None
try:
return datetime.fromisoformat(value)
except ValueError:
return None
def _get_providers_configured() -> list[ProviderType]:
"""Get configured OAuth providers from environment variables.
Checks for presence of OAuth client ID env vars and returns a list of
configured providers. Mirrors legacy logic from SaaSServerConfig.
"""
providers: list[ProviderType] = []
if os.getenv('GITHUB_APP_CLIENT_ID', '').strip():
providers.append(ProviderType.GITHUB)
if os.getenv('GITLAB_APP_CLIENT_ID', '').strip():
providers.append(ProviderType.GITLAB)
if os.getenv('BITBUCKET_APP_CLIENT_ID', '').strip():
providers.append(ProviderType.BITBUCKET)
if os.getenv('ENABLE_ENTERPRISE_SSO', '').strip():
providers.append(ProviderType.ENTERPRISE_SSO)
return providers
def _get_github_app_slug() -> str | None:
"""Get GitHub app slug from environment variable.
Reads GITHUB_APP_SLUG from environment. If set, returns the value.
If empty or unset, returns None.
"""
slug = os.getenv('GITHUB_APP_SLUG', '').strip()
return slug if slug else None
def _get_feature_flags() -> WebClientFeatureFlags:
"""Get feature flags from environment variables.
Reads ENABLE_BILLING, HIDE_LLM_SETTINGS, ENABLE_JIRA, ENABLE_JIRA_DC,
and ENABLE_LINEAR from environment. Each flag is True only if the
corresponding env var is exactly 'true', otherwise False.
"""
return WebClientFeatureFlags(
enable_billing=os.getenv('ENABLE_BILLING', 'false') == 'true',
hide_llm_settings=os.getenv('HIDE_LLM_SETTINGS', 'false') == 'true',
enable_jira=os.getenv('ENABLE_JIRA', 'false') == 'true',
enable_jira_dc=os.getenv('ENABLE_JIRA_DC', 'false') == 'true',
enable_linear=os.getenv('ENABLE_LINEAR', 'false') == 'true',
)
class DefaultWebClientConfigInjector(WebClientConfigInjector):
posthog_client_key: str = Field(default_factory=_get_posthog_client_key)
feature_flags: WebClientFeatureFlags = Field(default_factory=_get_feature_flags)
providers_configured: list[ProviderType] = Field(
default_factory=_get_providers_configured
)
maintenance_start_time: datetime | None = Field(
default_factory=_get_maintenance_start_time
)
auth_url: str | None = Field(default_factory=_get_auth_url)
posthog_client_key: str | None = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
feature_flags: WebClientFeatureFlags = Field(default_factory=WebClientFeatureFlags)
providers_configured: list[ProviderType] = Field(default_factory=list)
maintenance_start_time: datetime | None = None
auth_url: str | None = None
recaptcha_site_key: str | None = Field(default_factory=_get_recaptcha_site_key)
faulty_models: list[str] = Field(default_factory=list)
error_message: str | None = None
@@ -128,7 +36,7 @@ class DefaultWebClientConfigInjector(WebClientConfigInjector):
'new and should be displayed. (Default to start of 2026)'
),
)
github_app_slug: str | None = Field(default_factory=_get_github_app_slug)
github_app_slug: str | None = None
async def get_web_client_config(self) -> WebClientConfig:
from openhands.app_server.config import get_global_config

View File

@@ -1,6 +1,5 @@
import json
from datetime import datetime
from typing import Any
from json_repair import repair_json
from litellm.types.utils import ModelResponse
@@ -33,7 +32,7 @@ class OpenHandsJSONEncoder(json.JSONEncoder):
_json_encoder = OpenHandsJSONEncoder()
def dumps(obj, **kwargs) -> str:
def dumps(obj, **kwargs):
"""Serialize an object to str format"""
if not kwargs:
return _json_encoder.encode(obj)
@@ -48,7 +47,7 @@ def dumps(obj, **kwargs) -> str:
return json.dumps(obj, **encoder_kwargs)
def loads(json_str: str, **kwargs) -> Any:
def loads(json_str, **kwargs):
"""Create a JSON object from str"""
try:
return json.loads(json_str, **kwargs)

View File

@@ -198,14 +198,13 @@ class LLM(RetryMixin, DebugMixin):
if 'claude-opus-4-1' in self.config.model.lower():
kwargs['thinking'] = {'type': 'disabled'}
# Anthropic constraint: Opus 4.1, Opus 4.5, Opus 4.6, and Sonnet 4 models cannot accept both temperature and top_p
# Anthropic constraint: Opus 4.1, Opus 4.5, and Sonnet 4 models cannot accept both temperature and top_p
# Prefer temperature (drop top_p) if both are specified.
_model_lower = self.config.model.lower()
# Apply to Opus 4.1, Opus 4.5, Opus 4.6, and Sonnet 4 models to avoid API errors
# Apply to Opus 4.1, Opus 4.5, and Sonnet 4 models to avoid API errors
if (
('claude-opus-4-1' in _model_lower)
or ('claude-opus-4-5' in _model_lower)
or ('claude-opus-4-6' in _model_lower)
or ('claude-sonnet-4' in _model_lower)
) and ('temperature' in kwargs and 'top_p' in kwargs):
kwargs.pop('top_p', None)

View File

@@ -793,11 +793,7 @@ if __name__ == '__main__':
@app.middleware('http')
async def authenticate_requests(request: Request, call_next):
if (
request.url.path != '/alive'
and request.url.path != '/ready'
and request.url.path != '/server_info'
):
if request.url.path != '/alive' and request.url.path != '/server_info':
try:
verify_api_key(request.headers.get('X-Session-API-Key'))
except HTTPException as e:

View File

@@ -14,8 +14,6 @@ from openhands.runtime.utils.system_stats import get_system_info
def add_health_endpoints(app: FastAPI):
@app.get('/alive')
async def alive():
"""Endpoint for liveness probes. If this responds then the server is
considered alive."""
return {'status': 'ok'}
@app.get('/health')
@@ -25,11 +23,3 @@ def add_health_endpoints(app: FastAPI):
@app.get('/server_info')
async def get_server_info():
return get_system_info()
@app.get('/ready')
async def ready() -> str:
"""Endpoint for readiness probes. For now this is functionally the same as
the liveness probe, but should be need to establish further invariants in
the future, having a separate endpoint will mean we don't need to change
client code."""
return 'OK'

View File

@@ -32,15 +32,12 @@ CONVERSATION_URL = HOST + '/conversations/{}'
async def get_conversation_link(
service: GitService, conversation_id: str | None, body: str
service: GitService, conversation_id: str, body: str
) -> str:
"""Appends a followup link, in the PR body, to the OpenHands conversation that opened the PR"""
if server_config.app_mode != AppMode.SAAS:
return body
if not conversation_id:
return body
user = await service.get_user()
username = user.login
conversation_url = CONVERSATION_URL.format(conversation_id)

View File

@@ -0,0 +1,556 @@
#!/usr/bin/env python3
"""
Script to notify contributors when their code goes live in a release.
This script:
1. Identifies all contributors between two release tags
2. Gathers their commit information
3. Can send email notifications via Resend, SMTP, or other providers
4. Supports dry-run mode to preview notifications
Usage:
# Preview contributors for a release
python scripts/notify_release_contributors.py --from-tag 1.2.1 --to-tag 1.3.0 --dry-run
# Send email notifications via Resend (recommended)
python scripts/notify_release_contributors.py --from-tag 1.2.1 --to-tag 1.3.0 \
--email-provider resend
# Send via SMTP
python scripts/notify_release_contributors.py --from-tag 1.2.1 --to-tag 1.3.0 \
--email-provider smtp --smtp-host smtp.gmail.com --smtp-port 587 \
--smtp-user user@gmail.com --smtp-password $SMTP_PASSWORD
Environment Variables:
GITHUB_TOKEN: Required for fetching contributor info from GitHub API
RESEND_API_KEY: Required when using Resend provider
SMTP_PASSWORD: Can be used instead of --smtp-password flag
"""
import argparse
import json
import os
import re
import smtplib
import subprocess
import sys
from dataclasses import dataclass
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Any
@dataclass
class Contributor:
"""Represents a contributor to the release."""
username: str
email: str
name: str
commit_count: int
commits: list[dict[str, str]]
@property
def is_bot(self) -> bool:
return "[bot]" in self.username or self.username.endswith("-bot")
@property
def has_valid_email(self) -> bool:
"""Check if email is valid (not a GitHub noreply address)."""
if not self.email:
return False
# GitHub noreply emails are not useful for direct contact
# but we can still try to resolve actual emails via GitHub API
return "@users.noreply.github.com" not in self.email
def run_command(cmd: list[str], capture_output: bool = True) -> subprocess.CompletedProcess:
"""Run a shell command and return the result."""
return subprocess.run(cmd, capture_output=capture_output, text=True, check=True)
def get_github_token() -> str | None:
"""Get GitHub token from environment."""
return os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN")
def github_api_request(endpoint: str) -> dict[str, Any]:
"""Make a GitHub API request using gh CLI."""
token = get_github_token()
if not token:
raise ValueError("GITHUB_TOKEN environment variable is required")
env = os.environ.copy()
env["GH_TOKEN"] = token
result = subprocess.run(
["gh", "api", endpoint, "-H", "Accept: application/vnd.github+json"],
capture_output=True,
text=True,
env=env,
)
if result.returncode != 0:
raise RuntimeError(f"GitHub API request failed: {result.stderr}")
return json.loads(result.stdout)
def get_contributors_between_tags(
repo: str, from_tag: str, to_tag: str
) -> list[Contributor]:
"""
Get all contributors between two git tags using GitHub API.
Args:
repo: Repository in 'owner/repo' format
from_tag: Starting tag (exclusive)
to_tag: Ending tag (inclusive)
Returns:
List of Contributor objects
"""
# Use GitHub compare API to get commits between tags
endpoint = f"repos/{repo}/compare/{from_tag}...{to_tag}"
try:
data = github_api_request(endpoint)
except RuntimeError as e:
print(f"Warning: GitHub API request failed: {e}")
print("Falling back to git log...")
return get_contributors_via_git(from_tag, to_tag)
# Group commits by author
contributors_map: dict[str, Contributor] = {}
for commit_data in data.get("commits", []):
author = commit_data.get("author") or {}
commit_info = commit_data.get("commit", {})
author_info = commit_info.get("author", {})
username = author.get("login", "unknown")
email = author_info.get("email", "")
name = author_info.get("name", username)
sha = commit_data.get("sha", "")[:7]
message = commit_info.get("message", "").split("\n")[0] # First line only
if username not in contributors_map:
contributors_map[username] = Contributor(
username=username,
email=email,
name=name,
commit_count=0,
commits=[],
)
contributors_map[username].commit_count += 1
contributors_map[username].commits.append({"sha": sha, "message": message})
return list(contributors_map.values())
def get_contributors_via_git(from_tag: str, to_tag: str) -> list[Contributor]:
"""Fallback method using git log directly."""
result = run_command(
[
"git",
"--no-pager",
"log",
f"{from_tag}..{to_tag}",
"--format=%H|%ae|%an|%s",
"--no-merges",
]
)
contributors_map: dict[str, Contributor] = {}
for line in result.stdout.strip().split("\n"):
if not line:
continue
parts = line.split("|", 3)
if len(parts) < 4:
continue
sha, email, name, message = parts
username = email.split("@")[0]
# Extract username from GitHub noreply email
noreply_match = re.match(r"(\d+\+)?([^@]+)@users\.noreply\.github\.com", email)
if noreply_match:
username = noreply_match.group(2)
if username not in contributors_map:
contributors_map[username] = Contributor(
username=username,
email=email,
name=name,
commit_count=0,
commits=[],
)
contributors_map[username].commit_count += 1
contributors_map[username].commits.append({"sha": sha[:7], "message": message})
return list(contributors_map.values())
def resolve_user_email(username: str) -> str | None:
"""Try to resolve a user's public email from their GitHub profile."""
try:
user_data = github_api_request(f"users/{username}")
return user_data.get("email")
except Exception:
return None
def generate_email_content(
contributor: Contributor, release_tag: str, repo: str, release_url: str
) -> tuple[str, str, str]:
"""
Generate email subject and body (plain text and HTML).
Returns:
Tuple of (subject, plain_text_body, html_body)
"""
subject = f"🎉 Your code is live in {repo.split('/')[-1]} {release_tag}!"
# List of commits (limit to 10 for readability)
commits_display = contributor.commits[:10]
commits_text = "\n".join(
f"{c['sha']}: {c['message'][:60]}..." if len(c["message"]) > 60 else f"{c['sha']}: {c['message']}"
for c in commits_display
)
if len(contributor.commits) > 10:
commits_text += f"\n ... and {len(contributor.commits) - 10} more commits"
plain_text = f"""Hi {contributor.name},
Great news! Your contributions have been included in {repo.split('/')[-1]} {release_tag}, which is now live in production! 🚀
Your contributions ({contributor.commit_count} commit{'s' if contributor.commit_count > 1 else ''}):
{commits_text}
View the full release notes: {release_url}
Please take a moment to verify that your changes are working as intended in production. If you notice any issues or unexpected behavior, we'd appreciate it if you could submit follow-up fixes as needed.
Thank you for your valuable contributions to the project!
Best regards,
The OpenHands Team
"""
commits_html = "".join(
f"<li><code>{c['sha']}</code>: {c['message'][:60]}{'...' if len(c['message']) > 60 else ''}</li>"
for c in commits_display
)
if len(contributor.commits) > 10:
commits_html += f"<li><em>... and {len(contributor.commits) - 10} more commits</em></li>"
html_body = f"""
<!DOCTYPE html>
<html>
<head>
<style>
body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; line-height: 1.6; color: #333; }}
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }}
.header {{ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 30px; border-radius: 10px 10px 0 0; }}
.content {{ background: #f9fafb; padding: 30px; border-radius: 0 0 10px 10px; }}
.commits {{ background: white; padding: 15px; border-radius: 8px; margin: 20px 0; }}
.commits ul {{ margin: 0; padding-left: 20px; }}
.commits li {{ margin: 8px 0; }}
.cta {{ display: inline-block; background: #667eea; color: white; padding: 12px 24px; text-decoration: none; border-radius: 6px; margin-top: 20px; }}
.notice {{ background: #fef3c7; border-left: 4px solid #f59e0b; padding: 12px 16px; margin: 20px 0; border-radius: 0 8px 8px 0; }}
code {{ background: #e5e7eb; padding: 2px 6px; border-radius: 4px; font-size: 0.9em; }}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1 style="margin: 0;">🎉 Your Code is Live!</h1>
<p style="margin: 10px 0 0 0; opacity: 0.9;">{repo.split('/')[-1]} {release_tag}</p>
</div>
<div class="content">
<p>Hi {contributor.name},</p>
<p>Great news! Your contributions have been included in <strong>{repo.split('/')[-1]} {release_tag}</strong>, which is now live in production! 🚀</p>
<div class="commits">
<strong>Your contributions ({contributor.commit_count} commit{'s' if contributor.commit_count > 1 else ''}):</strong>
<ul>{commits_html}</ul>
</div>
<div class="notice">
<strong>🔍 Please verify your changes</strong><br>
Take a moment to check that your changes are working as intended in production. If you notice any issues or unexpected behavior, please submit follow-up fixes as needed.
</div>
<p>Thank you for your valuable contributions to the project!</p>
<a href="{release_url}" class="cta">View Release Notes →</a>
<p style="margin-top: 30px; color: #6b7280; font-size: 0.9em;">
Best regards,<br>
The OpenHands Team
</p>
</div>
</div>
</body>
</html>
"""
return subject, plain_text, html_body
def send_email_smtp(
to_email: str,
subject: str,
plain_text: str,
html_body: str,
smtp_host: str,
smtp_port: int,
smtp_user: str,
smtp_password: str,
from_email: str,
) -> bool:
"""Send email via SMTP."""
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["From"] = from_email
msg["To"] = to_email
msg.attach(MIMEText(plain_text, "plain"))
msg.attach(MIMEText(html_body, "html"))
try:
with smtplib.SMTP(smtp_host, smtp_port) as server:
server.starttls()
server.login(smtp_user, smtp_password)
server.sendmail(from_email, [to_email], msg.as_string())
return True
except Exception as e:
print(f" Error sending email to {to_email}: {e}")
return False
def send_email_resend(
to_email: str,
subject: str,
plain_text: str,
html_body: str,
api_key: str,
from_email: str,
) -> bool:
"""Send email via Resend API."""
try:
import resend
except ImportError:
print("Error: 'resend' package required. Install with: pip install resend")
return False
resend.api_key = api_key
try:
resend.Emails.send({
"from": from_email,
"to": [to_email],
"subject": subject,
"html": html_body,
"text": plain_text,
})
return True
except Exception as e:
print(f" Error sending email to {to_email}: {e}")
return False
def main():
parser = argparse.ArgumentParser(
description="Notify contributors when their code goes live in a release"
)
parser.add_argument(
"--repo",
default="OpenHands/OpenHands",
help="GitHub repository (owner/repo format)",
)
parser.add_argument("--from-tag", required=True, help="Starting tag (exclusive)")
parser.add_argument("--to-tag", required=True, help="Ending tag (inclusive)")
parser.add_argument(
"--dry-run",
action="store_true",
help="Preview notifications without sending emails",
)
parser.add_argument(
"--include-bots",
action="store_true",
help="Include bot accounts in notifications",
)
parser.add_argument(
"--resolve-emails",
action="store_true",
help="Try to resolve emails from GitHub profiles for noreply addresses",
)
parser.add_argument(
"--output-json",
type=str,
help="Output contributor data to JSON file",
)
# Email configuration
parser.add_argument(
"--email-provider",
choices=["smtp", "resend", "none"],
default="none",
help="Email provider to use",
)
parser.add_argument(
"--from-email",
default="OpenHands Team <contact@all-hands.dev>",
help="Sender email address",
)
# SMTP settings
parser.add_argument("--smtp-host", help="SMTP server hostname")
parser.add_argument("--smtp-port", type=int, default=587, help="SMTP server port")
parser.add_argument("--smtp-user", help="SMTP username")
parser.add_argument("--smtp-password", help="SMTP password (or use SMTP_PASSWORD env)")
# Resend settings
parser.add_argument(
"--resend-api-key",
help="Resend API key (or use RESEND_API_KEY env)",
)
args = parser.parse_args()
print(f"🔍 Finding contributors between {args.from_tag} and {args.to_tag}...")
# Get contributors
contributors = get_contributors_between_tags(args.repo, args.from_tag, args.to_tag)
# Filter bots if needed
if not args.include_bots:
bots = [c for c in contributors if c.is_bot]
contributors = [c for c in contributors if not c.is_bot]
if bots:
print(f" Excluding {len(bots)} bot account(s): {', '.join(b.username for b in bots)}")
# Sort by commit count
contributors.sort(key=lambda c: c.commit_count, reverse=True)
print(f"\n📊 Found {len(contributors)} contributor(s):\n")
for c in contributors:
email_status = "" if c.has_valid_email else "⚠ noreply"
print(f" {c.name} (@{c.username})")
print(f" Email: {c.email} [{email_status}]")
print(f" Commits: {c.commit_count}")
print()
# Try to resolve emails for noreply addresses
if args.resolve_emails:
print("🔎 Resolving emails from GitHub profiles...")
for c in contributors:
if not c.has_valid_email:
resolved = resolve_user_email(c.username)
if resolved:
print(f" ✓ Resolved {c.username}: {resolved}")
c.email = resolved
else:
print(f" ✗ Could not resolve email for {c.username}")
# Output JSON if requested
if args.output_json:
output_data = {
"release": args.to_tag,
"previous_release": args.from_tag,
"repository": args.repo,
"contributors": [
{
"username": c.username,
"email": c.email,
"name": c.name,
"commit_count": c.commit_count,
"has_valid_email": c.has_valid_email,
"commits": c.commits,
}
for c in contributors
],
}
with open(args.output_json, "w") as f:
json.dump(output_data, f, indent=2)
print(f"\n📁 Contributor data saved to {args.output_json}")
# Generate release URL
release_url = f"https://github.com/{args.repo}/releases/tag/{args.to_tag}"
# Send emails if not dry run
if args.dry_run:
print("\n📧 DRY RUN - Email preview:\n")
for c in contributors:
if c.has_valid_email:
subject, plain_text, _ = generate_email_content(
c, args.to_tag, args.repo, release_url
)
print(f"To: {c.email}")
print(f"Subject: {subject}")
print("-" * 40)
print(plain_text[:500] + "..." if len(plain_text) > 500 else plain_text)
print("=" * 60 + "\n")
elif args.email_provider != "none":
print(f"\n📧 Sending notifications via {args.email_provider}...")
# Filter to only those with valid emails
notifiable = [c for c in contributors if c.has_valid_email]
if len(notifiable) < len(contributors):
print(
f" ⚠ Skipping {len(contributors) - len(notifiable)} contributor(s) with noreply emails"
)
sent_count = 0
failed_count = 0
for c in notifiable:
subject, plain_text, html_body = generate_email_content(
c, args.to_tag, args.repo, release_url
)
success = False
if args.email_provider == "smtp":
password = args.smtp_password or os.environ.get("SMTP_PASSWORD")
if not all([args.smtp_host, args.smtp_user, password]):
print("Error: SMTP requires --smtp-host, --smtp-user, and --smtp-password")
sys.exit(1)
success = send_email_smtp(
c.email,
subject,
plain_text,
html_body,
args.smtp_host,
args.smtp_port,
args.smtp_user,
password,
args.from_email,
)
elif args.email_provider == "resend":
api_key = args.resend_api_key or os.environ.get("RESEND_API_KEY")
if not api_key:
print("Error: Resend requires --resend-api-key or RESEND_API_KEY env")
sys.exit(1)
success = send_email_resend(
c.email, subject, plain_text, html_body, api_key, args.from_email
)
if success:
print(f" ✓ Sent to {c.name} <{c.email}>")
sent_count += 1
else:
print(f" ✗ Failed to send to {c.name} <{c.email}>")
failed_count += 1
print(f"\n📊 Summary: {sent_count} sent, {failed_count} failed")
else:
print("\n💡 To send emails, use --email-provider (smtp or sendgrid)")
print(" Or use --dry-run to preview email content")
if __name__ == "__main__":
main()

View File

@@ -6,13 +6,10 @@ then ensures the GitHub token is set in Settings → Integrations and that the
home screen shows the repository selector.
"""
import logging
import os
from playwright.sync_api import Page, expect
logger = logging.getLogger(__name__)
def test_github_token_configuration(page: Page, base_url: str):
"""
@@ -31,51 +28,51 @@ def test_github_token_configuration(page: Page, base_url: str):
base_url = 'http://localhost:12000'
# Navigate to the OpenHands application
logger.info(f'Step 1: Navigating to OpenHands application at {base_url}...')
print(f'Step 1: Navigating to OpenHands application at {base_url}...')
page.goto(base_url)
page.wait_for_load_state('networkidle', timeout=30000)
# Take initial screenshot
page.screenshot(path='test-results/token_01_initial_load.png')
logger.info('Screenshot saved: token_01_initial_load.png')
print('Screenshot saved: token_01_initial_load.png')
# Step 1.5: Handle any initial modals that might appear (LLM API key configuration)
try:
# Check for AI Provider Configuration modal
config_modal = page.locator('text=AI Provider Configuration')
if config_modal.is_visible(timeout=5000):
logger.info('AI Provider Configuration modal detected')
print('AI Provider Configuration modal detected')
# Fill in the LLM API key if available
llm_api_key_input = page.locator('[data-testid="llm-api-key-input"]')
if llm_api_key_input.is_visible(timeout=3000):
llm_api_key = os.getenv('LLM_API_KEY', 'test-key')
llm_api_key_input.fill(llm_api_key)
logger.info(f'Filled LLM API key (length: {len(llm_api_key)})')
print(f'Filled LLM API key (length: {len(llm_api_key)})')
# Click the Save button
save_button = page.locator('button:has-text("Save")')
if save_button.is_visible(timeout=3000):
save_button.click()
page.wait_for_timeout(2000)
logger.info('Saved LLM API key configuration')
print('Saved LLM API key configuration')
# Check for Privacy Preferences modal
privacy_modal = page.locator('text=Your Privacy Preferences')
if privacy_modal.is_visible(timeout=5000):
logger.info('Privacy Preferences modal detected')
print('Privacy Preferences modal detected')
confirm_button = page.locator('button:has-text("Confirm Preferences")')
if confirm_button.is_visible(timeout=3000):
confirm_button.click()
page.wait_for_timeout(2000)
logger.info('Confirmed privacy preferences')
print('Confirmed privacy preferences')
except Exception as e:
logger.error(f'Error handling initial modals: {e}')
print(f'Error handling initial modals: {e}')
page.screenshot(path='test-results/token_01_5_modal_error.png')
logger.info('Screenshot saved: token_01_5_modal_error.png')
print('Screenshot saved: token_01_5_modal_error.png')
# Step 2: Check if GitHub token is already configured or needs to be set
logger.info('Step 2: Checking if GitHub token is configured...')
print('Step 2: Checking if GitHub token is configured...')
try:
# First, check if we're already on the home screen with repository selection
@@ -83,7 +80,7 @@ def test_github_token_configuration(page: Page, base_url: str):
connect_to_provider = page.locator('text=Connect to a Repository')
if connect_to_provider.is_visible(timeout=3000):
logger.info('Found "Connect to a Repository" section')
print('Found "Connect to a Repository" section')
# Check if we need to configure a provider (GitHub token)
navigate_to_settings_button = page.locator(
@@ -91,9 +88,7 @@ def test_github_token_configuration(page: Page, base_url: str):
)
if navigate_to_settings_button.is_visible(timeout=3000):
logger.info(
'GitHub token not configured. Need to navigate to settings...'
)
print('GitHub token not configured. Need to navigate to settings...')
# Click the Settings button to navigate to the settings page
navigate_to_settings_button.click()
@@ -101,21 +96,19 @@ def test_github_token_configuration(page: Page, base_url: str):
page.wait_for_timeout(3000) # Wait for navigation to complete
# We should now be on the /settings/integrations page
logger.info(
'Navigated to settings page, looking for GitHub token input...'
)
print('Navigated to settings page, looking for GitHub token input...')
# Check if we're on the settings page with the integrations tab
settings_screen = page.locator('[data-testid="settings-screen"]')
if settings_screen.is_visible(timeout=5000):
logger.info('Settings screen is visible')
print('Settings screen is visible')
# Make sure we're on the Integrations tab
integrations_tab = page.locator('text=Integrations')
if integrations_tab.is_visible(timeout=3000):
# Check if we need to click the tab
if not page.url.endswith('/settings/integrations'):
logger.info('Clicking Integrations tab...')
print('Clicking Integrations tab...')
integrations_tab.click()
page.wait_for_load_state('networkidle')
page.wait_for_timeout(2000)
@@ -125,7 +118,7 @@ def test_github_token_configuration(page: Page, base_url: str):
'[data-testid="github-token-input"]'
)
if github_token_input.is_visible(timeout=5000):
logger.info('Found GitHub token input field')
print('Found GitHub token input field')
# Fill in the GitHub token from environment variable
github_token = os.getenv('GITHUB_TOKEN', '')
@@ -133,18 +126,18 @@ def test_github_token_configuration(page: Page, base_url: str):
# Clear the field first, then fill it
github_token_input.clear()
github_token_input.fill(github_token)
logger.info(
print(
f'Filled GitHub token from environment variable (length: {len(github_token)})'
)
# Verify the token was filled
filled_value = github_token_input.input_value()
if filled_value:
logger.info(
print(
f'Token field now contains value of length: {len(filled_value)}'
)
else:
logger.warning(
print(
'WARNING: Token field appears to be empty after filling'
)
@@ -153,12 +146,12 @@ def test_github_token_configuration(page: Page, base_url: str):
if save_button.is_visible(timeout=3000):
# Check if button is enabled
is_disabled = save_button.is_disabled()
logger.info(
print(
f'Save Changes button found, disabled: {is_disabled}'
)
if not is_disabled:
logger.info('Clicking Save Changes button...')
print('Clicking Save Changes button...')
save_button.click()
# Wait for the save operation to complete
@@ -171,52 +164,46 @@ def test_github_token_configuration(page: Page, base_url: str):
'document.querySelector(\'[data-testid="submit-button"]\').disabled === true',
timeout=10000,
)
logger.info(
print(
'Save operation completed - form is now clean'
)
except Exception:
logger.warning(
print(
'Save operation completed (timeout waiting for form clean state)'
)
# Navigate back to home page after successful save
logger.info('Navigating back to home page...')
print('Navigating back to home page...')
page.goto(base_url)
page.wait_for_load_state('networkidle')
page.wait_for_timeout(
5000
) # Wait longer for providers to be updated
else:
logger.warning(
print(
'Save Changes button is disabled - form may be invalid'
)
else:
logger.warning('Save Changes button not found')
print('Save Changes button not found')
else:
logger.warning(
'No GitHub token found in environment variables'
)
print('No GitHub token found in environment variables')
else:
logger.warning(
'GitHub token input field not found on settings page'
)
print('GitHub token input field not found on settings page')
# Take a screenshot to see what's on the page
page.screenshot(path='test-results/token_02_settings_debug.png')
logger.info(
'Debug screenshot saved: token_02_settings_debug.png'
)
print('Debug screenshot saved: token_02_settings_debug.png')
else:
logger.warning('Settings screen not found')
print('Settings screen not found')
else:
# Branch 2: GitHub token is already configured, repository selection is available
logger.info(
print(
'GitHub token is already configured, repository selection is available'
)
# Check if we need to update the token by going to settings manually
settings_button = page.locator('button:has-text("Settings")')
if settings_button.is_visible(timeout=3000):
logger.info(
print(
'Settings button found, clicking to navigate to settings page...'
)
settings_button.click()
@@ -226,7 +213,7 @@ def test_github_token_configuration(page: Page, base_url: str):
# Navigate to the Integrations tab
integrations_tab = page.locator('text=Integrations')
if integrations_tab.is_visible(timeout=3000):
logger.info('Clicking Integrations tab...')
print('Clicking Integrations tab...')
integrations_tab.click()
page.wait_for_load_state('networkidle')
page.wait_for_timeout(2000)
@@ -236,7 +223,7 @@ def test_github_token_configuration(page: Page, base_url: str):
'[data-testid="github-token-input"]'
)
if github_token_input.is_visible(timeout=5000):
logger.info('Found GitHub token input field')
print('Found GitHub token input field')
# Fill in the GitHub token from environment variable
github_token = os.getenv('GITHUB_TOKEN', '')
@@ -244,7 +231,7 @@ def test_github_token_configuration(page: Page, base_url: str):
# Clear the field first, then fill it
github_token_input.clear()
github_token_input.fill(github_token)
logger.info(
print(
f'Filled GitHub token from environment variable (length: {len(github_token)})'
)
@@ -256,56 +243,52 @@ def test_github_token_configuration(page: Page, base_url: str):
save_button.is_visible(timeout=3000)
and not save_button.is_disabled()
):
logger.info('Clicking Save Changes button...')
print('Clicking Save Changes button...')
save_button.click()
page.wait_for_timeout(3000)
# Navigate back to home page
logger.info('Navigating back to home page...')
print('Navigating back to home page...')
page.goto(base_url)
page.wait_for_load_state('networkidle')
page.wait_for_timeout(3000)
else:
logger.warning(
print(
'GitHub token input field not found, going back to home page'
)
page.goto(base_url)
page.wait_for_load_state('networkidle')
else:
logger.warning(
'Integrations tab not found, going back to home page'
)
print('Integrations tab not found, going back to home page')
page.goto(base_url)
page.wait_for_load_state('networkidle')
else:
logger.info(
'Settings button not found, continuing with existing token'
)
print('Settings button not found, continuing with existing token')
else:
logger.warning('Could not find "Connect to a Repository" section')
print('Could not find "Connect to a Repository" section')
page.screenshot(path='test-results/token_03_after_settings.png')
logger.info('Screenshot saved: token_03_after_settings.png')
print('Screenshot saved: token_03_after_settings.png')
except Exception as e:
logger.error(f'Error checking GitHub token configuration: {e}')
print(f'Error checking GitHub token configuration: {e}')
page.screenshot(path='test-results/token_04_error.png')
logger.info('Screenshot saved: token_04_error.png')
print('Screenshot saved: token_04_error.png')
# Step 3: Verify we're back on the home screen with repository selection available
logger.info('Step 3: Verifying repository selection is available...')
print('Step 3: Verifying repository selection is available...')
# Wait for the home screen to load
home_screen = page.locator('[data-testid="home-screen"]')
expect(home_screen).to_be_visible(timeout=15000)
logger.info('Home screen is visible')
print('Home screen is visible')
# Look for the repository dropdown/selector
repo_dropdown = page.locator('[data-testid="repo-dropdown"]')
expect(repo_dropdown).to_be_visible(timeout=15000)
logger.info('Repository dropdown is visible')
print('Repository dropdown is visible')
# Success - we've verified the GitHub token configuration
logger.info('GitHub token configuration verified successfully')
print('GitHub token configuration verified successfully')
page.screenshot(path='test-results/token_05_success.png')
logger.info('Screenshot saved: token_05_success.png')
print('Screenshot saved: token_05_success.png')

View File

@@ -1,468 +0,0 @@
"""Tests for DefaultWebClientConfigInjector.
This module tests environment variable handling in DefaultWebClientConfigInjector.
"""
import os
from unittest.mock import patch
class TestGetPosthogClientKey:
"""Test cases for _get_posthog_client_key helper function."""
OSS_DEFAULT_KEY = 'phc_3ESMmY9SgqEAGBB6sMGK5ayYHkeUuknH2vP6FmWH9RA'
def test_returns_env_var_when_set(self):
"""When POSTHOG_CLIENT_KEY is set, return that value."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_posthog_client_key,
)
with patch.dict(os.environ, {'POSTHOG_CLIENT_KEY': 'phc_saas_key_123'}):
result = _get_posthog_client_key()
assert result == 'phc_saas_key_123'
def test_returns_oss_default_when_env_var_unset(self):
"""When POSTHOG_CLIENT_KEY is not set, return the OSS default key."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_posthog_client_key,
)
with patch.dict(os.environ, {}, clear=True):
# Ensure POSTHOG_CLIENT_KEY is not in environment
os.environ.pop('POSTHOG_CLIENT_KEY', None)
result = _get_posthog_client_key()
assert result == self.OSS_DEFAULT_KEY
def test_returns_oss_default_when_env_var_empty(self):
"""When POSTHOG_CLIENT_KEY is empty string, return the OSS default key."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_posthog_client_key,
)
with patch.dict(os.environ, {'POSTHOG_CLIENT_KEY': ''}):
result = _get_posthog_client_key()
assert result == self.OSS_DEFAULT_KEY
def test_strips_whitespace_from_env_var(self):
"""When POSTHOG_CLIENT_KEY has whitespace, strip it."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_posthog_client_key,
)
with patch.dict(os.environ, {'POSTHOG_CLIENT_KEY': ' phc_trimmed_key '}):
result = _get_posthog_client_key()
assert result == 'phc_trimmed_key'
def test_returns_oss_default_when_env_var_only_whitespace(self):
"""When POSTHOG_CLIENT_KEY is only whitespace, return the OSS default key."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_posthog_client_key,
)
with patch.dict(os.environ, {'POSTHOG_CLIENT_KEY': ' '}):
result = _get_posthog_client_key()
assert result == self.OSS_DEFAULT_KEY
class TestGetAuthUrl:
"""Test cases for _get_auth_url helper function."""
def test_returns_env_var_when_set(self):
"""When AUTH_URL is set, return that value."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_auth_url,
)
with patch.dict(os.environ, {'AUTH_URL': 'https://auth.example.com'}):
result = _get_auth_url()
assert result == 'https://auth.example.com'
def test_returns_none_when_env_var_unset(self):
"""When AUTH_URL is not set, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_auth_url,
)
with patch.dict(os.environ, {}, clear=True):
os.environ.pop('AUTH_URL', None)
result = _get_auth_url()
assert result is None
def test_returns_none_when_env_var_empty(self):
"""When AUTH_URL is empty string, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_auth_url,
)
with patch.dict(os.environ, {'AUTH_URL': ''}):
result = _get_auth_url()
assert result is None
def test_strips_whitespace_from_env_var(self):
"""When AUTH_URL has whitespace, strip it."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_auth_url,
)
with patch.dict(os.environ, {'AUTH_URL': ' https://auth.example.com '}):
result = _get_auth_url()
assert result == 'https://auth.example.com'
def test_returns_none_when_env_var_only_whitespace(self):
"""When AUTH_URL is only whitespace, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_auth_url,
)
with patch.dict(os.environ, {'AUTH_URL': ' '}):
result = _get_auth_url()
assert result is None
class TestGetFeatureFlags:
"""Test cases for _get_feature_flags helper function."""
def test_returns_all_false_when_no_env_vars_set(self):
"""When no feature flag env vars are set, all flags default to False."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {}, clear=True):
# Remove any existing feature flag env vars
for var in [
'ENABLE_BILLING',
'HIDE_LLM_SETTINGS',
'ENABLE_JIRA',
'ENABLE_JIRA_DC',
'ENABLE_LINEAR',
]:
os.environ.pop(var, None)
result = _get_feature_flags()
assert result.enable_billing is False
assert result.hide_llm_settings is False
assert result.enable_jira is False
assert result.enable_jira_dc is False
assert result.enable_linear is False
def test_enable_billing_true_when_env_var_true(self):
"""When ENABLE_BILLING is 'true', enable_billing flag is True."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'ENABLE_BILLING': 'true'}):
result = _get_feature_flags()
assert result.enable_billing is True
def test_enable_billing_false_when_env_var_false(self):
"""When ENABLE_BILLING is 'false', enable_billing flag is False."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'ENABLE_BILLING': 'false'}):
result = _get_feature_flags()
assert result.enable_billing is False
def test_enable_billing_false_when_env_var_other_value(self):
"""When ENABLE_BILLING is any value other than 'true', enable_billing is False."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'ENABLE_BILLING': 'yes'}):
result = _get_feature_flags()
assert result.enable_billing is False
def test_hide_llm_settings_true_when_env_var_true(self):
"""When HIDE_LLM_SETTINGS is 'true', hide_llm_settings flag is True."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'HIDE_LLM_SETTINGS': 'true'}):
result = _get_feature_flags()
assert result.hide_llm_settings is True
def test_enable_jira_true_when_env_var_true(self):
"""When ENABLE_JIRA is 'true', enable_jira flag is True."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'ENABLE_JIRA': 'true'}):
result = _get_feature_flags()
assert result.enable_jira is True
def test_enable_jira_dc_true_when_env_var_true(self):
"""When ENABLE_JIRA_DC is 'true', enable_jira_dc flag is True."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'ENABLE_JIRA_DC': 'true'}):
result = _get_feature_flags()
assert result.enable_jira_dc is True
def test_enable_linear_true_when_env_var_true(self):
"""When ENABLE_LINEAR is 'true', enable_linear flag is True."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(os.environ, {'ENABLE_LINEAR': 'true'}):
result = _get_feature_flags()
assert result.enable_linear is True
def test_multiple_flags_can_be_set(self):
"""Multiple feature flags can be enabled simultaneously."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_feature_flags,
)
with patch.dict(
os.environ,
{
'ENABLE_BILLING': 'true',
'HIDE_LLM_SETTINGS': 'true',
'ENABLE_JIRA': 'false',
'ENABLE_LINEAR': 'true',
},
):
result = _get_feature_flags()
assert result.enable_billing is True
assert result.hide_llm_settings is True
assert result.enable_jira is False
assert result.enable_jira_dc is False
assert result.enable_linear is True
class TestGetMaintenanceStartTime:
"""Test cases for _get_maintenance_start_time helper function."""
def test_returns_datetime_when_valid_iso_timestamp_set(self):
"""When MAINTENANCE_START_TIME is a valid ISO 8601 timestamp, return parsed datetime."""
from datetime import datetime, timezone
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_maintenance_start_time,
)
with patch.dict(os.environ, {'MAINTENANCE_START_TIME': '2026-03-15T10:00:00Z'}):
result = _get_maintenance_start_time()
assert result == datetime(2026, 3, 15, 10, 0, 0, tzinfo=timezone.utc)
def test_returns_none_when_env_var_unset(self):
"""When MAINTENANCE_START_TIME is not set, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_maintenance_start_time,
)
with patch.dict(os.environ, {}, clear=True):
os.environ.pop('MAINTENANCE_START_TIME', None)
result = _get_maintenance_start_time()
assert result is None
def test_returns_none_when_env_var_empty(self):
"""When MAINTENANCE_START_TIME is empty string, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_maintenance_start_time,
)
with patch.dict(os.environ, {'MAINTENANCE_START_TIME': ''}):
result = _get_maintenance_start_time()
assert result is None
def test_returns_none_when_env_var_invalid(self):
"""When MAINTENANCE_START_TIME is invalid format, return None (graceful fallback)."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_maintenance_start_time,
)
with patch.dict(
os.environ, {'MAINTENANCE_START_TIME': 'not-a-valid-timestamp'}
):
result = _get_maintenance_start_time()
assert result is None
def test_strips_whitespace_from_env_var(self):
"""When MAINTENANCE_START_TIME has whitespace, strip it before parsing."""
from datetime import datetime, timezone
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_maintenance_start_time,
)
with patch.dict(
os.environ, {'MAINTENANCE_START_TIME': ' 2026-03-15T10:00:00Z '}
):
result = _get_maintenance_start_time()
assert result == datetime(2026, 3, 15, 10, 0, 0, tzinfo=timezone.utc)
class TestGetProvidersConfigured:
"""Test cases for _get_providers_configured helper function."""
def test_returns_empty_list_when_no_env_vars_set(self):
"""When no provider env vars are set, return empty list."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
with patch.dict(os.environ, {}, clear=True):
# Remove any existing provider env vars
for var in [
'GITHUB_APP_CLIENT_ID',
'GITLAB_APP_CLIENT_ID',
'BITBUCKET_APP_CLIENT_ID',
'ENABLE_ENTERPRISE_SSO',
]:
os.environ.pop(var, None)
result = _get_providers_configured()
assert result == []
def test_includes_github_when_client_id_set(self):
"""When GITHUB_APP_CLIENT_ID is set, include GitHub in providers."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(os.environ, {'GITHUB_APP_CLIENT_ID': 'some-client-id'}):
result = _get_providers_configured()
assert ProviderType.GITHUB in result
def test_includes_gitlab_when_client_id_set(self):
"""When GITLAB_APP_CLIENT_ID is set, include GitLab in providers."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(os.environ, {'GITLAB_APP_CLIENT_ID': 'some-client-id'}):
result = _get_providers_configured()
assert ProviderType.GITLAB in result
def test_includes_bitbucket_when_client_id_set(self):
"""When BITBUCKET_APP_CLIENT_ID is set, include Bitbucket in providers."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(os.environ, {'BITBUCKET_APP_CLIENT_ID': 'some-client-id'}):
result = _get_providers_configured()
assert ProviderType.BITBUCKET in result
def test_includes_enterprise_sso_when_enabled(self):
"""When ENABLE_ENTERPRISE_SSO is set, include Enterprise SSO in providers."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(os.environ, {'ENABLE_ENTERPRISE_SSO': 'true'}):
result = _get_providers_configured()
assert ProviderType.ENTERPRISE_SSO in result
def test_excludes_provider_when_env_var_empty(self):
"""When env var is empty string, do not include provider."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(os.environ, {'GITHUB_APP_CLIENT_ID': ''}):
result = _get_providers_configured()
assert ProviderType.GITHUB not in result
def test_excludes_provider_when_env_var_only_whitespace(self):
"""When env var is only whitespace, do not include provider."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(os.environ, {'GITHUB_APP_CLIENT_ID': ' '}):
result = _get_providers_configured()
assert ProviderType.GITHUB not in result
def test_includes_multiple_providers(self):
"""Multiple providers can be configured simultaneously."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_providers_configured,
)
from openhands.integrations.service_types import ProviderType
with patch.dict(
os.environ,
{
'GITHUB_APP_CLIENT_ID': 'github-id',
'GITLAB_APP_CLIENT_ID': 'gitlab-id',
'BITBUCKET_APP_CLIENT_ID': '',
'ENABLE_ENTERPRISE_SSO': 'enabled',
},
):
result = _get_providers_configured()
assert ProviderType.GITHUB in result
assert ProviderType.GITLAB in result
assert ProviderType.BITBUCKET not in result
assert ProviderType.ENTERPRISE_SSO in result
assert len(result) == 3
class TestGetGithubAppSlug:
"""Test cases for _get_github_app_slug helper function."""
def test_returns_env_var_when_set(self):
"""When GITHUB_APP_SLUG is set, return that value."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_github_app_slug,
)
with patch.dict(os.environ, {'GITHUB_APP_SLUG': 'openhands-app'}):
result = _get_github_app_slug()
assert result == 'openhands-app'
def test_returns_none_when_env_var_unset(self):
"""When GITHUB_APP_SLUG is not set, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_github_app_slug,
)
with patch.dict(os.environ, {}, clear=True):
os.environ.pop('GITHUB_APP_SLUG', None)
result = _get_github_app_slug()
assert result is None
def test_returns_none_when_env_var_empty(self):
"""When GITHUB_APP_SLUG is empty string, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_github_app_slug,
)
with patch.dict(os.environ, {'GITHUB_APP_SLUG': ''}):
result = _get_github_app_slug()
assert result is None
def test_strips_whitespace_from_env_var(self):
"""When GITHUB_APP_SLUG has whitespace, strip it."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_github_app_slug,
)
with patch.dict(os.environ, {'GITHUB_APP_SLUG': ' openhands-app '}):
result = _get_github_app_slug()
assert result == 'openhands-app'
def test_returns_none_when_env_var_only_whitespace(self):
"""When GITHUB_APP_SLUG is only whitespace, return None."""
from openhands.app_server.web_client.default_web_client_config_injector import (
_get_github_app_slug,
)
with patch.dict(os.environ, {'GITHUB_APP_SLUG': ' '}):
result = _get_github_app_slug()
assert result is None

View File

@@ -1274,25 +1274,6 @@ def test_opus_45_keeps_temperature_drops_top_p(mock_completion):
assert 'top_p' not in call_kwargs
@patch('openhands.llm.llm.litellm_completion')
def test_opus_46_keeps_temperature_drops_top_p(mock_completion):
mock_completion.return_value = {
'choices': [{'message': {'content': 'ok'}}],
}
config = LLMConfig(
model='anthropic/claude-opus-4-6',
api_key='k',
temperature=0.7,
top_p=0.9,
)
llm = LLM(config, service_id='svc')
llm.completion(messages=[{'role': 'user', 'content': 'hi'}])
call_kwargs = mock_completion.call_args[1]
assert call_kwargs.get('temperature') == 0.7
# Anthropic rejects both temperature and top_p together on Opus 4.6; we keep temperature and drop top_p
assert 'top_p' not in call_kwargs
@patch('openhands.llm.llm.litellm_completion')
def test_sonnet_4_keeps_temperature_drops_top_p(mock_completion):
mock_completion.return_value = {

View File

@@ -123,29 +123,3 @@ async def test_get_conversation_link_empty_body():
# Verify that get_user was called
mock_service.get_user.assert_called_once()
@pytest.mark.asyncio
async def test_get_conversation_link_none_conversation_id():
"""Test get_conversation_link returns body unchanged when conversation_id is None."""
mock_service = AsyncMock(spec=GitService)
with patch('openhands.server.routes.mcp.server_config') as mock_config:
mock_config.app_mode = AppMode.SAAS
body = 'This is the PR body.'
# Test with None conversation_id
result = await get_conversation_link(
service=mock_service, conversation_id=None, body=body
)
assert result == body
# Test with empty string conversation_id
result = await get_conversation_link(
service=mock_service, conversation_id='', body=body
)
assert result == body
# Verify get_user was never called (early return)
mock_service.get_user.assert_not_called()