mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
18 Commits
fix/slack-
...
rename-cod
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f1dc33f1a | ||
|
|
0da18440c2 | ||
|
|
ac76e10048 | ||
|
|
b98bae8b5f | ||
|
|
516721d1ee | ||
|
|
4d6f66ca28 | ||
|
|
b18568da0b | ||
|
|
83dd3c169c | ||
|
|
35bddb14f1 | ||
|
|
e8425218e2 | ||
|
|
0a879fa781 | ||
|
|
41e142bbab | ||
|
|
b06b9eedac | ||
|
|
a9afafa991 | ||
|
|
663ace4b39 | ||
|
|
2d085a6e0a | ||
|
|
8b7112abe8 | ||
|
|
34547ba947 |
1
.github/workflows/ghcr-build.yml
vendored
1
.github/workflows/ghcr-build.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "saas-rel-*"
|
||||
tags:
|
||||
- "*"
|
||||
pull_request:
|
||||
|
||||
@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
|
||||
@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
return agent
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
# Skip experiment for planning agents which require their specialized prompt
|
||||
if agent.system_prompt_filename != 'system_prompt_planning.j2':
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add pending_free_credits flag to org table.
|
||||
|
||||
Revision ID: 093
|
||||
Revises: 092
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '093'
|
||||
down_revision: Union[str, None] = '092'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add pending_free_credits column to org table with default false.
|
||||
# New orgs will have this set to TRUE at creation time.
|
||||
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""create org_invitation table
|
||||
|
||||
Revision ID: 094
|
||||
Revises: 093
|
||||
Create Date: 2026-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '094'
|
||||
down_revision: Union[str, None] = '093'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create org_invitation table
|
||||
op.create_table(
|
||||
'org_invitation',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('token', sa.String(64), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Foreign key constraints
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'],
|
||||
['org.id'],
|
||||
name='org_invitation_org_fkey',
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['role_id'],
|
||||
['role.id'],
|
||||
name='org_invitation_role_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['inviter_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_inviter_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['accepted_by_user_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_accepter_fkey',
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index(
|
||||
'ix_org_invitation_token',
|
||||
'org_invitation',
|
||||
['token'],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_id',
|
||||
'org_invitation',
|
||||
['org_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_email',
|
||||
'org_invitation',
|
||||
['email'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_status',
|
||||
'org_invitation',
|
||||
['status'],
|
||||
)
|
||||
# Composite index for checking pending invitations
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_email_status',
|
||||
'org_invitation',
|
||||
['org_id', 'email', 'status'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('org_invitation')
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Drop pending_free_credits column from org table.
|
||||
|
||||
Revision ID: 095
|
||||
Revises: 094
|
||||
Create Date: 2025-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '095'
|
||||
down_revision: Union[str, None] = '094'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the pending_free_credits column from org table.
|
||||
# This column was used for tracking free credit eligibility but is no longer needed.
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add pending_free_credits column with default false.
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
accept_router as invitation_accept_router,
|
||||
)
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
invitation_router,
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
@@ -99,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
306
enterprise/server/auth/authorization.py
Normal file
306
enterprise/server/auth/authorization.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Permission-based authorization dependencies for API endpoints.
|
||||
|
||||
This module provides FastAPI dependencies for checking user permissions
|
||||
within organizations. It uses a permission-based authorization model where
|
||||
roles (owner, admin, member) are mapped to specific permissions.
|
||||
|
||||
Permissions are defined in the Permission enum and mapped to roles via
|
||||
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
|
||||
the familiar role-based hierarchy.
|
||||
|
||||
Usage:
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with VIEW_LLM_SETTINGS permission can access
|
||||
...
|
||||
|
||||
@router.patch('/{org_id}/settings')
|
||||
async def update_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with EDIT_LLM_SETTINGS permission can access
|
||||
...
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Permissions that can be assigned to roles."""
|
||||
|
||||
# Secrets
|
||||
MANAGE_SECRETS = 'manage_secrets'
|
||||
|
||||
# MCP
|
||||
MANAGE_MCP = 'manage_mcp'
|
||||
|
||||
# Integrations
|
||||
MANAGE_INTEGRATIONS = 'manage_integrations'
|
||||
|
||||
# Application Settings
|
||||
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
|
||||
|
||||
# API Keys
|
||||
MANAGE_API_KEYS = 'manage_api_keys'
|
||||
|
||||
# LLM Settings
|
||||
VIEW_LLM_SETTINGS = 'view_llm_settings'
|
||||
EDIT_LLM_SETTINGS = 'edit_llm_settings'
|
||||
|
||||
# Billing
|
||||
VIEW_BILLING = 'view_billing'
|
||||
ADD_CREDITS = 'add_credits'
|
||||
|
||||
# Organization Members
|
||||
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
|
||||
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
|
||||
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
|
||||
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
|
||||
|
||||
# Organization Management
|
||||
VIEW_ORG_SETTINGS = 'view_org_settings'
|
||||
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
|
||||
DELETE_ORGANIZATION = 'delete_organization'
|
||||
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
MEMBER = 'member'
|
||||
|
||||
|
||||
# Permission mappings for each role
|
||||
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
RoleName.OWNER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
Permission.CHANGE_USER_ROLE_OWNER,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (synchronous version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
|
||||
else:
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return RoleStore.get_role_by_id(org_member.role_id)
|
||||
|
||||
|
||||
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
|
||||
parse_uuid(user_id)
|
||||
)
|
||||
else:
|
||||
org_member = await OrgMemberStore.get_org_member_async(
|
||||
org_id, parse_uuid(user_id)
|
||||
)
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return await RoleStore.get_role_by_id_async(org_member.role_id)
|
||||
|
||||
|
||||
def get_role_permissions(role_name: str) -> frozenset[Permission]:
|
||||
"""
|
||||
Get the permissions for a role.
|
||||
|
||||
Args:
|
||||
role_name: Name of the role
|
||||
|
||||
Returns:
|
||||
Set of permissions for the role
|
||||
"""
|
||||
try:
|
||||
role_enum = RoleName(role_name)
|
||||
return ROLE_PERMISSIONS.get(role_enum, frozenset())
|
||||
except ValueError:
|
||||
return frozenset()
|
||||
|
||||
|
||||
def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
"""
|
||||
Check if a role has a specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User's Role object
|
||||
permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if the role has the permission
|
||||
"""
|
||||
permissions = get_role_permissions(user_role.name)
|
||||
return permission in permissions
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
...
|
||||
|
||||
Args:
|
||||
permission: The permission required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that validates permission and returns user_id
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role_async(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'User not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if not has_permission(user_role, permission):
|
||||
logger.warning(
|
||||
'Insufficient permissions',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
'required_permission': permission.value,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'Requires {permission.value} permission',
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
@@ -30,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
# Minimax is now the default as it gives results close to claude in terms of quality
|
||||
# but at a much lower price
|
||||
5: 'minimax-m2.5',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
@@ -59,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ class SetAuthCookieMiddleware:
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/api/organizations/members/invite/accept',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
|
||||
@@ -5,6 +5,7 @@ import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
import posthog
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
|
||||
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
|
||||
)
|
||||
|
||||
|
||||
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
|
||||
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token, invitation_token).
|
||||
Tokens may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None, None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return (
|
||||
state_data.get('redirect_url', ''),
|
||||
state_data.get('recaptcha_token'),
|
||||
state_data.get('invitation_token'),
|
||||
)
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None, None
|
||||
|
||||
|
||||
# Keep alias for backward compatibility
|
||||
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
|
||||
"""Extract redirect URL and reCAPTCHA token from OAuth state.
|
||||
|
||||
Deprecated: Use _extract_oauth_state instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token). Token may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None
|
||||
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
|
||||
return redirect_url, recaptcha_token
|
||||
|
||||
|
||||
@oauth_router.get('/keycloak/callback')
|
||||
@@ -130,8 +156,8 @@ async def keycloak_callback(
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
):
|
||||
# Extract redirect URL and reCAPTCHA token from state
|
||||
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
|
||||
# Extract redirect URL, reCAPTCHA token, and invitation token from state
|
||||
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
|
||||
if not redirect_url:
|
||||
redirect_url = str(request.base_url)
|
||||
|
||||
@@ -302,8 +328,13 @@ async def keycloak_callback(
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
# Preserve invitation token so it can be included in OAuth state after verification
|
||||
if invitation_token:
|
||||
verification_redirect_url = (
|
||||
f'{verification_redirect_url}&invitation_token={invitation_token}'
|
||||
)
|
||||
response = RedirectResponse(verification_redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
@@ -381,14 +412,90 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
|
||||
# Process invitation token if present (after email verification but before TOS)
|
||||
if invitation_token:
|
||||
try:
|
||||
logger.info(
|
||||
'Processing invitation token during auth callback',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'invitation_token_prefix': invitation_token[:10] + '...',
|
||||
},
|
||||
)
|
||||
|
||||
await OrgInvitationService.accept_invitation(
|
||||
invitation_token, parse_uuid(user_id)
|
||||
)
|
||||
logger.info(
|
||||
'Invitation accepted during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation expired during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Add query param to redirect URL
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invalid invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'User already member during invitation acceptance',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Email mismatch during auth callback invitation acceptance',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error processing invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
# Don't fail the login if invitation processing fails
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_error=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_error=true'
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
tos_redirect_url = (
|
||||
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
|
||||
)
|
||||
if invitation_token:
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
else:
|
||||
if invitation_token:
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
set_response_cookie(
|
||||
|
||||
@@ -9,15 +9,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import (
|
||||
STRIPE_API_KEY,
|
||||
)
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_store import OrgStore
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -149,7 +147,7 @@ async def create_customer_setup_session(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
success_url=f'{base_url}?setup=success',
|
||||
cancel_url=f'{base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
@@ -254,6 +252,8 @@ async def success_callback(session_id: str, request: Request):
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
@@ -261,7 +261,8 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
|
||||
# Enable BYOR export for the org now that they've purchased credits
|
||||
OrgStore.update_org(user.current_org_id, {'byor_export_enabled': True})
|
||||
if org:
|
||||
org.byor_export_enabled = True
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
|
||||
122
enterprise/server/routes/org_invitation_models.py
Normal file
122
enterprise/server/routes/org_invitation_models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Pydantic models and custom exceptions for organization invitations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
class InvitationError(Exception):
|
||||
"""Base exception for invitation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvitationAlreadyExistsError(InvitationError):
|
||||
"""Raised when a pending invitation already exists for the email."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = 'A pending invitation already exists for this email'
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserAlreadyMemberError(InvitationError):
|
||||
"""Raised when the user is already a member of the organization."""
|
||||
|
||||
def __init__(self, message: str = 'User is already a member of this organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationExpiredError(InvitationError):
|
||||
"""Raised when the invitation has expired."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation has expired'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationInvalidError(InvitationError):
|
||||
"""Raised when the invitation is invalid or revoked."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation is no longer valid'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionError(InvitationError):
|
||||
"""Raised when the user lacks permission to perform the action."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EmailMismatchError(InvitationError):
|
||||
"""Raised when the accepting user's email doesn't match the invitation email."""
|
||||
|
||||
def __init__(self, message: str = 'Your email does not match the invitation'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationCreate(BaseModel):
|
||||
"""Request model for creating invitation(s)."""
|
||||
|
||||
emails: list[EmailStr]
|
||||
role: str = 'member' # Default to member role
|
||||
|
||||
|
||||
class InvitationResponse(BaseModel):
|
||||
"""Response model for invitation details."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
created_at: str
|
||||
expires_at: str
|
||||
inviter_email: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_invitation(
|
||||
cls,
|
||||
invitation: OrgInvitation,
|
||||
inviter_email: str | None = None,
|
||||
) -> 'InvitationResponse':
|
||||
"""Create an InvitationResponse from an OrgInvitation entity.
|
||||
|
||||
Args:
|
||||
invitation: The invitation entity to convert
|
||||
inviter_email: Optional email of the inviter
|
||||
|
||||
Returns:
|
||||
InvitationResponse: The response model instance
|
||||
"""
|
||||
role_name = ''
|
||||
if invitation.role:
|
||||
role_name = invitation.role.name
|
||||
elif invitation.role_id:
|
||||
role = RoleStore.get_role_by_id(invitation.role_id)
|
||||
role_name = role.name if role else ''
|
||||
|
||||
return cls(
|
||||
id=invitation.id,
|
||||
email=invitation.email,
|
||||
role=role_name,
|
||||
status=invitation.status,
|
||||
created_at=invitation.created_at.isoformat(),
|
||||
expires_at=invitation.expires_at.isoformat(),
|
||||
inviter_email=inviter_email,
|
||||
)
|
||||
|
||||
|
||||
class InvitationFailure(BaseModel):
|
||||
"""Response model for a failed invitation."""
|
||||
|
||||
email: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchInvitationResponse(BaseModel):
|
||||
"""Response model for batch invitation creation."""
|
||||
|
||||
successful: list[InvitationResponse]
|
||||
failed: list[InvitationFailure]
|
||||
226
enterprise/server/routes/org_invitations.py
Normal file
226
enterprise/server/routes/org_invitations.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""API routes for organization invitations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from server.routes.org_invitation_models import (
|
||||
BatchInvitationResponse,
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationCreate,
|
||||
InvitationExpiredError,
|
||||
InvitationFailure,
|
||||
InvitationInvalidError,
|
||||
InvitationResponse,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
|
||||
# Router for invitation operations on an organization (requires org_id)
|
||||
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
|
||||
|
||||
# Router for accepting invitations (no org_id required)
|
||||
accept_router = APIRouter(prefix='/api/organizations/members/invite')
|
||||
|
||||
|
||||
@invitation_router.post(
|
||||
'/invite',
|
||||
response_model=BatchInvitationResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
invitation_data: InvitationCreate,
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Create organization invitations for multiple email addresses.
|
||||
|
||||
Sends emails to invitees with secure links to join the organization.
|
||||
Supports batch invitations - some may succeed while others fail.
|
||||
|
||||
Permission rules:
|
||||
- Only owners and admins can create invitations
|
||||
- Admins can only invite with 'member' or 'admin' role (not 'owner')
|
||||
- Owners can invite with any role
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
invitation_data: Invitation details (emails array, role)
|
||||
request: FastAPI request
|
||||
user_id: Authenticated user ID (from dependency)
|
||||
|
||||
Returns:
|
||||
BatchInvitationResponse: Lists of successful and failed invitations
|
||||
|
||||
Raises:
|
||||
HTTPException 400: Invalid role or organization not found
|
||||
HTTPException 403: User lacks permission to invite
|
||||
HTTPException 429: Rate limit exceeded
|
||||
"""
|
||||
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
|
||||
await check_rate_limit_by_user_id(
|
||||
request=request,
|
||||
key_prefix='org_invitation_create',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=6,
|
||||
)
|
||||
|
||||
try:
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=[str(email) for email in invitation_data.emails],
|
||||
role_name=invitation_data.role,
|
||||
inviter_id=UUID(user_id),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Batch organization invitations created',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'total_emails': len(invitation_data.emails),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
'inviter_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return BatchInvitationResponse(
|
||||
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
|
||||
failed=[
|
||||
InvitationFailure(email=email, error=error) for email, error in failed
|
||||
],
|
||||
)
|
||||
|
||||
except InsufficientPermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error creating batch invitations',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@accept_router.get('/accept')
|
||||
async def accept_invitation(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""Accept an organization invitation via token.
|
||||
|
||||
This endpoint is accessed via the link in the invitation email.
|
||||
|
||||
Flow:
|
||||
1. If user is authenticated: Accept invitation directly and redirect to home
|
||||
2. If user is not authenticated: Redirect to login page with invitation token
|
||||
- Frontend stores token and includes it in OAuth state during login
|
||||
- After authentication, keycloak_callback processes the invitation
|
||||
|
||||
Args:
|
||||
token: The invitation token from the email link
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
|
||||
or home page with error query params on failure
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
|
||||
# Try to get user_id from auth (may not be authenticated)
|
||||
user_id = None
|
||||
try:
|
||||
user_auth = await get_user_auth(request)
|
||||
if user_auth:
|
||||
user_id = await user_auth.get_user_id()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
# User not authenticated - redirect to login page with invitation token
|
||||
# Frontend will store the token and include it in OAuth state during login
|
||||
logger.info(
|
||||
'Invitation accept: redirecting unauthenticated user to login',
|
||||
extra={'token_prefix': token[:10] + '...'},
|
||||
)
|
||||
login_url = f'{base_url}/login?invitation_token={token}'
|
||||
return RedirectResponse(login_url, status_code=302)
|
||||
|
||||
# User is authenticated - process the invitation directly
|
||||
try:
|
||||
await OrgInvitationService.accept_invitation(token, UUID(user_id))
|
||||
|
||||
logger.info(
|
||||
'Invitation accepted successfully',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Redirect to home page on success
|
||||
return RedirectResponse(f'{base_url}/', status_code=302)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation accept failed: expired',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: invalid',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'Invitation accept: user already member',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: email mismatch',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error accepting invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)
|
||||
@@ -214,6 +214,7 @@ class OrgPage(BaseModel):
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
current_org_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
@@ -257,7 +258,7 @@ class OrgMemberResponse(BaseModel):
|
||||
user_id: str
|
||||
email: str | None
|
||||
role_id: int
|
||||
role_name: str
|
||||
role: str
|
||||
role_rank: int
|
||||
status: str | None
|
||||
|
||||
|
||||
@@ -2,6 +2,10 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
@@ -28,6 +32,7 @@ from server.routes.org_models import (
|
||||
)
|
||||
from server.services.org_member_service import OrgMemberService
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -74,6 +79,12 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch user to get current_org_id
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
current_org_id = (
|
||||
str(user.current_org_id) if user and user.current_org_id else None
|
||||
)
|
||||
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
@@ -95,7 +106,11 @@ async def list_user_orgs(
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
return OrgPage(
|
||||
items=org_responses,
|
||||
next_page_id=next_page_id,
|
||||
current_org_id=current_org_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@@ -189,23 +204,26 @@ async def create_org(
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
This endpoint retrieves details for a specific organization. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -305,23 +323,24 @@ async def get_me(
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
This endpoint permanently deletes an organization and all associated data including
|
||||
organization members, conversations, billing data, and external LiteLLM team resources.
|
||||
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
org_id: Organization ID to delete (UUID)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
@@ -414,25 +433,26 @@ async def delete_org(
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
|
||||
permission, which is granted to admin and owner roles.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
org_id: Organization ID to update (UUID)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 409 if organization name already exists
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
@@ -496,7 +516,7 @@ async def update_org(
|
||||
|
||||
@org_router.get('/{org_id}/members')
|
||||
async def get_org_members(
|
||||
org_id: str,
|
||||
org_id: UUID,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
@@ -509,13 +529,33 @@ async def get_org_members(
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgMemberPage:
|
||||
"""Get all members of an organization with cursor-based pagination."""
|
||||
"""Get all members of an organization with cursor-based pagination.
|
||||
|
||||
This endpoint retrieves a paginated list of organization members. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of members to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgMemberPage: Paginated list of organization members
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 400 if org_id or page_id format is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
success, error_code, data = await OrgMemberService.get_org_members(
|
||||
org_id=UUID(org_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
org_id=org_id,
|
||||
current_user_id=UUID(user_id),
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
@@ -562,7 +602,7 @@ async def get_org_members(
|
||||
|
||||
@org_router.delete('/{org_id}/members/{user_id}')
|
||||
async def remove_org_member(
|
||||
org_id: str,
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
):
|
||||
@@ -576,7 +616,7 @@ async def remove_org_member(
|
||||
"""
|
||||
try:
|
||||
success, error = await OrgMemberService.remove_org_member(
|
||||
org_id=UUID(org_id),
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
)
|
||||
@@ -708,7 +748,7 @@ async def switch_org(
|
||||
|
||||
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
|
||||
async def update_org_member(
|
||||
org_id: str,
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
update_data: OrgMemberUpdate,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
@@ -725,7 +765,7 @@ async def update_org_member(
|
||||
"""
|
||||
try:
|
||||
return await OrgMemberService.update_org_member(
|
||||
org_id=UUID(org_id),
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
update_data=update_data,
|
||||
|
||||
131
enterprise/server/services/email_service.py
Normal file
131
enterprise/server/services/email_service.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Email service for sending transactional emails via Resend."""
|
||||
|
||||
import os
|
||||
|
||||
try:
|
||||
import resend
|
||||
|
||||
RESEND_AVAILABLE = True
|
||||
except ImportError:
|
||||
RESEND_AVAILABLE = False
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
|
||||
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""Service for sending transactional emails."""
|
||||
|
||||
@staticmethod
|
||||
def _get_resend_client() -> bool:
|
||||
"""Initialize and return the Resend client.
|
||||
|
||||
Returns:
|
||||
bool: True if client is ready, False otherwise
|
||||
"""
|
||||
if not RESEND_AVAILABLE:
|
||||
logger.warning('Resend library not installed, skipping email')
|
||||
return False
|
||||
|
||||
resend_api_key = os.environ.get('RESEND_API_KEY')
|
||||
if not resend_api_key:
|
||||
logger.warning('RESEND_API_KEY not configured, skipping email')
|
||||
return False
|
||||
|
||||
resend.api_key = resend_api_key
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def send_invitation_email(
|
||||
to_email: str,
|
||||
org_name: str,
|
||||
inviter_name: str,
|
||||
role_name: str,
|
||||
invitation_token: str,
|
||||
invitation_id: int,
|
||||
) -> None:
|
||||
"""Send an organization invitation email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient's email address
|
||||
org_name: Name of the organization
|
||||
inviter_name: Display name of the person who sent the invite
|
||||
role_name: Role being offered (e.g., 'member', 'admin')
|
||||
invitation_token: The secure invitation token
|
||||
invitation_id: The invitation ID for logging
|
||||
"""
|
||||
if not EmailService._get_resend_client():
|
||||
return
|
||||
|
||||
# Build invitation URL
|
||||
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
|
||||
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
|
||||
|
||||
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
|
||||
|
||||
params = {
|
||||
'from': from_email,
|
||||
'to': [to_email],
|
||||
'subject': f"You're invited to join {org_name} on OpenHands",
|
||||
'html': f"""
|
||||
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<p>Hi,</p>
|
||||
|
||||
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
|
||||
|
||||
<p>Click the button below to accept the invitation:</p>
|
||||
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{invitation_url}"
|
||||
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
|
||||
text-decoration: none; border-radius: 8px; display: inline-block;
|
||||
font-size: 14px; font-weight: 600;">
|
||||
Accept Invitation
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
Or copy and paste this link into your browser:<br>
|
||||
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
This invitation will expire in 7 days.
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
If you weren't expecting this invitation, you can safely ignore this email.
|
||||
</p>
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
|
||||
|
||||
<p style="color: #999; font-size: 12px;">
|
||||
Best,<br>
|
||||
The OpenHands Team
|
||||
</p>
|
||||
</div>
|
||||
""",
|
||||
}
|
||||
|
||||
try:
|
||||
response = resend.Emails.send(params)
|
||||
logger.info(
|
||||
'Invitation email sent',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'response_id': response.get('id') if response else None,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise
|
||||
397
enterprise/server/services/org_invitation_service.py
Normal file
397
enterprise/server/services/org_invitation_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Service for managing organization invitations."""
|
||||
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import ROLE_ADMIN, ROLE_OWNER
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.email_service import EmailService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import OrgInvitationStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgInvitationService:
|
||||
"""Service for organization invitation operations."""
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the organization exists
|
||||
2. Validates this is not a personal workspace
|
||||
3. Checks inviter has owner/admin role
|
||||
4. Validates role assignment permissions
|
||||
5. Checks if user is already a member
|
||||
6. Creates the invitation
|
||||
7. Sends the invitation email
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
UserAlreadyMemberError: If email is already a member
|
||||
InvitationAlreadyExistsError: If pending invitation exists
|
||||
"""
|
||||
email = email.lower().strip()
|
||||
|
||||
logger.info(
|
||||
'Creating organization invitation',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
# Step 2: Check this is not a personal workspace
|
||||
# A personal workspace has org_id matching the user's id
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
# Step 3: Check inviter is a member and has permission
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
# Step 4: Validate role assignment permissions
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
# Get the target role
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 5: Check if user is already a member (by email)
|
||||
existing_user = await UserStore.get_user_by_email_async(email)
|
||||
if existing_user:
|
||||
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'User is already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 6: Create the invitation
|
||||
invitation = await OrgInvitationStore.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_id=target_role.id,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Step 7: Send invitation email
|
||||
try:
|
||||
# Get inviter info for the email
|
||||
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
|
||||
inviter_name = 'A team member'
|
||||
if inviter_user and inviter_user.email:
|
||||
inviter_name = inviter_user.email.split('@')[0]
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email=email,
|
||||
org_name=org.name,
|
||||
inviter_name=inviter_name,
|
||||
role_name=target_role.name,
|
||||
invitation_token=invitation.token,
|
||||
invitation_id=invitation.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'email': email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Don't fail the invitation creation if email fails
|
||||
# The user can still access via direct link
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def create_invitations_batch(
|
||||
org_id: UUID,
|
||||
emails: list[str],
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
|
||||
"""Create multiple organization invitations concurrently.
|
||||
|
||||
Validates permissions once upfront, then creates invitations in parallel.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
emails: List of invitee email addresses
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitations
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_invitations, failed_emails_with_errors)
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
"""
|
||||
logger.info(
|
||||
'Creating batch organization invitations',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email_count': len(emails),
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate permissions upfront (shared for all emails)
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 2: Create invitations concurrently
|
||||
async def create_single(
|
||||
email: str,
|
||||
) -> tuple[str, OrgInvitation | None, str | None]:
|
||||
"""Create single invitation, return (email, invitation, error)."""
|
||||
try:
|
||||
invitation = await OrgInvitationService.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_name=role_name,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
return (email, invitation, None)
|
||||
except (UserAlreadyMemberError, ValueError) as e:
|
||||
return (email, None, str(e))
|
||||
|
||||
results = await asyncio.gather(*[create_single(email) for email in emails])
|
||||
|
||||
# Step 3: Separate successes and failures
|
||||
successful: list[OrgInvitation] = []
|
||||
failed: list[tuple[str, str]] = []
|
||||
for email, invitation, error in results:
|
||||
if invitation:
|
||||
successful.append(invitation)
|
||||
elif error:
|
||||
failed.append((email, error))
|
||||
|
||||
logger.info(
|
||||
'Batch invitation creation completed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
},
|
||||
)
|
||||
|
||||
return successful, failed
|
||||
|
||||
@staticmethod
|
||||
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
|
||||
"""Accept an organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the token and invitation status
|
||||
2. Checks expiration
|
||||
3. Verifies user is not already a member
|
||||
4. Creates LiteLLM integration
|
||||
5. Adds user to the organization
|
||||
6. Marks invitation as accepted
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
user_id: The user accepting the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The accepted invitation
|
||||
|
||||
Raises:
|
||||
InvitationInvalidError: If token is invalid or invitation not pending
|
||||
InvitationExpiredError: If invitation has expired
|
||||
UserAlreadyMemberError: If user is already a member
|
||||
"""
|
||||
logger.info(
|
||||
'Accepting organization invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
|
||||
'user_id': str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Get and validate invitation
|
||||
invitation = await OrgInvitationStore.get_invitation_by_token(token)
|
||||
|
||||
if not invitation:
|
||||
raise InvitationInvalidError('Invalid invitation token')
|
||||
|
||||
if invitation.status != OrgInvitation.STATUS_PENDING:
|
||||
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
|
||||
raise InvitationInvalidError('Invitation has already been accepted')
|
||||
elif invitation.status == OrgInvitation.STATUS_REVOKED:
|
||||
raise InvitationInvalidError('Invitation has been revoked')
|
||||
else:
|
||||
raise InvitationInvalidError('Invitation is no longer valid')
|
||||
|
||||
# Step 2: Check expiration
|
||||
if OrgInvitationStore.is_token_expired(invitation):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
raise InvitationExpiredError('Invitation has expired')
|
||||
|
||||
# Step 2.5: Verify user email matches invitation email
|
||||
user = await UserStore.get_user_by_id_async(str(user_id))
|
||||
if not user:
|
||||
raise InvitationInvalidError('User not found')
|
||||
|
||||
user_email = user.email
|
||||
# Fallback: fetch email from Keycloak if not in database (for existing users)
|
||||
if not user_email:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
|
||||
user_email = user_info.get('email') if user_info else None
|
||||
|
||||
if not user_email:
|
||||
raise EmailMismatchError('Your account does not have an email address')
|
||||
|
||||
user_email = user_email.lower().strip()
|
||||
invitation_email = invitation.email.lower().strip()
|
||||
|
||||
if user_email != invitation_email:
|
||||
logger.warning(
|
||||
'Email mismatch during invitation acceptance',
|
||||
extra={
|
||||
'user_id': str(user_id),
|
||||
'user_email': user_email,
|
||||
'invitation_email': invitation_email,
|
||||
'invitation_id': invitation.id,
|
||||
},
|
||||
)
|
||||
raise EmailMismatchError()
|
||||
|
||||
# Step 3: Check if user is already a member
|
||||
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'You are already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 4: Create LiteLLM integration for the user in the new org
|
||||
try:
|
||||
settings = await OrgService.create_litellm_integration(
|
||||
invitation.org_id, str(user_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to create LiteLLM integration for invitation acceptance',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise InvitationInvalidError(
|
||||
'Failed to set up organization access. Please try again.'
|
||||
)
|
||||
|
||||
# Step 5: Add user to organization
|
||||
from storage.org_member_store import OrgMemberStore as OMS
|
||||
|
||||
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
|
||||
# Don't override with org defaults - use invitation-specified role
|
||||
org_member_kwargs.pop('llm_model', None)
|
||||
org_member_kwargs.pop('llm_base_url', None)
|
||||
|
||||
OrgMemberStore.add_user_to_org(
|
||||
org_id=invitation.org_id,
|
||||
user_id=user_id,
|
||||
role_id=invitation.role_id,
|
||||
llm_api_key=settings.llm_api_key,
|
||||
status='active',
|
||||
)
|
||||
|
||||
# Step 6: Mark invitation as accepted
|
||||
updated_invitation = await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id,
|
||||
OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization invitation accepted',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'role_id': invitation.role_id,
|
||||
},
|
||||
)
|
||||
|
||||
return updated_invitation
|
||||
@@ -104,7 +104,7 @@ class OrgMemberService:
|
||||
user_id=str(member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=member.role_id,
|
||||
role_name=role.name if role else '',
|
||||
role=role.name if role else '',
|
||||
role_rank=role.rank if role else 0,
|
||||
status=member.status,
|
||||
)
|
||||
@@ -240,7 +240,7 @@ class OrgMemberService:
|
||||
user_id=str(target_membership.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=target_membership.role_id,
|
||||
role_name=target_role.name,
|
||||
role=target_role.name,
|
||||
role_rank=target_role.rank,
|
||||
status=target_membership.status,
|
||||
)
|
||||
@@ -280,7 +280,7 @@ class OrgMemberService:
|
||||
user_id=str(updated_member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=updated_member.role_id,
|
||||
role_name=new_role.name,
|
||||
role=new_role.name,
|
||||
role_rank=new_role.rank,
|
||||
status=updated_member.status,
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
@@ -65,6 +66,7 @@ __all__ = [
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgInvitation',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
|
||||
@@ -10,7 +10,6 @@ import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
@@ -72,9 +71,8 @@ class LiteLlmManager:
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
# New users start with $0 budget - they must purchase credits
|
||||
await LiteLlmManager._create_team(client, keycloak_user_id, org_id, 0)
|
||||
|
||||
if create_user:
|
||||
await LiteLlmManager._create_user(
|
||||
@@ -82,7 +80,7 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, 0
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
|
||||
@@ -51,6 +51,7 @@ class Org(Base): # type: ignore
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
invitations = relationship('OrgInvitation', back_populates='org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
|
||||
59
enterprise/storage/org_invitation.py
Normal file
59
enterprise/storage/org_invitation.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization Invitation.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class OrgInvitation(Base): # type: ignore
|
||||
"""Organization invitation model.
|
||||
|
||||
Represents an invitation for a user to join an organization.
|
||||
Invitations are created by organization owners/admins and contain
|
||||
a secure token that can be used to accept the invitation.
|
||||
"""
|
||||
|
||||
__tablename__ = 'org_invitation'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
token = Column(String(64), nullable=False, unique=True, index=True)
|
||||
org_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('org.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
email = Column(String(255), nullable=False, index=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
status = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
server_default=text("'pending'"),
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
accepted_at = Column(DateTime, nullable=True)
|
||||
accepted_by_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('user.id'),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='invitations')
|
||||
role = relationship('Role')
|
||||
inviter = relationship('User', foreign_keys=[inviter_id])
|
||||
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
|
||||
|
||||
# Status constants
|
||||
STATUS_PENDING = 'pending'
|
||||
STATUS_ACCEPTED = 'accepted'
|
||||
STATUS_REVOKED = 'revoked'
|
||||
STATUS_EXPIRED = 'expired'
|
||||
227
enterprise/storage/org_invitation_store.py
Normal file
227
enterprise/storage/org_invitation_store.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Store class for managing organization invitations.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Invitation token configuration
|
||||
INVITATION_TOKEN_PREFIX = 'inv-'
|
||||
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
|
||||
DEFAULT_EXPIRATION_DAYS = 7
|
||||
|
||||
|
||||
class OrgInvitationStore:
|
||||
"""Store for managing organization invitations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
|
||||
"""Generate a secure invitation token.
|
||||
|
||||
Uses cryptographically secure random generation for tokens.
|
||||
Pattern from api_key_store.py.
|
||||
|
||||
Args:
|
||||
length: Length of the random part of the token
|
||||
|
||||
Returns:
|
||||
str: Token with prefix (e.g., 'inv-aBcDeF123...')
|
||||
"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_id: int,
|
||||
inviter_id: UUID,
|
||||
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_id: Role ID to assign on acceptance
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
expiration_days: Days until the invitation expires
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation record
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
|
||||
|
||||
invitation = OrgInvitation(
|
||||
token=token,
|
||||
org_id=org_id,
|
||||
email=email.lower().strip(),
|
||||
role_id=role_id,
|
||||
inviter_id=inviter_id,
|
||||
status=OrgInvitation.STATUS_PENDING,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(invitation)
|
||||
await session.commit()
|
||||
|
||||
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.id == invitation.id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
logger.info(
|
||||
'Created organization invitation',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'inviter_id': str(inviter_id),
|
||||
'expires_at': expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
|
||||
"""Get an invitation by its token.
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.token == token)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_invitation(
|
||||
org_id: UUID, email: str
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Get a pending invitation for an email in an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Email address to check
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if no pending invitation exists
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(
|
||||
and_(
|
||||
OrgInvitation.org_id == org_id,
|
||||
OrgInvitation.email == email.lower().strip(),
|
||||
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def update_invitation_status(
|
||||
invitation_id: int,
|
||||
status: str,
|
||||
accepted_by_user_id: Optional[UUID] = None,
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Update an invitation's status.
|
||||
|
||||
Args:
|
||||
invitation_id: The invitation ID
|
||||
status: New status (pending, accepted, revoked, expired)
|
||||
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
|
||||
|
||||
Returns:
|
||||
Updated OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
if not invitation:
|
||||
return None
|
||||
|
||||
old_status = invitation.status
|
||||
invitation.status = status
|
||||
|
||||
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
invitation.accepted_at = datetime.utcnow()
|
||||
invitation.accepted_by_user_id = accepted_by_user_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(invitation)
|
||||
|
||||
logger.info(
|
||||
'Updated invitation status',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'old_status': old_status,
|
||||
'new_status': status,
|
||||
'accepted_by_user_id': (
|
||||
str(accepted_by_user_id) if accepted_by_user_id else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
def is_token_expired(invitation: OrgInvitation) -> bool:
|
||||
"""Check if an invitation token has expired.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if expired, False otherwise
|
||||
"""
|
||||
# Use timezone-naive datetime for comparison (database stores without timezone)
|
||||
now = datetime.utcnow()
|
||||
return invitation.expires_at < now
|
||||
|
||||
@staticmethod
|
||||
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
|
||||
"""Check if invitation is expired and update status if needed.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if invitation was marked as expired, False otherwise
|
||||
"""
|
||||
if (
|
||||
invitation.status == OrgInvitation.STATUS_PENDING
|
||||
and OrgInvitationStore.is_token_expired(invitation)
|
||||
):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -60,6 +61,51 @@ class OrgMemberStore:
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
result = (
|
||||
session.query(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_org_member_for_current_org_async(
|
||||
user_id: UUID,
|
||||
) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
|
||||
@@ -29,6 +29,20 @@ class RoleStore:
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_id_async(
|
||||
role_id: int,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by ID (async version)."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
|
||||
@@ -768,6 +768,30 @@ class UserStore:
|
||||
finally:
|
||||
await UserStore._release_user_creation_lock(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_email_async(email: str) -> Optional[User]:
|
||||
"""Get user by email address (async version).
|
||||
|
||||
This method looks up a user by their email address. Note that email
|
||||
addresses may not be unique across all users in rare cases.
|
||||
|
||||
Args:
|
||||
email: The email address to search for
|
||||
|
||||
Returns:
|
||||
User: The user with the matching email, or None if not found
|
||||
"""
|
||||
if not email:
|
||||
return None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.email == email.lower().strip())
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import asyncio
|
||||
import asyncio # noqa: I001
|
||||
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
# This must be before the import of storage
|
||||
# to set up logging and prevent alembic from
|
||||
# running its mouth.
|
||||
from openhands.core.logger import openhands_logger
|
||||
|
||||
from storage.proactive_conversation_store import (
|
||||
ProactiveConversationStore,
|
||||
)
|
||||
|
||||
OLDER_THAN = 30 # 30 minutes
|
||||
|
||||
|
||||
async def main():
|
||||
openhands_logger.info('clean_proactive_convo_table')
|
||||
convo_store = ProactiveConversationStore()
|
||||
await convo_store.clean_old_convos(older_than_minutes=OLDER_THAN)
|
||||
|
||||
|
||||
@@ -126,3 +126,24 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
|
||||
# Should be a different instance than the original (copied after handler runs)
|
||||
assert result is not agent
|
||||
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
|
||||
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
|
||||
def test_run_agent_variant_tests_v1_preserves_planning_agent_system_prompt():
|
||||
"""Planning agents should retain their specialized system prompt and not be overwritten by the experiment."""
|
||||
# Arrange
|
||||
planning_agent = make_agent().model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_planning.j2'}
|
||||
)
|
||||
conv_id = uuid4()
|
||||
|
||||
# Act
|
||||
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
|
||||
user_id='user-planning',
|
||||
conversation_id=conv_id,
|
||||
agent=planning_agent,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.system_prompt_filename == 'system_prompt_planning.j2'
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -179,7 +179,7 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
assert data.items[0].user_id == str(current_user_id)
|
||||
assert data.items[0].email == 'test@example.com'
|
||||
assert data.items[0].role_id == 1
|
||||
assert data.items[0].role_name == 'owner'
|
||||
assert data.items[0].role == 'owner'
|
||||
assert data.items[0].role_rank == 10
|
||||
assert data.items[0].status == 'active'
|
||||
|
||||
@@ -462,7 +462,7 @@ class TestOrgMemberServiceGetOrgMembers:
|
||||
assert success is True
|
||||
assert data is not None
|
||||
assert len(data.items) == 1
|
||||
assert data.items[0].role_name == ''
|
||||
assert data.items[0].role == ''
|
||||
assert data.items[0].role_rank == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1099,7 +1099,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
|
||||
# Assert
|
||||
assert isinstance(data, OrgMemberResponse)
|
||||
assert data.role_name == 'admin'
|
||||
assert data.role == 'admin'
|
||||
assert data.role_rank == 20
|
||||
mock_update.assert_called_once_with(org_id, target_user_id, admin_role.id)
|
||||
|
||||
@@ -1431,7 +1431,7 @@ class TestOrgMemberServiceUpdateOrgMember:
|
||||
|
||||
# Assert
|
||||
assert data is not None
|
||||
assert data.role_name == 'member'
|
||||
assert data.role == 'member'
|
||||
assert data.role_rank == 1000
|
||||
|
||||
|
||||
|
||||
181
enterprise/tests/unit/test_auth_invitation_callback.py
Normal file
181
enterprise/tests/unit/test_auth_invitation_callback.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAuthCallbackInvitationEmailMismatch:
|
||||
"""Test cases for EmailMismatchError handling during auth callback."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redirect_url(self):
|
||||
"""Base redirect URL."""
|
||||
return 'https://app.example.com/'
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self):
|
||||
"""Mock user ID."""
|
||||
return '87654321-4321-8765-4321-876543218765'
|
||||
|
||||
def test_email_mismatch_appends_to_url_without_query_params(
|
||||
self, mock_redirect_url, mock_user_id
|
||||
):
|
||||
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
# Simulate the logic from auth.py
|
||||
redirect_url = mock_redirect_url
|
||||
try:
|
||||
raise EmailMismatchError('Your email does not match the invitation')
|
||||
except EmailMismatchError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
|
||||
|
||||
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
|
||||
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
redirect_url = 'https://app.example.com/?other_param=value'
|
||||
try:
|
||||
raise EmailMismatchError()
|
||||
except EmailMismatchError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
assert (
|
||||
redirect_url
|
||||
== 'https://app.example.com/?other_param=value&email_mismatch=true'
|
||||
)
|
||||
|
||||
def test_email_mismatch_error_has_default_message(self):
|
||||
"""Test that EmailMismatchError has the default message."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
error = EmailMismatchError()
|
||||
assert str(error) == 'Your email does not match the invitation'
|
||||
|
||||
def test_email_mismatch_error_accepts_custom_message(self):
|
||||
"""Test that EmailMismatchError accepts a custom message."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
custom_message = 'Custom error message'
|
||||
error = EmailMismatchError(custom_message)
|
||||
assert str(error) == custom_message
|
||||
|
||||
def test_email_mismatch_error_is_invitation_error(self):
|
||||
"""Test that EmailMismatchError inherits from InvitationError."""
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InvitationError,
|
||||
)
|
||||
|
||||
error = EmailMismatchError()
|
||||
assert isinstance(error, InvitationError)
|
||||
|
||||
|
||||
class TestInvitationTokenInOAuthState:
|
||||
"""Test cases for invitation token handling in OAuth state."""
|
||||
|
||||
def test_invitation_token_included_in_oauth_state(self):
|
||||
"""Test that invitation token is included in OAuth state data."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
# Simulate building OAuth state with invitation token
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
'invitation_token': 'inv-test-token-12345',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
decoded_data = json.loads(base64.b64decode(encoded_state))
|
||||
|
||||
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
|
||||
assert decoded_data['redirect_url'] == 'https://app.example.com/'
|
||||
|
||||
def test_invitation_token_extracted_from_oauth_state(self):
|
||||
"""Test that invitation token can be extracted from OAuth state."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
'invitation_token': 'inv-test-token-12345',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
|
||||
# Simulate decoding in callback
|
||||
decoded_state = json.loads(base64.b64decode(encoded_state))
|
||||
invitation_token = decoded_state.get('invitation_token')
|
||||
|
||||
assert invitation_token == 'inv-test-token-12345'
|
||||
|
||||
def test_oauth_state_without_invitation_token(self):
|
||||
"""Test that OAuth state works without invitation token."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
decoded_data = json.loads(base64.b64decode(encoded_state))
|
||||
|
||||
assert 'invitation_token' not in decoded_data
|
||||
assert decoded_data['redirect_url'] == 'https://app.example.com/'
|
||||
|
||||
|
||||
class TestAuthCallbackInvitationErrors:
|
||||
"""Test cases for various invitation error scenarios in auth callback."""
|
||||
|
||||
def test_invitation_expired_appends_flag(self):
|
||||
"""Test that invitation_expired=true is appended for expired invitations."""
|
||||
from server.routes.org_invitation_models import InvitationExpiredError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise InvitationExpiredError()
|
||||
except InvitationExpiredError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
|
||||
|
||||
def test_invitation_invalid_appends_flag(self):
|
||||
"""Test that invitation_invalid=true is appended for invalid invitations."""
|
||||
from server.routes.org_invitation_models import InvitationInvalidError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise InvitationInvalidError()
|
||||
except InvitationInvalidError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
|
||||
|
||||
def test_already_member_appends_flag(self):
|
||||
"""Test that already_member=true is appended when user is already a member."""
|
||||
from server.routes.org_invitation_models import UserAlreadyMemberError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise UserAlreadyMemberError()
|
||||
except UserAlreadyMemberError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?already_member=true'
|
||||
756
enterprise/tests/unit/test_authorization.py
Normal file
756
enterprise/tests/unit/test_authorization.py
Normal file
@@ -0,0 +1,756 @@
|
||||
"""
|
||||
Unit tests for permission-based authorization (authorization.py).
|
||||
|
||||
Tests the FastAPI dependencies that validate user permissions within organizations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.auth.authorization import (
|
||||
ROLE_PERMISSIONS,
|
||||
Permission,
|
||||
RoleName,
|
||||
get_role_permissions,
|
||||
get_user_org_role,
|
||||
has_permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Tests for Permission enum
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPermission:
|
||||
"""Tests for Permission enum."""
|
||||
|
||||
def test_permission_values(self):
|
||||
"""
|
||||
GIVEN: Permission enum
|
||||
WHEN: Accessing permission values
|
||||
THEN: All expected permissions exist with correct string values
|
||||
"""
|
||||
assert Permission.MANAGE_SECRETS.value == 'manage_secrets'
|
||||
assert Permission.MANAGE_MCP.value == 'manage_mcp'
|
||||
assert Permission.MANAGE_INTEGRATIONS.value == 'manage_integrations'
|
||||
assert (
|
||||
Permission.MANAGE_APPLICATION_SETTINGS.value
|
||||
== 'manage_application_settings'
|
||||
)
|
||||
assert Permission.MANAGE_API_KEYS.value == 'manage_api_keys'
|
||||
assert Permission.VIEW_LLM_SETTINGS.value == 'view_llm_settings'
|
||||
assert Permission.EDIT_LLM_SETTINGS.value == 'edit_llm_settings'
|
||||
assert Permission.VIEW_BILLING.value == 'view_billing'
|
||||
assert Permission.ADD_CREDITS.value == 'add_credits'
|
||||
assert (
|
||||
Permission.INVITE_USER_TO_ORGANIZATION.value
|
||||
== 'invite_user_to_organization'
|
||||
)
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER.value == 'change_user_role:member'
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN.value == 'change_user_role:admin'
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER.value == 'change_user_role:owner'
|
||||
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
|
||||
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
|
||||
|
||||
def test_permission_from_string(self):
|
||||
"""
|
||||
GIVEN: Valid permission string
|
||||
WHEN: Creating Permission from string
|
||||
THEN: Correct enum value is returned
|
||||
"""
|
||||
assert Permission('manage_secrets') == Permission.MANAGE_SECRETS
|
||||
assert Permission('view_llm_settings') == Permission.VIEW_LLM_SETTINGS
|
||||
assert Permission('delete_organization') == Permission.DELETE_ORGANIZATION
|
||||
|
||||
def test_permission_invalid_string(self):
|
||||
"""
|
||||
GIVEN: Invalid permission string
|
||||
WHEN: Creating Permission from string
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
Permission('invalid_permission')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for RoleName enum
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRoleName:
|
||||
"""Tests for RoleName enum."""
|
||||
|
||||
def test_role_name_values(self):
|
||||
"""
|
||||
GIVEN: RoleName enum
|
||||
WHEN: Accessing role name values
|
||||
THEN: All expected roles exist with correct string values
|
||||
"""
|
||||
assert RoleName.OWNER.value == 'owner'
|
||||
assert RoleName.ADMIN.value == 'admin'
|
||||
assert RoleName.MEMBER.value == 'member'
|
||||
|
||||
def test_role_name_from_string(self):
|
||||
"""
|
||||
GIVEN: Valid role name string
|
||||
WHEN: Creating RoleName from string
|
||||
THEN: Correct enum value is returned
|
||||
"""
|
||||
assert RoleName('owner') == RoleName.OWNER
|
||||
assert RoleName('admin') == RoleName.ADMIN
|
||||
assert RoleName('member') == RoleName.MEMBER
|
||||
|
||||
def test_role_name_invalid_string(self):
|
||||
"""
|
||||
GIVEN: Invalid role name string
|
||||
WHEN: Creating RoleName from string
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
RoleName('invalid_role')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for ROLE_PERMISSIONS mapping
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRolePermissions:
|
||||
"""Tests for role permission mappings."""
|
||||
|
||||
def test_owner_has_all_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking owner permissions
|
||||
THEN: Owner has all permissions including owner-only permissions
|
||||
"""
|
||||
owner_perms = ROLE_PERMISSIONS[RoleName.OWNER]
|
||||
assert Permission.MANAGE_SECRETS in owner_perms
|
||||
assert Permission.MANAGE_MCP in owner_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in owner_perms
|
||||
assert Permission.EDIT_LLM_SETTINGS in owner_perms
|
||||
assert Permission.VIEW_BILLING in owner_perms
|
||||
assert Permission.ADD_CREDITS in owner_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
|
||||
assert Permission.DELETE_ORGANIZATION in owner_perms
|
||||
|
||||
def test_admin_has_admin_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking admin permissions
|
||||
THEN: Admin has admin permissions but not owner-only permissions
|
||||
"""
|
||||
admin_perms = ROLE_PERMISSIONS[RoleName.ADMIN]
|
||||
assert Permission.MANAGE_SECRETS in admin_perms
|
||||
assert Permission.MANAGE_MCP in admin_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in admin_perms
|
||||
assert Permission.EDIT_LLM_SETTINGS in admin_perms
|
||||
assert Permission.VIEW_BILLING in admin_perms
|
||||
assert Permission.ADD_CREDITS in admin_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
|
||||
# Admin should NOT have owner-only permissions
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
|
||||
assert Permission.DELETE_ORGANIZATION not in admin_perms
|
||||
|
||||
def test_member_has_limited_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking member permissions
|
||||
THEN: Member has limited permissions
|
||||
"""
|
||||
member_perms = ROLE_PERMISSIONS[RoleName.MEMBER]
|
||||
# Member has basic settings permissions
|
||||
assert Permission.MANAGE_SECRETS in member_perms
|
||||
assert Permission.MANAGE_MCP in member_perms
|
||||
assert Permission.MANAGE_INTEGRATIONS in member_perms
|
||||
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
|
||||
assert Permission.MANAGE_API_KEYS in member_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in member_perms
|
||||
assert Permission.VIEW_ORG_SETTINGS in member_perms
|
||||
# Member should NOT have admin/owner permissions
|
||||
assert Permission.EDIT_LLM_SETTINGS not in member_perms
|
||||
assert Permission.VIEW_BILLING not in member_perms
|
||||
assert Permission.ADD_CREDITS not in member_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in member_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in member_perms
|
||||
assert Permission.DELETE_ORGANIZATION not in member_perms
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_role_permissions function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetRolePermissions:
|
||||
"""Tests for get_role_permissions function."""
|
||||
|
||||
def test_get_owner_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'owner'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Owner permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('owner')
|
||||
assert Permission.DELETE_ORGANIZATION in perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in perms
|
||||
|
||||
def test_get_admin_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'admin'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Admin permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('admin')
|
||||
assert Permission.EDIT_LLM_SETTINGS in perms
|
||||
assert Permission.DELETE_ORGANIZATION not in perms
|
||||
|
||||
def test_get_member_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'member'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Member permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('member')
|
||||
assert Permission.VIEW_LLM_SETTINGS in perms
|
||||
assert Permission.EDIT_LLM_SETTINGS not in perms
|
||||
|
||||
def test_get_invalid_role_permissions(self):
|
||||
"""
|
||||
GIVEN: Invalid role name
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Empty frozenset is returned
|
||||
"""
|
||||
perms = get_role_permissions('invalid_role')
|
||||
assert perms == frozenset()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for has_permission function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHasPermission:
|
||||
"""Tests for has_permission function."""
|
||||
|
||||
def test_owner_has_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is True
|
||||
|
||||
def test_owner_has_view_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: Checking for VIEW_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
|
||||
|
||||
def test_admin_has_edit_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: Checking for EDIT_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is True
|
||||
|
||||
def test_admin_lacks_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
def test_member_has_view_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for VIEW_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
|
||||
|
||||
def test_member_lacks_edit_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for EDIT_LLM_SETTINGS permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is False
|
||||
|
||||
def test_member_lacks_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
def test_invalid_role_has_no_permissions(self):
|
||||
"""
|
||||
GIVEN: User with invalid role
|
||||
WHEN: Checking for any permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'invalid_role'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is False
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_user_org_role function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetUserOrgRole:
|
||||
"""Tests for get_user_org_role function."""
|
||||
|
||||
def test_returns_role_when_member_exists(self):
|
||||
"""
|
||||
GIVEN: User is a member of organization with role
|
||||
WHEN: get_user_org_role is called
|
||||
THEN: Role object is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.role_id = 1
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
return_value=mock_org_member,
|
||||
),
|
||||
patch(
|
||||
'server.auth.authorization.RoleStore.get_role_by_id',
|
||||
return_value=mock_role,
|
||||
),
|
||||
):
|
||||
result = get_user_org_role(user_id, org_id)
|
||||
assert result == mock_role
|
||||
|
||||
def test_returns_none_when_not_member(self):
|
||||
"""
|
||||
GIVEN: User is not a member of organization
|
||||
WHEN: get_user_org_role is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_org_role(user_id, org_id)
|
||||
assert result is None
|
||||
|
||||
def test_returns_role_when_org_id_is_none(self):
|
||||
"""
|
||||
GIVEN: User with a current organization
|
||||
WHEN: get_user_org_role is called with org_id=None
|
||||
THEN: Role object is returned using get_org_member_for_current_org
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.role_id = 1
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
|
||||
return_value=mock_org_member,
|
||||
) as mock_get_current,
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
) as mock_get_org_member,
|
||||
patch(
|
||||
'server.auth.authorization.RoleStore.get_role_by_id',
|
||||
return_value=mock_role,
|
||||
),
|
||||
):
|
||||
result = get_user_org_role(user_id, None)
|
||||
assert result == mock_role
|
||||
mock_get_current.assert_called_once()
|
||||
mock_get_org_member.assert_not_called()
|
||||
|
||||
def test_returns_none_when_org_id_is_none_and_no_current_org(self):
|
||||
"""
|
||||
GIVEN: User with no current organization membership
|
||||
WHEN: get_user_org_role is called with org_id=None
|
||||
THEN: None is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_org_role(user_id, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for require_permission dependency
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
"""Tests for require_permission dependency factory."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_id_when_authorized(self):
|
||||
"""
|
||||
GIVEN: User with required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_401_when_not_authenticated(self):
|
||||
"""
|
||||
GIVEN: No user ID (not authenticated)
|
||||
WHEN: Permission checker is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
org_id = uuid4()
|
||||
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_not_member(self):
|
||||
"""
|
||||
GIVEN: User is not a member of organization
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_insufficient_permission(self):
|
||||
"""
|
||||
GIVEN: User without required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'delete_organization' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_delete_organization(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: DELETE_ORGANIZATION permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_delete_organization(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: DELETE_ORGANIZATION permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_on_insufficient_permission(self):
|
||||
"""
|
||||
GIVEN: User without required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: Warning is logged with details
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
),
|
||||
patch('server.auth.authorization.logger') as mock_logger,
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['user_role'] == 'member'
|
||||
assert call_args[1]['extra']['required_permission'] == 'delete_organization'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_id_when_org_id_is_none(self):
|
||||
"""
|
||||
GIVEN: User with required permission in their current org
|
||||
WHEN: Permission checker is called with org_id=None
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
) as mock_get_role:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=None, user_id=user_id)
|
||||
assert result == user_id
|
||||
mock_get_role.assert_called_once_with(user_id, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_org_id_is_none_and_not_member(self):
|
||||
"""
|
||||
GIVEN: User not a member of their current organization
|
||||
WHEN: Permission checker is called with org_id=None
|
||||
THEN: HTTPException with 403 status is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=None, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for permission-based access control scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPermissionScenarios:
|
||||
"""Tests for real-world permission scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_can_manage_secrets(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: MANAGE_SECRETS permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_cannot_invite_users(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_invite_users(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_change_owner_role(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: CHANGE_USER_ROLE_OWNER permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_change_owner_role(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: CHANGE_USER_ROLE_OWNER permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
@@ -299,6 +299,8 @@ async def test_success_callback_success():
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
@@ -317,10 +319,19 @@ async def test_success_callback_success():
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
|
||||
) as mock_update_budget,
|
||||
patch('server.routes.billing.OrgStore.update_org') as mock_update_org,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
# First query: BillingSession (query().filter().filter().first())
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
# Second query: Org (query().filter().first()) - use side_effect for different return chains
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
@@ -338,14 +349,11 @@ async def test_success_callback_success():
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
125.0, # 100 + 25.00
|
||||
)
|
||||
|
||||
# Verify BYOR export is enabled for the org
|
||||
mock_update_org.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
{'byor_export_enabled': True},
|
||||
)
|
||||
# Verify BYOR export is enabled for the org (updated in same session)
|
||||
assert mock_org.byor_export_enabled is True
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
@@ -394,6 +402,68 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
"""Test that database changes are not committed when update_team_and_users_budget fails.
|
||||
|
||||
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
|
||||
the database transaction rolls back.
|
||||
"""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 0,
|
||||
'litellm_budget_table': {'max_budget': 0},
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=1000, # $10
|
||||
customer='mock_customer_id',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_session_not_found():
|
||||
"""Test cancel callback when billing session is not found."""
|
||||
@@ -509,6 +579,6 @@ async def test_create_customer_setup_session_success():
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com/?free_credits=success',
|
||||
success_url='https://test.com/?setup=success',
|
||||
cancel_url='https://test.com/',
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ async def test_create_customer_setup_session_uses_customer_id():
|
||||
customer=customer_id,
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
success_url=f'{request.base_url}?setup=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
|
||||
|
||||
192
enterprise/tests/unit/test_email_service.py
Normal file
192
enterprise/tests/unit/test_email_service.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for email service."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from server.services.email_service import (
|
||||
DEFAULT_WEB_HOST,
|
||||
EmailService,
|
||||
)
|
||||
|
||||
|
||||
class TestEmailServiceInvitationUrl:
|
||||
"""Test cases for invitation URL generation."""
|
||||
|
||||
def test_invitation_url_uses_correct_endpoint(self):
|
||||
"""Test that invitation URL points to the correct API endpoint."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
# Get the call arguments
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
# Verify the URL in the email HTML contains the correct endpoint
|
||||
assert (
|
||||
'/api/organizations/members/invite/accept?token='
|
||||
in email_params['html']
|
||||
)
|
||||
assert 'inv-test-token-12345' in email_params['html']
|
||||
|
||||
def test_invitation_url_uses_web_host_env_var(self):
|
||||
"""Test that invitation URL uses WEB_HOST environment variable."""
|
||||
custom_host = 'https://custom.example.com'
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
|
||||
),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
|
||||
assert expected_url in email_params['html']
|
||||
|
||||
def test_invitation_url_uses_default_host_when_env_not_set(self):
|
||||
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
|
||||
# Remove WEB_HOST if it exists
|
||||
env_without_web_host.pop('WEB_HOST', None)
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, env_without_web_host, clear=True),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
# Clear WEB_HOST from the environment
|
||||
os.environ.pop('WEB_HOST', None)
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
|
||||
assert expected_url in email_params['html']
|
||||
|
||||
|
||||
class TestEmailServiceGetResendClient:
|
||||
"""Test cases for Resend client initialization."""
|
||||
|
||||
def test_get_resend_client_returns_false_when_resend_not_available(self):
|
||||
"""Test that _get_resend_client returns False when resend is not installed."""
|
||||
with patch('server.services.email_service.RESEND_AVAILABLE', False):
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is False
|
||||
|
||||
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
|
||||
"""Test that _get_resend_client returns False when API key is missing."""
|
||||
with (
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch.dict(os.environ, {}, clear=True),
|
||||
):
|
||||
os.environ.pop('RESEND_API_KEY', None)
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is False
|
||||
|
||||
def test_get_resend_client_returns_true_when_configured(self):
|
||||
"""Test that _get_resend_client returns True when properly configured."""
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is True
|
||||
assert mock_resend.api_key == 'test-key'
|
||||
|
||||
|
||||
class TestEmailServiceSendInvitationEmail:
|
||||
"""Test cases for send_invitation_email method."""
|
||||
|
||||
def test_send_invitation_email_skips_when_client_not_ready(self):
|
||||
"""Test that email sending is skipped when client is not ready."""
|
||||
with patch.object(
|
||||
EmailService, '_get_resend_client', return_value=False
|
||||
) as mock_get_client:
|
||||
# Should not raise, just return early
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
mock_get_client.assert_called_once()
|
||||
|
||||
def test_send_invitation_email_includes_all_required_info(self):
|
||||
"""Test that invitation email includes org name, inviter name, and role."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Acme Corp',
|
||||
inviter_name='John Doe',
|
||||
role_name='admin',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=42,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
# Verify email content
|
||||
assert email_params['to'] == ['test@example.com']
|
||||
assert 'Acme Corp' in email_params['subject']
|
||||
assert 'John Doe' in email_params['html']
|
||||
assert 'Acme Corp' in email_params['html']
|
||||
assert 'admin' in email_params['html']
|
||||
464
enterprise/tests/unit/test_org_invitation_service.py
Normal file
464
enterprise/tests/unit/test_org_invitation_service.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Tests for organization invitation service - email validation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
|
||||
class TestAcceptInvitationEmailValidation:
|
||||
"""Test cases for email validation during invitation acceptance."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create a mock invitation with pending status."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.id = 1
|
||||
invitation.email = 'alice@example.com'
|
||||
invitation.status = OrgInvitation.STATUS_PENDING
|
||||
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
invitation.role_id = 1
|
||||
return invitation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock user with email."""
|
||||
user = MagicMock()
|
||||
user.id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
user.email = 'alice@example.com'
|
||||
return user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
|
||||
"""Test that invitation is accepted when user email matches invitation email."""
|
||||
# Arrange
|
||||
user_id = mock_user.id
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
|
||||
) as mock_accept:
|
||||
mock_accept.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert
|
||||
mock_accept.assert_called_once_with(token, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_mismatch_raises_error(
|
||||
self, mock_invitation, mock_user
|
||||
):
|
||||
"""Test that EmailMismatchError is raised when emails don't match."""
|
||||
# Arrange
|
||||
user_id = mock_user.id
|
||||
token = 'inv-test-token-12345'
|
||||
mock_user.email = 'bob@example.com' # Different email
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(EmailMismatchError):
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that Keycloak email is used when user has no email in database."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = None # No email in database
|
||||
|
||||
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.TokenManager'
|
||||
) as mock_token_manager_class,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock TokenManager instance
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.get_user_info_from_user_id = AsyncMock(
|
||||
return_value=mock_keycloak_user_info
|
||||
)
|
||||
mock_token_manager_class.return_value = mock_token_manager
|
||||
|
||||
mock_get_member.return_value = None # Not already a member
|
||||
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because Keycloak email matches
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert
|
||||
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
|
||||
str(user_id)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_no_email_anywhere_raises_error(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = None # No email in database
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.TokenManager'
|
||||
) as mock_token_manager_class,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock TokenManager to return no email
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
|
||||
mock_token_manager_class.return_value = mock_token_manager
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(EmailMismatchError) as exc_info:
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
assert 'does not have an email address' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_comparison_is_case_insensitive(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that email comparison is case insensitive."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
|
||||
|
||||
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because emails match case-insensitively
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert - invitation was accepted (update_invitation_status was called)
|
||||
mock_update_status.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateInvitationsBatch:
|
||||
"""Test cases for batch invitation creation."""
|
||||
|
||||
@pytest.fixture
|
||||
def org_id(self):
|
||||
"""Organization UUID for testing."""
|
||||
return UUID('12345678-1234-5678-1234-567812345678')
|
||||
|
||||
@pytest.fixture
|
||||
def inviter_id(self):
|
||||
"""Inviter UUID for testing."""
|
||||
return UUID('87654321-4321-8765-4321-876543218765')
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org(self):
|
||||
"""Create a mock organization."""
|
||||
org = MagicMock()
|
||||
org.id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
org.name = 'Test Org'
|
||||
return org
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inviter_member(self):
|
||||
"""Create a mock inviter member with owner role."""
|
||||
member = MagicMock()
|
||||
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
member.role_id = 1
|
||||
return member
|
||||
|
||||
@pytest.fixture
|
||||
def mock_owner_role(self):
|
||||
"""Create a mock owner role."""
|
||||
role = MagicMock()
|
||||
role.id = 1
|
||||
role.name = 'owner'
|
||||
return role
|
||||
|
||||
@pytest.fixture
|
||||
def mock_member_role(self):
|
||||
"""Create a mock member role."""
|
||||
role = MagicMock()
|
||||
role.id = 3
|
||||
role.name = 'member'
|
||||
return role
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_creates_all_invitations_successfully(
|
||||
self,
|
||||
org_id,
|
||||
inviter_id,
|
||||
mock_org,
|
||||
mock_inviter_member,
|
||||
mock_owner_role,
|
||||
mock_member_role,
|
||||
):
|
||||
"""Test that batch creation succeeds for all valid emails."""
|
||||
# Arrange
|
||||
emails = ['alice@example.com', 'bob@example.com']
|
||||
mock_invitation_1 = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation_1.id = 1
|
||||
mock_invitation_2 = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation_2.id = 2
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=mock_member_role,
|
||||
),
|
||||
patch.object(
|
||||
OrgInvitationService,
|
||||
'create_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[mock_invitation_1, mock_invitation_2],
|
||||
),
|
||||
):
|
||||
# Act
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(successful) == 2
|
||||
assert len(failed) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_handles_partial_success(
|
||||
self,
|
||||
org_id,
|
||||
inviter_id,
|
||||
mock_org,
|
||||
mock_inviter_member,
|
||||
mock_owner_role,
|
||||
mock_member_role,
|
||||
):
|
||||
"""Test that batch returns partial results when some emails fail."""
|
||||
# Arrange
|
||||
from server.routes.org_invitation_models import UserAlreadyMemberError
|
||||
|
||||
emails = ['alice@example.com', 'existing@example.com']
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=mock_member_role,
|
||||
),
|
||||
patch.object(
|
||||
OrgInvitationService,
|
||||
'create_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[mock_invitation, UserAlreadyMemberError()],
|
||||
),
|
||||
):
|
||||
# Act
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(successful) == 1
|
||||
assert len(failed) == 1
|
||||
assert failed[0][0] == 'existing@example.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
|
||||
"""Test that permission error fails the entire batch upfront."""
|
||||
# Arrange
|
||||
|
||||
emails = ['alice@example.com', 'bob@example.com']
|
||||
|
||||
with patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=None, # Organization not found
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
assert 'not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_fails_on_invalid_role(
|
||||
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
|
||||
):
|
||||
"""Test that invalid role fails the entire batch."""
|
||||
# Arrange
|
||||
emails = ['alice@example.com']
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=None, # Invalid role
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='invalid_role',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
assert 'Invalid role' in str(exc_info.value)
|
||||
308
enterprise/tests/unit/test_org_invitation_store.py
Normal file
308
enterprise/tests/unit/test_org_invitation_store.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Tests for organization invitation store."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import (
|
||||
INVITATION_TOKEN_LENGTH,
|
||||
INVITATION_TOKEN_PREFIX,
|
||||
OrgInvitationStore,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateToken:
|
||||
"""Test cases for token generation."""
|
||||
|
||||
def test_generate_token_has_correct_prefix(self):
|
||||
"""Test that generated tokens have the correct prefix."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
assert token.startswith(INVITATION_TOKEN_PREFIX)
|
||||
|
||||
def test_generate_token_has_correct_length(self):
|
||||
"""Test that generated tokens have the correct total length."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
|
||||
assert len(token) == expected_length
|
||||
|
||||
def test_generate_token_uses_alphanumeric_characters(self):
|
||||
"""Test that generated tokens use only alphanumeric characters."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Remove prefix and check the rest is alphanumeric
|
||||
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
|
||||
assert random_part.isalnum()
|
||||
|
||||
def test_generate_token_is_unique(self):
|
||||
"""Test that generated tokens are unique (probabilistically)."""
|
||||
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
|
||||
assert len(set(tokens)) == 100
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Test cases for token expiration checking."""
|
||||
|
||||
def test_token_not_expired_when_future(self):
|
||||
"""Test that tokens with future expiration are not expired."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is False
|
||||
|
||||
def test_token_expired_when_past(self):
|
||||
"""Test that tokens with past expiration are expired."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is True
|
||||
|
||||
def test_token_expired_at_exact_boundary(self):
|
||||
"""Test that tokens at exact expiration time are expired."""
|
||||
# A token that expires "now" should be expired
|
||||
now = datetime.utcnow()
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = now - timedelta(microseconds=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestCreateInvitation:
|
||||
"""Test cases for invitation creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invitation_normalizes_email(self):
|
||||
"""Test that email is normalized (lowercase, stripped) on creation."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.execute = AsyncMock()
|
||||
|
||||
# Mock the result of the re-fetch query
|
||||
mock_result = MagicMock()
|
||||
mock_invitation = MagicMock()
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.email = 'test@example.com'
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
await OrgInvitationStore.create_invitation(
|
||||
org_id=UUID('12345678-1234-5678-1234-567812345678'),
|
||||
email=' TEST@EXAMPLE.COM ',
|
||||
role_id=1,
|
||||
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
|
||||
)
|
||||
|
||||
# Verify that the OrgInvitation was created with normalized email
|
||||
add_call = mock_session.add.call_args
|
||||
created_invitation = add_call[0][0]
|
||||
assert created_invitation.email == 'test@example.com'
|
||||
|
||||
|
||||
class TestGetInvitationByToken:
|
||||
"""Test cases for getting invitation by token."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_by_token_returns_invitation(self):
|
||||
"""Test that get_invitation_by_token returns the invitation when found."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.token = 'inv-test-token-12345'
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.get_invitation_by_token(
|
||||
'inv-test-token-12345'
|
||||
)
|
||||
assert result == mock_invitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_by_token_returns_none_when_not_found(self):
|
||||
"""Test that get_invitation_by_token returns None when not found."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.get_invitation_by_token(
|
||||
'inv-nonexistent-token'
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetPendingInvitation:
|
||||
"""Test cases for getting pending invitation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_invitation_normalizes_email(self):
|
||||
"""Test that email is normalized when querying for pending invitations."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
await OrgInvitationStore.get_pending_invitation(
|
||||
org_id=UUID('12345678-1234-5678-1234-567812345678'),
|
||||
email=' TEST@EXAMPLE.COM ',
|
||||
)
|
||||
|
||||
# Verify the query was called (email normalization happens in the filter)
|
||||
assert mock_session.execute.called
|
||||
|
||||
|
||||
class TestUpdateInvitationStatus:
|
||||
"""Test cases for updating invitation status."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_sets_accepted_at_for_accepted(self):
|
||||
"""Test that accepted_at is set when status is accepted."""
|
||||
from uuid import UUID
|
||||
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation_id=1,
|
||||
status=OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
assert mock_invitation.accepted_at is not None
|
||||
assert mock_invitation.accepted_by_user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_returns_none_when_not_found(self):
|
||||
"""Test that update returns None when invitation not found."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.update_invitation_status(
|
||||
invitation_id=999,
|
||||
status=OrgInvitation.STATUS_ACCEPTED,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMarkExpiredIfNeeded:
|
||||
"""Test cases for marking expired invitations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_marks_expired_when_pending_and_past_expiry(self):
|
||||
"""Test that pending expired invitations are marked as expired."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is True
|
||||
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_mark_when_not_expired(self):
|
||||
"""Test that non-expired invitations are not marked."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is False
|
||||
mock_update.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_mark_when_not_pending(self):
|
||||
"""Test that non-pending invitations are not marked even if expired."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
|
||||
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is False
|
||||
mock_update.assert_not_called()
|
||||
388
enterprise/tests/unit/test_org_invitations_router.py
Normal file
388
enterprise/tests/unit/test_org_invitations_router.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Tests for organization invitations API router."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.routes.org_invitations import accept_router, invitation_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a FastAPI app with the invitation routers."""
|
||||
app = FastAPI()
|
||||
app.include_router(invitation_router)
|
||||
app.include_router(accept_router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestRouterPrefixes:
|
||||
"""Test that router prefixes are configured correctly."""
|
||||
|
||||
def test_invitation_router_has_correct_prefix(self):
|
||||
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
|
||||
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
|
||||
|
||||
def test_accept_router_has_correct_prefix(self):
|
||||
"""Test that accept_router has /api/organizations/members/invite prefix."""
|
||||
assert accept_router.prefix == '/api/organizations/members/invite'
|
||||
|
||||
|
||||
class TestAcceptInvitationEndpoint:
|
||||
"""Test cases for the accept invitation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth(self):
|
||||
"""Create a mock user auth."""
|
||||
user_auth = MagicMock()
|
||||
user_auth.get_user_id = AsyncMock(
|
||||
return_value='87654321-4321-8765-4321-876543218765'
|
||||
)
|
||||
return user_auth
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_unauthenticated_redirects_to_login(self, client):
|
||||
"""Test that unauthenticated users are redirected to login with invitation token."""
|
||||
with patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
|
||||
'location', ''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_authenticated_success_redirects_home(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that successful acceptance redirects to home page."""
|
||||
mock_invitation = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_invitation,
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
location = response.headers.get('location', '')
|
||||
assert location.endswith('/')
|
||||
assert 'invitation_expired' not in location
|
||||
assert 'invitation_invalid' not in location
|
||||
assert 'email_mismatch' not in location
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_expired_invitation_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that expired invitation redirects with invitation_expired=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InvitationExpiredError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_expired=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invalid_invitation_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that invalid invitation redirects with invitation_invalid=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InvitationInvalidError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_invalid=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_already_member_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that already member error redirects with already_member=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=UserAlreadyMemberError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'already_member=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_email_mismatch_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that email mismatch error redirects with email_mismatch=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=EmailMismatchError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'email_mismatch=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_unexpected_error_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that unexpected errors redirect with invitation_error=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception('Unexpected error'),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_error=true' in response.headers.get('location', '')
|
||||
|
||||
|
||||
class TestCreateInvitationBatchEndpoint:
|
||||
"""Test cases for the batch invitation creation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_app(self):
|
||||
"""Create a FastAPI app with dependency overrides for batch tests."""
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(invitation_router)
|
||||
|
||||
# Override the get_user_id dependency
|
||||
app.dependency_overrides[get_user_id] = (
|
||||
lambda: '87654321-4321-8765-4321-876543218765'
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def batch_client(self, batch_app):
|
||||
"""Create a test client with dependency overrides."""
|
||||
return TestClient(batch_app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create a mock invitation."""
|
||||
from datetime import datetime
|
||||
|
||||
invitation = MagicMock()
|
||||
invitation.id = 1
|
||||
invitation.email = 'alice@example.com'
|
||||
invitation.role = MagicMock(name='member')
|
||||
invitation.role.name = 'member'
|
||||
invitation.role_id = 3
|
||||
invitation.status = 'pending'
|
||||
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
|
||||
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
|
||||
return invitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_returns_successful_invitations(
|
||||
self, batch_client, mock_invitation
|
||||
):
|
||||
"""Test that batch creation returns successful invitations."""
|
||||
mock_invitation_2 = MagicMock()
|
||||
mock_invitation_2.id = 2
|
||||
mock_invitation_2.email = 'bob@example.com'
|
||||
mock_invitation_2.role = MagicMock()
|
||||
mock_invitation_2.role.name = 'member'
|
||||
mock_invitation_2.role_id = 3
|
||||
mock_invitation_2.status = 'pending'
|
||||
mock_invitation_2.created_at = mock_invitation.created_at
|
||||
mock_invitation_2.expires_at = mock_invitation.expires_at
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
return_value=([mock_invitation, mock_invitation_2], []),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={
|
||||
'emails': ['alice@example.com', 'bob@example.com'],
|
||||
'role': 'member',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert len(data['successful']) == 2
|
||||
assert len(data['failed']) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_returns_partial_success(
|
||||
self, batch_client, mock_invitation
|
||||
):
|
||||
"""Test that batch creation returns both successful and failed invitations."""
|
||||
failed_emails = [('existing@example.com', 'User is already a member')]
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
return_value=([mock_invitation], failed_emails),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={
|
||||
'emails': ['alice@example.com', 'existing@example.com'],
|
||||
'role': 'member',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert len(data['successful']) == 1
|
||||
assert len(data['failed']) == 1
|
||||
assert data['failed'][0]['email'] == 'existing@example.com'
|
||||
assert 'already a member' in data['failed'][0]['error']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_permission_denied_returns_403(self, batch_client):
|
||||
"""Test that permission denied returns 403 for entire batch."""
|
||||
from server.routes.org_invitation_models import InsufficientPermissionError
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientPermissionError(
|
||||
'Only owners and admins can invite'
|
||||
),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={'emails': ['alice@example.com'], 'role': 'member'},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert 'owners and admins' in response.json()['detail']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_invalid_role_returns_400(self, batch_client):
|
||||
"""Test that invalid role returns 400."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError('Invalid role: superuser'),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={'emails': ['alice@example.com'], 'role': 'superuser'},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert 'Invalid role' in response.json()['detail']
|
||||
@@ -158,6 +158,57 @@ def test_get_org_member(session_maker):
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
def test_get_org_member_for_current_org(session_maker):
|
||||
# Test getting org_member for user's current organization
|
||||
with session_maker() as session:
|
||||
# Create test data - user belongs to two orgs but current_org is org1
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org1.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org1.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org2.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
org1_id = org1.id
|
||||
|
||||
# Test retrieval - should return org_member for current_org (org1)
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(user_id)
|
||||
assert retrieved_org_member is not None
|
||||
assert retrieved_org_member.org_id == org1_id
|
||||
assert retrieved_org_member.user_id == user_id
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key-1'
|
||||
|
||||
|
||||
def test_get_org_member_for_current_org_user_not_found(session_maker):
|
||||
# Test getting org_member for non-existent user
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(
|
||||
uuid.uuid4()
|
||||
)
|
||||
assert retrieved_org_member is None
|
||||
|
||||
|
||||
def test_add_user_to_org(session_maker):
|
||||
# Test adding a user to an org
|
||||
with session_maker() as session:
|
||||
|
||||
@@ -151,8 +151,9 @@ describe("LoginContent", () => {
|
||||
await user.click(githubButton);
|
||||
|
||||
// Wait for async handleAuthRedirect to complete
|
||||
// The URL includes state parameter added by handleAuthRedirect
|
||||
await waitFor(() => {
|
||||
expect(window.location.href).toBe(mockUrl);
|
||||
expect(window.location.href).toContain(mockUrl);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -201,4 +202,103 @@ describe("LoginContent", () => {
|
||||
|
||||
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display invitation pending message when hasInvitation is true", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasInvitation
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display invitation pending message when hasInvitation is false", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasInvitation={false}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.queryByText("AUTH$INVITATION_PENDING"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call buildOAuthStateData when clicking auth button", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockBuildOAuthStateData = vi.fn((baseState) => ({
|
||||
...baseState,
|
||||
invitation_token: "inv-test-token-12345",
|
||||
}));
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/login/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
buildOAuthStateData={mockBuildOAuthStateData}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
name: "GITHUB$CONNECT_TO_GITHUB",
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockBuildOAuthStateData).toHaveBeenCalled();
|
||||
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
|
||||
expect(callArg).toHaveProperty("redirect_url");
|
||||
});
|
||||
});
|
||||
|
||||
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockBuildOAuthStateData = vi.fn((baseState) => ({
|
||||
...baseState,
|
||||
invitation_token: "inv-test-token-12345",
|
||||
}));
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/login/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
buildOAuthStateData={mockBuildOAuthStateData}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
name: "GITHUB$CONNECT_TO_GITHUB",
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const redirectUrl = window.location.href;
|
||||
// The URL should contain an encoded state parameter
|
||||
expect(redirectUrl).toContain("state=");
|
||||
// Decode and verify the state contains invitation_token
|
||||
const url = new URL(redirectUrl);
|
||||
const state = url.searchParams.get("state");
|
||||
if (state) {
|
||||
const decodedState = JSON.parse(atob(state));
|
||||
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -92,19 +92,21 @@ describe("PlanPreview", () => {
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is null", () => {
|
||||
renderPlanPreview(<PlanPreview planContent={null} />);
|
||||
// Arrange & Act
|
||||
const { container } = renderPlanPreview(<PlanPreview planContent={null} />);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
// Assert
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is undefined", () => {
|
||||
renderPlanPreview(<PlanPreview planContent={undefined} />);
|
||||
// Arrange & Act
|
||||
const { container } = renderPlanPreview(
|
||||
<PlanPreview planContent={undefined} />,
|
||||
);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
// Assert
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should render markdown content when planContent is provided", () => {
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
|
||||
import { SuggestedTask } from "#/utils/types";
|
||||
|
||||
vi.mock("#/hooks/query/use-settings", async () => {
|
||||
const actual = await vi.importActual<typeof import("#/hooks/query/use-settings")>(
|
||||
"#/hooks/query/use-settings",
|
||||
);
|
||||
return {
|
||||
...actual,
|
||||
useSettings: vi.fn().mockReturnValue({
|
||||
data: {
|
||||
v1_enabled: true,
|
||||
},
|
||||
isLoading: false,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackConversationCreated: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("useCreateConversation", () => {
|
||||
it("passes suggested tasks to the V1 create conversation API", async () => {
|
||||
const createConversationSpy = vi
|
||||
.spyOn(V1ConversationService, "createConversation")
|
||||
.mockResolvedValue({
|
||||
id: "task-id",
|
||||
created_by_user_id: null,
|
||||
status: "READY",
|
||||
detail: null,
|
||||
app_conversation_id: null,
|
||||
sandbox_id: null,
|
||||
agent_server_url: "http://agent-server.local",
|
||||
request: {
|
||||
sandbox_id: null,
|
||||
initial_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Please address the comments" }],
|
||||
},
|
||||
processors: [],
|
||||
llm_model: null,
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: "github",
|
||||
suggested_task: null,
|
||||
title: null,
|
||||
trigger: null,
|
||||
pr_number: [],
|
||||
parent_conversation_id: null,
|
||||
agent_type: "default",
|
||||
},
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useCreateConversation(), {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
const suggestedTask: SuggestedTask = {
|
||||
git_provider: "github",
|
||||
issue_number: 42,
|
||||
repo: "owner/repo",
|
||||
title: "Resolve comments",
|
||||
task_type: "UNRESOLVED_COMMENTS",
|
||||
};
|
||||
|
||||
await result.current.mutateAsync({
|
||||
query: "Please address the comments",
|
||||
repository: {
|
||||
name: "owner/repo",
|
||||
gitProvider: "github",
|
||||
branch: "main",
|
||||
},
|
||||
conversationInstructions: "Focus on review comments",
|
||||
suggestedTask,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
"owner/repo",
|
||||
"github",
|
||||
"Please address the comments",
|
||||
"main",
|
||||
"Focus on review comments",
|
||||
suggestedTask,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
170
frontend/__tests__/hooks/use-invitation.test.ts
Normal file
170
frontend/__tests__/hooks/use-invitation.test.ts
Normal file
@@ -0,0 +1,170 @@
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
|
||||
|
||||
// Mock setSearchParams function
|
||||
const mockSetSearchParams = vi.fn();
|
||||
|
||||
// Default mock searchParams
|
||||
let mockSearchParamsData: Record<string, string> = {};
|
||||
|
||||
// Mock react-router
|
||||
vi.mock("react-router", () => ({
|
||||
useSearchParams: () => [
|
||||
{
|
||||
get: (key: string) => mockSearchParamsData[key] || null,
|
||||
has: (key: string) => key in mockSearchParamsData,
|
||||
},
|
||||
mockSetSearchParams,
|
||||
],
|
||||
}));
|
||||
|
||||
// Import after mocking
|
||||
import { useInvitation } from "#/hooks/use-invitation";
|
||||
|
||||
describe("useInvitation", () => {
|
||||
beforeEach(() => {
|
||||
// Clear localStorage before each test
|
||||
localStorage.clear();
|
||||
// Reset mock data
|
||||
mockSearchParamsData = {};
|
||||
mockSetSearchParams.mockClear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("initialization", () => {
|
||||
it("should initialize with null token when localStorage is empty", () => {
|
||||
// Arrange - localStorage is empty (cleared in beforeEach)
|
||||
|
||||
// Act
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(result.current.invitationToken).toBeNull();
|
||||
expect(result.current.hasInvitation).toBe(false);
|
||||
});
|
||||
|
||||
it("should initialize with token from localStorage if present", () => {
|
||||
// Arrange
|
||||
const storedToken = "inv-stored-token-12345";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
|
||||
|
||||
// Act
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(result.current.invitationToken).toBe(storedToken);
|
||||
expect(result.current.hasInvitation).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("URL token capture", () => {
|
||||
it("should capture invitation_token from URL and store in localStorage", () => {
|
||||
// Arrange
|
||||
const urlToken = "inv-url-token-67890";
|
||||
mockSearchParamsData = { invitation_token: urlToken };
|
||||
|
||||
// Act
|
||||
renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
|
||||
expect(mockSetSearchParams).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("completion cleanup", () => {
|
||||
it("should clear localStorage when email_mismatch param is present", () => {
|
||||
// Arrange
|
||||
const storedToken = "inv-token-to-clear";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
|
||||
mockSearchParamsData = { email_mismatch: "true" };
|
||||
|
||||
// Act
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
expect(mockSetSearchParams).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should clear localStorage when invitation_success param is present", () => {
|
||||
// Arrange
|
||||
const storedToken = "inv-token-to-clear";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
|
||||
mockSearchParamsData = { invitation_success: "true" };
|
||||
|
||||
// Act
|
||||
renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
});
|
||||
|
||||
it("should clear localStorage when invitation_expired param is present", () => {
|
||||
// Arrange
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
|
||||
mockSearchParamsData = { invitation_expired: "true" };
|
||||
|
||||
// Act
|
||||
renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildOAuthStateData", () => {
|
||||
it("should include invitation_token in OAuth state when token is present", () => {
|
||||
// Arrange
|
||||
const token = "inv-oauth-token-12345";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, token);
|
||||
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
const baseState = { redirect_url: "/dashboard" };
|
||||
|
||||
// Act
|
||||
const stateData = result.current.buildOAuthStateData(baseState);
|
||||
|
||||
// Assert
|
||||
expect(stateData.invitation_token).toBe(token);
|
||||
expect(stateData.redirect_url).toBe("/dashboard");
|
||||
});
|
||||
|
||||
it("should not include invitation_token when no token is present", () => {
|
||||
// Arrange - no token in localStorage
|
||||
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
const baseState = { redirect_url: "/dashboard" };
|
||||
|
||||
// Act
|
||||
const stateData = result.current.buildOAuthStateData(baseState);
|
||||
|
||||
// Assert
|
||||
expect(stateData.invitation_token).toBeUndefined();
|
||||
expect(stateData.redirect_url).toBe("/dashboard");
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearInvitation", () => {
|
||||
it("should remove token from localStorage when called", () => {
|
||||
// Arrange
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Act
|
||||
act(() => {
|
||||
result.current.clearInvitation();
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
expect(result.current.invitationToken).toBeNull();
|
||||
expect(result.current.hasInvitation).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -57,6 +57,22 @@ vi.mock("#/hooks/use-tracking", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
const { useInvitationMock, buildOAuthStateDataMock } = vi.hoisted(() => ({
|
||||
useInvitationMock: vi.fn(() => ({
|
||||
invitationToken: null as string | null,
|
||||
hasInvitation: false,
|
||||
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
|
||||
clearInvitation: vi.fn(),
|
||||
})),
|
||||
buildOAuthStateDataMock: vi.fn(
|
||||
(baseState: Record<string, string>) => baseState,
|
||||
),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-invitation", () => ({
|
||||
useInvitation: () => useInvitationMock(),
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: LoginPage,
|
||||
@@ -234,7 +250,8 @@ describe("LoginPage", () => {
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
expect(window.location.href).toBe(mockUrl);
|
||||
// URL includes state parameter added by handleAuthRedirect
|
||||
expect(window.location.href).toContain(mockUrl);
|
||||
});
|
||||
|
||||
it("should redirect to GitLab auth URL when GitLab button is clicked", async () => {
|
||||
@@ -255,7 +272,8 @@ describe("LoginPage", () => {
|
||||
});
|
||||
await user.click(gitlabButton);
|
||||
|
||||
expect(window.location.href).toBe("https://gitlab.com/oauth/authorize");
|
||||
// URL includes state parameter added by handleAuthRedirect
|
||||
expect(window.location.href).toContain("https://gitlab.com/oauth/authorize");
|
||||
});
|
||||
|
||||
it("should redirect to Bitbucket auth URL when Bitbucket button is clicked", async () => {
|
||||
@@ -282,7 +300,8 @@ describe("LoginPage", () => {
|
||||
});
|
||||
await user.click(bitbucketButton);
|
||||
|
||||
expect(window.location.href).toBe(
|
||||
// URL includes state parameter added by handleAuthRedirect
|
||||
expect(window.location.href).toContain(
|
||||
"https://bitbucket.org/site/oauth2/authorize",
|
||||
);
|
||||
});
|
||||
@@ -479,4 +498,137 @@ describe("LoginPage", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Invitation Flow", () => {
|
||||
it("should display invitation pending message when hasInvitation is true", async () => {
|
||||
useInvitationMock.mockReturnValue({
|
||||
invitationToken: "inv-test-token-12345",
|
||||
hasInvitation: true,
|
||||
buildOAuthStateData: buildOAuthStateDataMock,
|
||||
clearInvitation: vi.fn(),
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not display invitation pending message when hasInvitation is false", async () => {
|
||||
useInvitationMock.mockReturnValue({
|
||||
invitationToken: null,
|
||||
hasInvitation: false,
|
||||
buildOAuthStateData: buildOAuthStateDataMock,
|
||||
clearInvitation: vi.fn(),
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("login-content")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(
|
||||
screen.queryByText("AUTH$INVITATION_PENDING"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should pass buildOAuthStateData to LoginContent for OAuth state encoding", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
|
||||
...baseState,
|
||||
invitation_token: "inv-test-token-12345",
|
||||
}));
|
||||
|
||||
useInvitationMock.mockReturnValue({
|
||||
invitationToken: "inv-test-token-12345",
|
||||
hasInvitation: true,
|
||||
buildOAuthStateData: mockBuildOAuthStateData,
|
||||
clearInvitation: vi.fn(),
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
name: "GITHUB$CONNECT_TO_GITHUB",
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
// buildOAuthStateData should have been called during the OAuth redirect
|
||||
expect(mockBuildOAuthStateData).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should include invitation token in OAuth state when invitation is present", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockBuildOAuthStateData = vi.fn((baseState: Record<string, string>) => ({
|
||||
...baseState,
|
||||
invitation_token: "inv-test-token-12345",
|
||||
}));
|
||||
|
||||
useInvitationMock.mockReturnValue({
|
||||
invitationToken: "inv-test-token-12345",
|
||||
hasInvitation: true,
|
||||
buildOAuthStateData: mockBuildOAuthStateData,
|
||||
clearInvitation: vi.fn(),
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
name: "GITHUB$CONNECT_TO_GITHUB",
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
// Verify the redirect URL contains the state with invitation token
|
||||
await waitFor(() => {
|
||||
expect(window.location.href).toContain("state=");
|
||||
});
|
||||
|
||||
// Decode and verify the state contains invitation_token
|
||||
const url = new URL(window.location.href);
|
||||
const state = url.searchParams.get("state");
|
||||
if (state) {
|
||||
const decodedState = JSON.parse(atob(state));
|
||||
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
|
||||
}
|
||||
});
|
||||
|
||||
it("should handle login with invitation_token URL parameter", async () => {
|
||||
useInvitationMock.mockReturnValue({
|
||||
invitationToken: "inv-url-token-67890",
|
||||
hasInvitation: true,
|
||||
buildOAuthStateData: buildOAuthStateDataMock,
|
||||
clearInvitation: vi.fn(),
|
||||
});
|
||||
|
||||
render(<RouterStub initialEntries={["/login?invitation_token=inv-url-token-67890"]} />, {
|
||||
wrapper: createWrapper(),
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -42,6 +42,15 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displaySuccessToast: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-invitation", () => ({
|
||||
useInvitation: () => ({
|
||||
invitationToken: null,
|
||||
hasInvitation: false,
|
||||
buildOAuthStateData: (baseState: Record<string, string>) => baseState,
|
||||
clearInvitation: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
function LoginStub() {
|
||||
const [searchParams] = useSearchParams();
|
||||
const emailVerificationRequired =
|
||||
@@ -353,4 +362,68 @@ describe("MainApp", () => {
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Invitation URL Parameters", () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(AuthService, "authenticate").mockRejectedValue({
|
||||
response: { status: 401 },
|
||||
isAxiosError: true,
|
||||
});
|
||||
});
|
||||
|
||||
it("should redirect to login when email_mismatch=true is in query params", async () => {
|
||||
renderMainApp(["/?email_mismatch=true"]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
|
||||
it("should redirect to login when invitation_success=true is in query params", async () => {
|
||||
renderMainApp(["/?invitation_success=true"]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
|
||||
it("should redirect to login when invitation_expired=true is in query params", async () => {
|
||||
renderMainApp(["/?invitation_expired=true"]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
|
||||
it("should redirect to login when invitation_invalid=true is in query params", async () => {
|
||||
renderMainApp(["/?invitation_invalid=true"]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
|
||||
it("should redirect to login when already_member=true is in query params", async () => {
|
||||
renderMainApp(["/?already_member=true"]);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(screen.getByTestId("login-page")).toBeInTheDocument();
|
||||
},
|
||||
{ timeout: 2000 },
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
4
frontend/package-lock.json
generated
4
frontend/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.3.0",
|
||||
"version": "1.4.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.3.0",
|
||||
"version": "1.4.0",
|
||||
"dependencies": {
|
||||
"@heroui/react": "2.8.7",
|
||||
"@microlink/react-json-view": "^1.27.1",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.3.0",
|
||||
"version": "1.4.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"engines": {
|
||||
|
||||
@@ -2,6 +2,7 @@ import axios from "axios";
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { ConversationTrigger, GetVSCodeUrlResponse } from "../open-hands.types";
|
||||
import { Provider } from "#/types/settings";
|
||||
import { SuggestedTask } from "#/utils/types";
|
||||
import { buildHttpBaseUrl } from "#/utils/websocket-url";
|
||||
import { buildSessionHeaders } from "#/utils/utils";
|
||||
import type {
|
||||
@@ -61,6 +62,7 @@ class V1ConversationService {
|
||||
initialUserMsg?: string,
|
||||
selected_branch?: string,
|
||||
conversationInstructions?: string,
|
||||
suggestedTask?: SuggestedTask,
|
||||
trigger?: ConversationTrigger,
|
||||
parent_conversation_id?: string,
|
||||
agent_type?: "default" | "plan",
|
||||
@@ -69,14 +71,15 @@ class V1ConversationService {
|
||||
selected_repository: selectedRepository,
|
||||
git_provider,
|
||||
selected_branch,
|
||||
suggested_task: suggestedTask,
|
||||
title: conversationInstructions,
|
||||
trigger,
|
||||
parent_conversation_id: parent_conversation_id || null,
|
||||
agent_type,
|
||||
};
|
||||
|
||||
// Add initial message if provided
|
||||
if (initialUserMsg) {
|
||||
// suggested_task implies the backend will construct the initial_message
|
||||
if (!suggestedTask && initialUserMsg) {
|
||||
body.initial_message = {
|
||||
role: "user",
|
||||
content: [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { ConversationTrigger } from "../open-hands.types";
|
||||
import { Provider } from "#/types/settings";
|
||||
import { V1SandboxStatus } from "../sandbox-service/sandbox-service.types";
|
||||
import { Provider } from "#/types/settings";
|
||||
import { SuggestedTask } from "#/utils/types";
|
||||
|
||||
// V1 Metrics Types
|
||||
export interface V1TokenUsage {
|
||||
@@ -47,6 +48,7 @@ export interface V1AppConversationStartRequest {
|
||||
selected_repository?: string | null;
|
||||
selected_branch?: string | null;
|
||||
git_provider?: Provider | null;
|
||||
suggested_task?: SuggestedTask | null;
|
||||
title?: string | null;
|
||||
trigger?: ConversationTrigger | null;
|
||||
pr_number?: number[];
|
||||
|
||||
@@ -21,6 +21,10 @@ export interface LoginContentProps {
|
||||
emailVerified?: boolean;
|
||||
hasDuplicatedEmail?: boolean;
|
||||
recaptchaBlocked?: boolean;
|
||||
hasInvitation?: boolean;
|
||||
buildOAuthStateData?: (
|
||||
baseStateData: Record<string, string>,
|
||||
) => Record<string, string>;
|
||||
}
|
||||
|
||||
export function LoginContent({
|
||||
@@ -31,6 +35,8 @@ export function LoginContent({
|
||||
emailVerified = false,
|
||||
hasDuplicatedEmail = false,
|
||||
recaptchaBlocked = false,
|
||||
hasInvitation = false,
|
||||
buildOAuthStateData,
|
||||
}: LoginContentProps) {
|
||||
const { t } = useTranslation();
|
||||
const { trackLoginButtonClick } = useTracking();
|
||||
@@ -59,31 +65,36 @@ export function LoginContent({
|
||||
) => {
|
||||
trackLoginButtonClick({ provider });
|
||||
|
||||
if (!config?.recaptcha_site_key || !recaptchaReady) {
|
||||
// No reCAPTCHA or token generation failed - redirect normally
|
||||
window.location.href = redirectUrl;
|
||||
return;
|
||||
const url = new URL(redirectUrl);
|
||||
const currentState =
|
||||
url.searchParams.get("state") || window.location.origin;
|
||||
|
||||
// Build base state data
|
||||
let stateData: Record<string, string> = {
|
||||
redirect_url: currentState,
|
||||
};
|
||||
|
||||
// Add invitation token if present
|
||||
if (buildOAuthStateData) {
|
||||
stateData = buildOAuthStateData(stateData);
|
||||
}
|
||||
|
||||
// If reCAPTCHA is configured, encode token in OAuth state
|
||||
try {
|
||||
const token = await executeRecaptcha("LOGIN");
|
||||
if (token) {
|
||||
const url = new URL(redirectUrl);
|
||||
const currentState =
|
||||
url.searchParams.get("state") || window.location.origin;
|
||||
|
||||
// Encode state with reCAPTCHA token for backend verification
|
||||
const stateData = {
|
||||
redirect_url: currentState,
|
||||
recaptcha_token: token,
|
||||
};
|
||||
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
|
||||
window.location.href = url.toString();
|
||||
// If reCAPTCHA is configured, add token to state
|
||||
if (config?.recaptcha_site_key && recaptchaReady) {
|
||||
try {
|
||||
const token = await executeRecaptcha("LOGIN");
|
||||
if (token) {
|
||||
stateData.recaptcha_token = token;
|
||||
}
|
||||
} catch (err) {
|
||||
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
|
||||
return;
|
||||
}
|
||||
} catch (err) {
|
||||
displayErrorToast(t(I18nKey.AUTH$RECAPTCHA_BLOCKED));
|
||||
}
|
||||
|
||||
// Encode state and redirect
|
||||
url.searchParams.set("state", btoa(JSON.stringify(stateData)));
|
||||
window.location.href = url.toString();
|
||||
};
|
||||
|
||||
const handleGitHubAuth = () => {
|
||||
@@ -123,6 +134,10 @@ export function LoginContent({
|
||||
const buttonBaseClasses =
|
||||
"w-[301.5px] h-10 rounded p-2 flex items-center justify-center cursor-pointer transition-opacity hover:opacity-90 disabled:opacity-50 disabled:cursor-not-allowed";
|
||||
const buttonLabelClasses = "text-sm font-medium leading-5 px-1";
|
||||
|
||||
const shouldShownHelperText =
|
||||
emailVerified || hasDuplicatedEmail || recaptchaBlocked || hasInvitation;
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col items-center w-full gap-12.5"
|
||||
@@ -136,20 +151,29 @@ export function LoginContent({
|
||||
{t(I18nKey.AUTH$LETS_GET_STARTED)}
|
||||
</h1>
|
||||
|
||||
{emailVerified && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
|
||||
</p>
|
||||
)}
|
||||
{hasDuplicatedEmail && (
|
||||
<p className="text-sm text-danger text-center">
|
||||
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
|
||||
</p>
|
||||
)}
|
||||
{recaptchaBlocked && (
|
||||
<p className="text-sm text-danger text-center max-w-125">
|
||||
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
|
||||
</p>
|
||||
{shouldShownHelperText && (
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
{emailVerified && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$EMAIL_VERIFIED_PLEASE_LOGIN)}
|
||||
</p>
|
||||
)}
|
||||
{hasDuplicatedEmail && (
|
||||
<p className="text-sm text-danger text-center">
|
||||
{t(I18nKey.AUTH$DUPLICATE_EMAIL_ERROR)}
|
||||
</p>
|
||||
)}
|
||||
{recaptchaBlocked && (
|
||||
<p className="text-sm text-danger text-center max-w-125">
|
||||
{t(I18nKey.AUTH$RECAPTCHA_BLOCKED)}
|
||||
</p>
|
||||
)}
|
||||
{hasInvitation && (
|
||||
<p className="text-sm text-muted-foreground text-center">
|
||||
{t(I18nKey.AUTH$INVITATION_PENDING)}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
|
||||
@@ -209,6 +209,21 @@ export function ChatInterface() {
|
||||
setFeedbackPolarity(polarity);
|
||||
};
|
||||
|
||||
// Auto-scroll to bottom when new messages arrive
|
||||
React.useEffect(() => {
|
||||
if (autoScroll) {
|
||||
scrollDomToBottom();
|
||||
}
|
||||
// Note: We intentionally exclude autoScroll from deps because we only want
|
||||
// to scroll when message content changes, not when autoScroll state changes.
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [
|
||||
v1UiEvents.length,
|
||||
v0Events.length,
|
||||
optimisticUserMessage,
|
||||
scrollDomToBottom,
|
||||
]);
|
||||
|
||||
// Create a ScrollProvider with the scroll hook values
|
||||
const scrollProviderValue = {
|
||||
scrollRef,
|
||||
|
||||
@@ -65,7 +65,7 @@ export function PlanPreview({
|
||||
return `${planContent.slice(0, MAX_CONTENT_LENGTH)}...`;
|
||||
}, [planContent]);
|
||||
|
||||
if (!shouldUsePlanningAgent) {
|
||||
if (!shouldUsePlanningAgent || !planContent) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -61,7 +61,8 @@ export const useCreateConversation = () => {
|
||||
query,
|
||||
repository?.branch,
|
||||
conversationInstructions,
|
||||
undefined, // trigger - will be set by backend
|
||||
suggestedTask,
|
||||
undefined, // trigger - set by backend when applicable
|
||||
parentConversationId,
|
||||
agentType,
|
||||
);
|
||||
|
||||
119
frontend/src/hooks/use-invitation.ts
Normal file
119
frontend/src/hooks/use-invitation.ts
Normal file
@@ -0,0 +1,119 @@
|
||||
import React from "react";
|
||||
import { useSearchParams } from "react-router";
|
||||
|
||||
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
|
||||
|
||||
interface UseInvitationReturn {
|
||||
/** The invitation token, if present */
|
||||
invitationToken: string | null;
|
||||
/** Whether there is an active invitation */
|
||||
hasInvitation: boolean;
|
||||
/** Clear the stored invitation token */
|
||||
clearInvitation: () => void;
|
||||
/** Build OAuth state data including invitation token if present */
|
||||
buildOAuthStateData: (
|
||||
baseStateData: Record<string, string>,
|
||||
) => Record<string, string>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook to manage organization invitation tokens during the login flow.
|
||||
*
|
||||
* This hook:
|
||||
* 1. Reads invitation_token from URL query params on mount
|
||||
* 2. Persists the token in localStorage (survives page refresh and works across tabs)
|
||||
* 3. Provides the token for inclusion in OAuth state
|
||||
* 4. Provides cleanup method after successful authentication
|
||||
*
|
||||
* The invitation token flow:
|
||||
* 1. User clicks invitation link → /api/invitations/accept?token=xxx
|
||||
* 2. Backend redirects to /login?invitation_token=xxx
|
||||
* 3. This hook captures token and stores in localStorage
|
||||
* 4. When user clicks login button, token is included in OAuth state
|
||||
* 5. After auth callback processes invitation, frontend clears the token
|
||||
*
|
||||
* Note: localStorage is used instead of sessionStorage to support scenarios where
|
||||
* the user opens the email verification link in a new tab/browser window.
|
||||
*/
|
||||
export function useInvitation(): UseInvitationReturn {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const [invitationToken, setInvitationToken] = React.useState<string | null>(
|
||||
() => {
|
||||
// Initialize from localStorage (persists across tabs and page refreshes)
|
||||
if (typeof window !== "undefined") {
|
||||
return localStorage.getItem(INVITATION_TOKEN_KEY);
|
||||
}
|
||||
return null;
|
||||
},
|
||||
);
|
||||
|
||||
// Capture invitation token from URL and persist to localStorage
|
||||
// This only runs on the login page where the hook is used
|
||||
React.useEffect(() => {
|
||||
const tokenFromUrl = searchParams.get("invitation_token");
|
||||
|
||||
if (tokenFromUrl) {
|
||||
// Store in localStorage for persistence across tabs and refreshes
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, tokenFromUrl);
|
||||
setInvitationToken(tokenFromUrl);
|
||||
|
||||
// Remove token from URL to clean up (prevents token exposure in browser history)
|
||||
const newSearchParams = new URLSearchParams(searchParams);
|
||||
newSearchParams.delete("invitation_token");
|
||||
setSearchParams(newSearchParams, { replace: true });
|
||||
}
|
||||
}, [searchParams, setSearchParams]);
|
||||
|
||||
// Clear invitation token when invitation flow completes (success or failure)
|
||||
// These query params are set by the backend after processing the invitation
|
||||
React.useEffect(() => {
|
||||
const invitationCompleted =
|
||||
searchParams.has("invitation_success") ||
|
||||
searchParams.has("invitation_expired") ||
|
||||
searchParams.has("invitation_invalid") ||
|
||||
searchParams.has("invitation_error") ||
|
||||
searchParams.has("already_member") ||
|
||||
searchParams.has("email_mismatch");
|
||||
|
||||
if (invitationCompleted) {
|
||||
localStorage.removeItem(INVITATION_TOKEN_KEY);
|
||||
setInvitationToken(null);
|
||||
|
||||
// Remove invitation params from URL to clean up
|
||||
const newSearchParams = new URLSearchParams(searchParams);
|
||||
newSearchParams.delete("invitation_success");
|
||||
newSearchParams.delete("invitation_expired");
|
||||
newSearchParams.delete("invitation_invalid");
|
||||
newSearchParams.delete("invitation_error");
|
||||
newSearchParams.delete("already_member");
|
||||
newSearchParams.delete("email_mismatch");
|
||||
setSearchParams(newSearchParams, { replace: true });
|
||||
}
|
||||
}, [searchParams, setSearchParams]);
|
||||
|
||||
const clearInvitation = React.useCallback(() => {
|
||||
localStorage.removeItem(INVITATION_TOKEN_KEY);
|
||||
setInvitationToken(null);
|
||||
}, []);
|
||||
|
||||
const buildOAuthStateData = React.useCallback(
|
||||
(baseStateData: Record<string, string>): Record<string, string> => {
|
||||
const stateData = { ...baseStateData };
|
||||
|
||||
// Include invitation token in state if present
|
||||
if (invitationToken) {
|
||||
stateData.invitation_token = invitationToken;
|
||||
}
|
||||
|
||||
return stateData;
|
||||
},
|
||||
[invitationToken],
|
||||
);
|
||||
|
||||
return {
|
||||
invitationToken,
|
||||
hasInvitation: invitationToken !== null,
|
||||
clearInvitation,
|
||||
buildOAuthStateData,
|
||||
};
|
||||
}
|
||||
@@ -763,6 +763,7 @@ export enum I18nKey {
|
||||
AUTH$DUPLICATE_EMAIL_ERROR = "AUTH$DUPLICATE_EMAIL_ERROR",
|
||||
AUTH$RECAPTCHA_BLOCKED = "AUTH$RECAPTCHA_BLOCKED",
|
||||
AUTH$LETS_GET_STARTED = "AUTH$LETS_GET_STARTED",
|
||||
AUTH$INVITATION_PENDING = "AUTH$INVITATION_PENDING",
|
||||
COMMON$TERMS_OF_SERVICE = "COMMON$TERMS_OF_SERVICE",
|
||||
COMMON$AND = "COMMON$AND",
|
||||
COMMON$PRIVACY_POLICY = "COMMON$PRIVACY_POLICY",
|
||||
|
||||
@@ -12207,6 +12207,22 @@
|
||||
"de": "Lass uns anfangen",
|
||||
"uk": "Почнімо"
|
||||
},
|
||||
"AUTH$INVITATION_PENDING": {
|
||||
"en": "Sign in to accept your organization invitation",
|
||||
"ja": "組織への招待を受け入れるにはサインインしてください",
|
||||
"zh-CN": "登录以接受您的组织邀请",
|
||||
"zh-TW": "登入以接受您的組織邀請",
|
||||
"ko-KR": "조직 초대를 수락하려면 로그인하세요",
|
||||
"no": "Logg inn for å godta organisasjonsinvitasjonen din",
|
||||
"it": "Accedi per accettare l'invito della tua organizzazione",
|
||||
"pt": "Faça login para aceitar o convite da sua organização",
|
||||
"es": "Inicia sesión para aceptar la invitación de tu organización",
|
||||
"ar": "سجّل الدخول لقبول دعوة مؤسستك",
|
||||
"fr": "Connectez-vous pour accepter l'invitation de votre organisation",
|
||||
"tr": "Organizasyon davetinizi kabul etmek için giriş yapın",
|
||||
"de": "Melden Sie sich an, um Ihre Organisationseinladung anzunehmen",
|
||||
"uk": "Увійдіть, щоб прийняти запрошення до організації"
|
||||
},
|
||||
"COMMON$TERMS_OF_SERVICE": {
|
||||
"en": "Terms of Service",
|
||||
"ja": "利用規約",
|
||||
|
||||
@@ -10,6 +10,7 @@ import "./tailwind.css";
|
||||
import "./index.css";
|
||||
import React from "react";
|
||||
import { Toaster } from "react-hot-toast";
|
||||
import { useInvitation } from "#/hooks/use-invitation";
|
||||
|
||||
export function Layout({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
@@ -37,5 +38,9 @@ export const meta: MetaFunction = () => [
|
||||
];
|
||||
|
||||
export default function App() {
|
||||
// Handle invitation token cleanup when invitation flow completes
|
||||
// This runs on all pages to catch redirects from auth callback
|
||||
useInvitation();
|
||||
|
||||
return <Outlet />;
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useIsAuthed } from "#/hooks/query/use-is-authed";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { useGitHubAuthUrl } from "#/hooks/use-github-auth-url";
|
||||
import { useEmailVerification } from "#/hooks/use-email-verification";
|
||||
import { useInvitation } from "#/hooks/use-invitation";
|
||||
import { LoginContent } from "#/components/features/auth/login-content";
|
||||
import { EmailVerificationModal } from "#/components/features/waitlist/email-verification-modal";
|
||||
|
||||
@@ -23,6 +24,8 @@ export default function LoginPage() {
|
||||
userId,
|
||||
} = useEmailVerification();
|
||||
|
||||
const { hasInvitation, buildOAuthStateData } = useInvitation();
|
||||
|
||||
const gitHubAuthUrl = useGitHubAuthUrl({
|
||||
appMode: config.data?.app_mode || null,
|
||||
authUrl: config.data?.auth_url,
|
||||
@@ -69,6 +72,8 @@ export default function LoginPage() {
|
||||
emailVerified={emailVerified}
|
||||
hasDuplicatedEmail={hasDuplicatedEmail}
|
||||
recaptchaBlocked={recaptchaBlocked}
|
||||
hasInvitation={hasInvitation}
|
||||
buildOAuthStateData={buildOAuthStateData}
|
||||
/>
|
||||
</main>
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallbackProcessor,
|
||||
)
|
||||
from openhands.app_server.sandbox.sandbox_models import SandboxStatus
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.integrations.service_types import ProviderType, SuggestedTask
|
||||
from openhands.sdk.conversation.state import ConversationExecutionStatus
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
@@ -150,6 +150,7 @@ class AppConversationStartRequest(OpenHandsModel):
|
||||
selected_repository: str | None = None
|
||||
selected_branch: str | None = None
|
||||
git_provider: ProviderType | None = None
|
||||
suggested_task: SuggestedTask | None = None
|
||||
title: str | None = None
|
||||
trigger: ConversationTrigger | None = None
|
||||
pr_number: list[int] = Field(default_factory=list)
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.agent_server.models import (
|
||||
ConversationInfo,
|
||||
SendMessageRequest,
|
||||
StartConversationRequest,
|
||||
TextContent,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
@@ -78,6 +79,7 @@ from openhands.app_server.utils.llm_metadata import (
|
||||
)
|
||||
from openhands.experiments.experiment_manager import ExperimentManagerImpl
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.integrations.service_types import SuggestedTask
|
||||
from openhands.sdk import Agent, AgentContext, LocalWorkspace
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.plugin import PluginSource
|
||||
@@ -85,6 +87,7 @@ from openhands.sdk.secret import LookupSecret, SecretValue, StaticSecret
|
||||
from openhands.sdk.utils.paging import page_iterator
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.tools.preset.default import (
|
||||
get_default_tools,
|
||||
)
|
||||
@@ -209,6 +212,8 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
)
|
||||
self._inherit_configuration_from_parent(request, parent_info)
|
||||
|
||||
self._apply_suggested_task(request)
|
||||
|
||||
task = AppConversationStartTask(
|
||||
created_by_user_id=user_id,
|
||||
request=request,
|
||||
@@ -569,6 +574,33 @@ class LiveStatusAppConversationService(AppConversationServiceBase):
|
||||
if not request.llm_model and parent_info.llm_model:
|
||||
request.llm_model = parent_info.llm_model
|
||||
|
||||
def _apply_suggested_task(self, request: AppConversationStartRequest) -> None:
|
||||
"""Apply suggested task defaults to the start request."""
|
||||
suggested_task: SuggestedTask | None = request.suggested_task
|
||||
if not suggested_task:
|
||||
return
|
||||
|
||||
if request.initial_message is not None:
|
||||
raise ValueError(
|
||||
'initial_message cannot be provided when suggested_task is present'
|
||||
)
|
||||
|
||||
prompt = suggested_task.get_prompt_for_task()
|
||||
if not prompt:
|
||||
raise ValueError(
|
||||
f'Suggested task returned empty prompt for task type {suggested_task.task_type}'
|
||||
)
|
||||
request.initial_message = SendMessageRequest(
|
||||
role='user',
|
||||
content=[TextContent(text=prompt)],
|
||||
)
|
||||
request.trigger = ConversationTrigger.SUGGESTED_TASK
|
||||
|
||||
if not request.selected_repository:
|
||||
request.selected_repository = suggested_task.repo
|
||||
if not request.git_provider:
|
||||
request.git_provider = suggested_task.git_provider
|
||||
|
||||
def _compute_plan_path(
|
||||
self,
|
||||
working_dir: str,
|
||||
|
||||
@@ -7,7 +7,7 @@ import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import UUID as SQLUUID
|
||||
@@ -59,7 +59,7 @@ class StoredEventCallback(Base): # type: ignore
|
||||
|
||||
class StoredEventCallbackResult(Base): # type: ignore
|
||||
__tablename__ = 'event_callback_result'
|
||||
id = Column(SQLUUID, primary_key=True)
|
||||
id = Column(SQLUUID, primary_key=True, default=uuid4)
|
||||
status = Column(Enum(EventCallbackResultStatus), nullable=True)
|
||||
event_callback_id = Column(SQLUUID, index=True)
|
||||
event_id = Column(String, index=True)
|
||||
|
||||
@@ -142,7 +142,7 @@ runtime = [
|
||||
|
||||
[tool.poetry]
|
||||
name = "openhands-ai"
|
||||
version = "1.3.0"
|
||||
version = "1.4.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
authors = [ "OpenHands" ]
|
||||
license = "MIT"
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
---
|
||||
name: custom-codereview-guide
|
||||
description: Repo-specific code review guidelines for OpenHands/OpenHands. Provides additional review rules alongside the default code review skill.
|
||||
triggers:
|
||||
- /codereview
|
||||
---
|
||||
@@ -14,6 +14,7 @@ from pydantic import SecretStr
|
||||
from openhands.agent_server.models import (
|
||||
SendMessageRequest,
|
||||
StartConversationRequest,
|
||||
TextContent,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AgentType,
|
||||
@@ -32,12 +33,14 @@ from openhands.app_server.sandbox.sandbox_models import (
|
||||
from openhands.app_server.sandbox.sandbox_spec_models import SandboxSpecInfo
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import SuggestedTask, TaskType
|
||||
from openhands.sdk import Agent, Event
|
||||
from openhands.sdk.llm import LLM
|
||||
from openhands.sdk.secret import LookupSecret, StaticSecret
|
||||
from openhands.sdk.workspace import LocalWorkspace
|
||||
from openhands.sdk.workspace.remote.async_remote_workspace import AsyncRemoteWorkspace
|
||||
from openhands.server.types import AppMode
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
# Env var used by openhands SDK LLM to skip context-window validation (e.g. for gpt-4 in tests)
|
||||
_ALLOW_SHORT_CONTEXT_WINDOWS = 'ALLOW_SHORT_CONTEXT_WINDOWS'
|
||||
@@ -112,7 +115,62 @@ class TestLiveStatusAppConversationService:
|
||||
self.mock_sandbox.id = uuid4()
|
||||
self.mock_sandbox.status = SandboxStatus.RUNNING
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_apply_suggested_task_sets_prompt_and_trigger(self):
|
||||
"""Test suggested task prompts populate initial message and trigger."""
|
||||
suggested_task = SuggestedTask(
|
||||
git_provider=ProviderType.GITHUB,
|
||||
task_type=TaskType.UNRESOLVED_COMMENTS,
|
||||
repo='owner/repo',
|
||||
issue_number=42,
|
||||
title='Handle review comments',
|
||||
)
|
||||
request = AppConversationStartRequest(suggested_task=suggested_task)
|
||||
|
||||
self.service._apply_suggested_task(request)
|
||||
|
||||
assert request.initial_message is not None
|
||||
assert (
|
||||
request.initial_message.content[0].text
|
||||
== suggested_task.get_prompt_for_task()
|
||||
)
|
||||
assert request.trigger == ConversationTrigger.SUGGESTED_TASK
|
||||
assert request.selected_repository == suggested_task.repo
|
||||
assert request.git_provider == suggested_task.git_provider
|
||||
|
||||
def test_apply_suggested_task_raises_if_initial_message_present(self):
|
||||
suggested_task = SuggestedTask(
|
||||
repo='foo/bar',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Some title',
|
||||
task_type=TaskType.OPEN_ISSUE,
|
||||
issue_number=123,
|
||||
)
|
||||
|
||||
request = AppConversationStartRequest(
|
||||
suggested_task=suggested_task,
|
||||
initial_message=SendMessageRequest(
|
||||
role='user',
|
||||
content=[TextContent(text='User provided message')],
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match='initial_message cannot be provided'):
|
||||
self.service._apply_suggested_task(request)
|
||||
|
||||
def test_apply_suggested_task_raises_if_prompt_empty(self):
|
||||
suggested_task = SuggestedTask(
|
||||
repo='foo/bar',
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title='Some title',
|
||||
task_type=TaskType.OPEN_ISSUE,
|
||||
issue_number=123,
|
||||
)
|
||||
request = AppConversationStartRequest(suggested_task=suggested_task)
|
||||
|
||||
with patch.object(SuggestedTask, 'get_prompt_for_task', return_value=''):
|
||||
with pytest.raises(ValueError, match='empty prompt'):
|
||||
self.service._apply_suggested_task(request)
|
||||
|
||||
async def test_setup_secrets_for_git_providers_no_provider_tokens(self):
|
||||
"""Test _setup_secrets_for_git_providers with no provider tokens."""
|
||||
# Arrange
|
||||
|
||||
Reference in New Issue
Block a user