mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
68 Commits
fix-basic-
...
fix/settin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f16748505 | ||
|
|
8927ac2230 | ||
|
|
f3429e33ca | ||
|
|
7cd219792b | ||
|
|
2aabe2ed8c | ||
|
|
731a9a813e | ||
|
|
123e556fed | ||
|
|
6676cae249 | ||
|
|
fede37b496 | ||
|
|
3bcd6f18df | ||
|
|
0da18440c2 | ||
|
|
ac76e10048 | ||
|
|
b98bae8b5f | ||
|
|
516721d1ee | ||
|
|
4d6f66ca28 | ||
|
|
b18568da0b | ||
|
|
83dd3c169c | ||
|
|
35bddb14f1 | ||
|
|
e8425218e2 | ||
|
|
0a879fa781 | ||
|
|
41e142bbab | ||
|
|
b06b9eedac | ||
|
|
a9afafa991 | ||
|
|
663ace4b39 | ||
|
|
2d085a6e0a | ||
|
|
8b7112abe8 | ||
|
|
34547ba947 | ||
|
|
5f958ab60d | ||
|
|
d7656bf1c9 | ||
|
|
2bc107564c | ||
|
|
85eb1e1504 | ||
|
|
cd235cc8c7 | ||
|
|
40f52dfabc | ||
|
|
bab7bf85e8 | ||
|
|
c856537f65 | ||
|
|
736f5b2255 | ||
|
|
c1d9d11772 | ||
|
|
85244499fe | ||
|
|
c55084e223 | ||
|
|
e3bb75deb4 | ||
|
|
1948200762 | ||
|
|
affe0af361 | ||
|
|
f20c956196 | ||
|
|
4a089a3a0d | ||
|
|
aa0b2d0b74 | ||
|
|
bef9b80b9d | ||
|
|
c4a90b1f89 | ||
|
|
0d13c57d9f | ||
|
|
b3422f1275 | ||
|
|
f139a9970b | ||
|
|
54d156122c | ||
|
|
ac072bf686 | ||
|
|
a53812c029 | ||
|
|
1d1c0925b5 | ||
|
|
872f41e3c0 | ||
|
|
d43ff82534 | ||
|
|
8cd8c011b2 | ||
|
|
5c68b10983 | ||
|
|
a97fad1976 | ||
|
|
4c3542a91c | ||
|
|
f460057f58 | ||
|
|
4fa2ad0f47 | ||
|
|
dd8be12809 | ||
|
|
89475095d9 | ||
|
|
05d5f8848a | ||
|
|
ee2885eb0b | ||
|
|
545257f870 | ||
|
|
a071b0651e |
1
.github/workflows/ghcr-build.yml
vendored
1
.github/workflows/ghcr-build.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "saas-rel-*"
|
||||
tags:
|
||||
- "*"
|
||||
pull_request:
|
||||
|
||||
127
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
127
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
@@ -0,0 +1,127 @@
|
||||
---
|
||||
name: PR Review by OpenHands
|
||||
|
||||
on:
|
||||
# Use pull_request_target to allow fork PRs to access secrets when triggered by maintainers
|
||||
# Security: This workflow runs when:
|
||||
# 1. A new PR is opened (non-draft), OR
|
||||
# 2. A draft PR is marked as ready for review, OR
|
||||
# 3. A maintainer adds the 'review-this' label, OR
|
||||
# 4. A maintainer requests openhands-agent or all-hands-bot as a reviewer
|
||||
# Only users with write access can add labels or request reviews, ensuring security.
|
||||
# The PR code is explicitly checked out for review, but secrets are only accessible
|
||||
# because the workflow runs in the base repository context
|
||||
pull_request_target:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
pr-review:
|
||||
# Run when one of the following conditions is met:
|
||||
# 1. A new non-draft PR is opened by a trusted contributor, OR
|
||||
# 2. A draft PR is converted to ready for review by a trusted contributor, OR
|
||||
# 3. 'review-this' label is added, OR
|
||||
# 4. openhands-agent or all-hands-bot is requested as a reviewer
|
||||
# Note: FIRST_TIME_CONTRIBUTOR PRs require manual trigger via label/reviewer request
|
||||
if: |
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
|
||||
(github.event.action == 'ready_for_review' && github.event.pull_request.author_association != 'FIRST_TIME_CONTRIBUTOR') ||
|
||||
github.event.label.name == 'review-this' ||
|
||||
github.event.requested_reviewer.login == 'openhands-agent' ||
|
||||
github.event.requested_reviewer.login == 'all-hands-bot'
|
||||
concurrency:
|
||||
group: pr-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2404
|
||||
env:
|
||||
LLM_MODEL: litellm_proxy/claude-sonnet-4-5-20250929
|
||||
LLM_BASE_URL: https://llm-proxy.app.all-hands.dev
|
||||
# PR context will be automatically provided by the agent script
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
PR_TITLE: ${{ github.event.pull_request.title }}
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
PR_BASE_BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
PR_HEAD_BRANCH: ${{ github.event.pull_request.head.ref }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
steps:
|
||||
- name: Checkout software-agent-sdk repository
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
repository: OpenHands/software-agent-sdk
|
||||
path: software-agent-sdk
|
||||
|
||||
- name: Checkout PR repository
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
# When using pull_request_target, explicitly checkout the PR branch
|
||||
# This ensures we review the actual PR code (including fork PRs)
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
fetch-depth: 0
|
||||
# Security: Don't persist credentials to prevent untrusted PR code from using them
|
||||
persist-credentials: false
|
||||
path: pr-repo
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.13'
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
- name: Install GitHub CLI
|
||||
run: |
|
||||
# Install GitHub CLI for posting review comments
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y gh
|
||||
|
||||
- name: Install OpenHands dependencies
|
||||
run: |
|
||||
# Install OpenHands SDK and tools from local checkout
|
||||
uv pip install --system ./software-agent-sdk/openhands-sdk ./software-agent-sdk/openhands-tools
|
||||
|
||||
- name: Check required configuration
|
||||
env:
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
run: |
|
||||
if [ -z "$LLM_API_KEY" ]; then
|
||||
echo "Error: LLM_API_KEY secret is not set."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "PR Number: $PR_NUMBER"
|
||||
echo "PR Title: $PR_TITLE"
|
||||
echo "Repository: $REPO_NAME"
|
||||
echo "LLM model: $LLM_MODEL"
|
||||
if [ -n "$LLM_BASE_URL" ]; then
|
||||
echo "LLM base URL: $LLM_BASE_URL"
|
||||
fi
|
||||
|
||||
- name: Run PR review
|
||||
env:
|
||||
LLM_API_KEY: ${{ secrets.LLM_API_KEY }}
|
||||
GITHUB_TOKEN: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
run: |
|
||||
# Change to the PR repository directory so agent can analyze the code
|
||||
cd pr-repo
|
||||
|
||||
# Run the PR review script from the software-agent-sdk checkout
|
||||
uv run python ../software-agent-sdk/examples/03_github_workflows/02_pr_review/agent_script.py
|
||||
|
||||
- name: Upload logs as artifact
|
||||
uses: actions/upload-artifact@v5
|
||||
if: always()
|
||||
with:
|
||||
name: openhands-pr-review-logs
|
||||
path: |
|
||||
*.log
|
||||
output/
|
||||
retention-days: 7
|
||||
@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
|
||||
@@ -23,12 +23,23 @@ RUN apt-get update && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python packages with security fixes
|
||||
RUN /app/.venv/bin/pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace "posthog>=6.0.0" "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy google-cloud-recaptcha-enterprise && \
|
||||
# Update packages with known CVE fixes
|
||||
/app/.venv/bin/pip install --upgrade \
|
||||
"mcp>=1.10.0" \
|
||||
"pillow>=11.3.0"
|
||||
# Install poetry and export before importing current code.
|
||||
RUN /app/.venv/bin/pip install poetry poetry-plugin-export
|
||||
|
||||
# Install Python dependencies from poetry.lock for reproducible builds
|
||||
# Copy lock files first for better Docker layer caching
|
||||
COPY --chown=openhands:openhands enterprise/pyproject.toml enterprise/poetry.lock /tmp/enterprise/
|
||||
RUN cd /tmp/enterprise && \
|
||||
# Export only main dependencies with hashes for supply chain security
|
||||
/app/.venv/bin/poetry export --only main -o requirements.txt && \
|
||||
# Remove the local path dependency (openhands-ai is already in base image)
|
||||
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
|
||||
# Install pinned dependencies from lock file
|
||||
/app/.venv/bin/pip install -r requirements.txt && \
|
||||
# Cleanup - return to /app before removing /tmp/enterprise
|
||||
cd /app && \
|
||||
rm -rf /tmp/enterprise && \
|
||||
/app/.venv/bin/pip uninstall -y poetry poetry-plugin-export
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
|
||||
@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
return agent
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
# Skip experiment for planning agents which require their specialized prompt
|
||||
if agent.system_prompt_filename != 'system_prompt_planning.j2':
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -145,11 +145,7 @@ class GithubManager(Manager):
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
# Check event types before making expensive API calls (e.g., _user_has_write_access_to_repo)
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
@@ -159,8 +155,17 @@ class GithubManager(Manager):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
user_has_write_access = self._user_has_write_access_to_repo(
|
||||
installation_id, repo_name, username
|
||||
)
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
if (
|
||||
GithubFactory.is_eligible_for_conversation_starter(message)
|
||||
and user_has_write_access
|
||||
):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
return user_has_write_access
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
@@ -167,17 +167,15 @@ async def install_webhook_on_resource(
|
||||
scopes=SCOPES,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Creating new webhook',
|
||||
extra={
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
},
|
||||
)
|
||||
log_extra = {
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
}
|
||||
|
||||
if status == WebhookStatus.RATE_LIMITED:
|
||||
logger.warning('Rate limited while creating webhook', extra=log_extra)
|
||||
raise BreakLoopException()
|
||||
|
||||
if webhook_id:
|
||||
@@ -191,9 +189,8 @@ async def install_webhook_on_resource(
|
||||
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
|
||||
)
|
||||
logger.info('Created new webhook', extra=log_extra)
|
||||
else:
|
||||
logger.error('Failed to create webhook', extra=log_extra)
|
||||
|
||||
return webhook_id, status
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
@@ -14,6 +14,7 @@ class ResolverUserContext(UserContext):
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
self._provider_handler: ProviderHandler | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
@@ -29,12 +30,26 @@ class ResolverUserContext(UserContext):
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def _get_provider_handler(self) -> ProviderHandler:
|
||||
"""Get or create a ProviderHandler for git operations."""
|
||||
if self._provider_handler is None:
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('No provider tokens available')
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
self._provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_id=user_id
|
||||
)
|
||||
return self._provider_handler
|
||||
|
||||
async def get_authenticated_git_url(
|
||||
self, repository: str, is_optional: bool = False
|
||||
) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
provider_handler = await self._get_provider_handler()
|
||||
url = await provider_handler.get_authenticated_git_url(
|
||||
repository, is_optional=is_optional
|
||||
)
|
||||
return url
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
# Suppress alembic.runtime.plugins INFO logs during import to prevent non-JSON logs in production
|
||||
# These plugin setup messages would otherwise appear before logging is configured
|
||||
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine # noqa: E402
|
||||
from storage.base import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add byor_export_enabled flag to org table.
|
||||
|
||||
Revision ID: 091
|
||||
Revises: 090
|
||||
Create Date: 2025-01-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '091'
|
||||
down_revision: Union[str, None] = '090'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add byor_export_enabled column to org table with default false
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'byor_export_enabled',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Set byor_export_enabled to true for orgs that have completed billing sessions
|
||||
op.execute(
|
||||
sa.text("""
|
||||
UPDATE org SET byor_export_enabled = TRUE
|
||||
WHERE id IN (
|
||||
SELECT DISTINCT org_id FROM billing_sessions
|
||||
WHERE status = 'completed' AND org_id IS NOT NULL
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'byor_export_enabled')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Rename 'user' role to 'member' in role table.
|
||||
|
||||
Revision ID: 092
|
||||
Revises: 091
|
||||
Create Date: 2025-02-12 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '092'
|
||||
down_revision: Union[str, None] = '091'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename 'user' role to 'member' for clarity
|
||||
# This avoids confusion between the 'user' role and the 'user' entity/account
|
||||
op.execute(sa.text("UPDATE role SET name = 'member' WHERE name = 'user'"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert 'member' role back to 'user'
|
||||
op.execute(sa.text("UPDATE role SET name = 'user' WHERE name = 'member'"))
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add pending_free_credits flag to org table.
|
||||
|
||||
Revision ID: 093
|
||||
Revises: 092
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '093'
|
||||
down_revision: Union[str, None] = '092'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add pending_free_credits column to org table with default false.
|
||||
# New orgs will have this set to TRUE at creation time.
|
||||
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""create org_invitation table
|
||||
|
||||
Revision ID: 094
|
||||
Revises: 093
|
||||
Create Date: 2026-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '094'
|
||||
down_revision: Union[str, None] = '093'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create org_invitation table
|
||||
op.create_table(
|
||||
'org_invitation',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('token', sa.String(64), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Foreign key constraints
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'],
|
||||
['org.id'],
|
||||
name='org_invitation_org_fkey',
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['role_id'],
|
||||
['role.id'],
|
||||
name='org_invitation_role_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['inviter_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_inviter_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['accepted_by_user_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_accepter_fkey',
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index(
|
||||
'ix_org_invitation_token',
|
||||
'org_invitation',
|
||||
['token'],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_id',
|
||||
'org_invitation',
|
||||
['org_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_email',
|
||||
'org_invitation',
|
||||
['email'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_status',
|
||||
'org_invitation',
|
||||
['status'],
|
||||
)
|
||||
# Composite index for checking pending invitations
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_email_status',
|
||||
'org_invitation',
|
||||
['org_id', 'email', 'status'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('org_invitation')
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Drop pending_free_credits column from org table.
|
||||
|
||||
Revision ID: 095
|
||||
Revises: 094
|
||||
Create Date: 2025-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '095'
|
||||
down_revision: Union[str, None] = '094'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the pending_free_credits column from org table.
|
||||
# This column was used for tracking free credit eligibility but is no longer needed.
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add pending_free_credits column with default false.
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Create resend_synced_users table.
|
||||
|
||||
Revision ID: 096
|
||||
Revises: 095
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '096'
|
||||
down_revision: Union[str, None] = '095'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
|
||||
op.create_table(
|
||||
'resend_synced_users',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('audience_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'synced_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
|
||||
# Create index on email for fast lookups
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_email',
|
||||
'resend_synced_users',
|
||||
['email'],
|
||||
)
|
||||
|
||||
# Create index on audience_id for filtering by audience
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_audience_id',
|
||||
'resend_synced_users',
|
||||
['audience_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop resend_synced_users table."""
|
||||
op.drop_index(
|
||||
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
|
||||
)
|
||||
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
|
||||
op.drop_table('resend_synced_users')
|
||||
210
enterprise/poetry.lock
generated
210
enterprise/poetry.lock
generated
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.11.1"
|
||||
version = "1.11.4"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.11.1-py3-none-any.whl", hash = "sha256:28e3ca670114c7a936a33f2d193238fbdc75f429c4e0bb99a03b14e6c01663c9"},
|
||||
{file = "openhands_agent_server-1.11.1.tar.gz", hash = "sha256:06eaf8b8eda4ca05de24751a7d269b22f611328c6cb2b4b91f2486011228b69a"},
|
||||
{file = "openhands_agent_server-1.11.4-py3-none-any.whl", hash = "sha256:739bdb774dbfcd23d6e87ee6ee32bc0999f22300037506b6dd33e9ea67fa5c2a"},
|
||||
{file = "openhands_agent_server-1.11.4.tar.gz", hash = "sha256:41247f7022a046eb50ca3b552bc6d12bfa9776e1bd27d0989da91b9f7ac77ca2"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.11.1"
|
||||
openhands-sdk = "1.11.1"
|
||||
openhands-tools = "1.11.1"
|
||||
openhands-agent-server = "1.11.4"
|
||||
openhands-sdk = "1.11.4"
|
||||
openhands-tools = "1.11.4"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6225,14 +6225,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.11.1"
|
||||
version = "1.11.4"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.11.1-py3-none-any.whl", hash = "sha256:10ee0777286b149db21bdeeadb6d4c57f461da4049a4ba07576e7228b5c76c85"},
|
||||
{file = "openhands_sdk-1.11.1.tar.gz", hash = "sha256:57f5884d0596a8659b7c0cdbe86ebaa74c810c4e2645fcff45f0113894dd9376"},
|
||||
{file = "openhands_sdk-1.11.4-py3-none-any.whl", hash = "sha256:9f4607c5d94b56fbcd533207026ee892779dd50e29bce79277ff82454a4f76d5"},
|
||||
{file = "openhands_sdk-1.11.4.tar.gz", hash = "sha256:4088744f6b8856eeab22d3bc17e47d1736ea7ced945c2fa126bd7d48c14bb313"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.11.1"
|
||||
version = "1.11.4"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.11.1-py3-none-any.whl", hash = "sha256:0b64763def90dda5b6545a356a437437c2029ec9bc47a4e6dac5c06dea6a4e77"},
|
||||
{file = "openhands_tools-1.11.1.tar.gz", hash = "sha256:2a71d2d0619ca631b3b7f5bd741bfdf97f7ebe6f96dc2540f79b9a688a6309fc"},
|
||||
{file = "openhands_tools-1.11.4-py3-none-any.whl", hash = "sha256:efd721b73e87a0dac69171a76931363fa59fcde98107ca86081ee7bf0253673a"},
|
||||
{file = "openhands_tools-1.11.4.tar.gz", hash = "sha256:80671b1ea8c85a5247a75ea2340ae31d76363e9c723b104699a9a77e66d2043c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6851,103 +6851,103 @@ scramp = ">=1.4.5"
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "12.1.0"
|
||||
version = "12.1.1"
|
||||
description = "Python Imaging Library (fork)"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main", "test"]
|
||||
files = [
|
||||
{file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"},
|
||||
{file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1f1625b72740fdda5d77b4def688eb8fd6490975d06b909fd19f13f391e077e0"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:178aa072084bd88ec759052feca8e56cbb14a60b39322b99a049e58090479713"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b66e95d05ba806247aaa1561f080abc7975daf715c30780ff92a20e4ec546e1b"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89c7e895002bbe49cdc5426150377cbbc04767d7547ed145473f496dfa40408b"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a5cbdcddad0af3da87cb16b60d23648bc3b51967eb07223e9fed77a82b457c4"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9f51079765661884a486727f0729d29054242f74b46186026582b4e4769918e4"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:99c1506ea77c11531d75e3a412832a13a71c7ebc8192ab9e4b2e355555920e3e"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:36341d06738a9f66c8287cf8b876d24b18db9bd8740fa0672c74e259ad408cff"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win32.whl", hash = "sha256:6c52f062424c523d6c4db85518774cc3d50f5539dd6eed32b8f6229b26f24d40"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6008de247150668a705a6338156efb92334113421ceecf7438a12c9a12dab23"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:1a9b0ee305220b392e1124a764ee4265bd063e54a751a6b62eff69992f457fa9"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e879bb6cd5c73848ef3b2b48b8af9ff08c5b71ecda8048b7dd22d8a33f60be32"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:365b10bb9417dd4498c0e3b128018c4a624dc11c7b97d8cc54effe3b096f4c38"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d4ce8e329c93845720cd2014659ca67eac35f6433fd3050393d85f3ecef0dad5"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc354a04072b765eccf2204f588a7a532c9511e8b9c7f900e1b64e3e33487090"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e7976bf1910a8116b523b9f9f58bf410f3e8aa330cd9a2bb2953f9266ab49af"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:597bd9c8419bc7c6af5604e55847789b69123bbe25d65cc6ad3012b4f3c98d8b"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2c1fc0f2ca5f96a3c8407e41cca26a16e46b21060fe6d5b099d2cb01412222f5"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:578510d88c6229d735855e1f278aa305270438d36a05031dfaae5067cc8eb04d"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win32.whl", hash = "sha256:7311c0a0dcadb89b36b7025dfd8326ecfa36964e29913074d47382706e516a7c"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:fbfa2a7c10cc2623f412753cddf391c7f971c52ca40a3f65dc5039b2939e8563"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:b81b5e3511211631b3f672a595e3221252c90af017e399056d0faabb9538aa80"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab323b787d6e18b3d91a72fc99b1a2c28651e4358749842b8f8dfacd28ef2052"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:adebb5bee0f0af4909c30db0d890c773d1a92ffe83da908e2e9e720f8edf3984"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb66b7cc26f50977108790e2456b7921e773f23db5630261102233eb355a3b79"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aee2810642b2898bb187ced9b349e95d2a7272930796e022efaf12e99dccd293"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a0b1cd6232e2b618adcc54d9882e4e662a089d5768cd188f7c245b4c8c44a397"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7aac39bcf8d4770d089588a2e1dd111cbaa42df5a94be3114222057d68336bd0"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ab174cd7d29a62dd139c44bf74b698039328f45cb03b4596c43473a46656b2f3"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:339ffdcb7cbeaa08221cd401d517d4b1fe7a9ed5d400e4a8039719238620ca35"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win32.whl", hash = "sha256:5d1f9575a12bed9e9eedd9a4972834b08c97a352bd17955ccdebfeca5913fa0a"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:21329ec8c96c6e979cd0dfd29406c40c1d52521a90544463057d2aaa937d66a6"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:af9a332e572978f0218686636610555ae3defd1633597be015ed50289a03c523"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:d242e8ac078781f1de88bf823d70c1a9b3c7950a44cdf4b7c012e22ccbcd8e4e"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:02f84dfad02693676692746df05b89cf25597560db2857363a208e393429f5e9"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:e65498daf4b583091ccbb2556c7000abf0f3349fcd57ef7adc9a84a394ed29f6"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6c6db3b84c87d48d0088943bf33440e0c42370b99b1c2a7989216f7b42eede60"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8b7e5304e34942bf62e15184219a7b5ad4ff7f3bb5cca4d984f37df1a0e1aee2"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:18e5bddd742a44b7e6b1e773ab5db102bd7a94c32555ba656e76d319d19c3850"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc44ef1f3de4f45b50ccf9136999d71abb99dca7706bc75d222ed350b9fd2289"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5a8eb7ed8d4198bccbd07058416eeec51686b498e784eda166395a23eb99138e"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47b94983da0c642de92ced1702c5b6c292a84bd3a8e1d1702ff923f183594717"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:518a48c2aab7ce596d3bf79d0e275661b846e86e4d0e7dec34712c30fe07f02a"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a550ae29b95c6dc13cf69e2c9dc5747f814c54eeb2e32d683e5e93af56caa029"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win32.whl", hash = "sha256:a003d7422449f6d1e3a34e3dd4110c22148336918ddbfc6a32581cd54b2e0b2b"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:344cf1e3dab3be4b1fa08e449323d98a2a3f819ad20f4b22e77a0ede31f0faa1"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:5c0dd1636633e7e6a0afe7bf6a51a14992b7f8e60de5789018ebbdfae55b040a"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0330d233c1a0ead844fc097a7d16c0abff4c12e856c0b325f231820fee1f39da"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5dae5f21afb91322f2ff791895ddd8889e5e947ff59f71b46041c8ce6db790bc"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e0c664be47252947d870ac0d327fea7e63985a08794758aa8af5b6cb6ec0c9c"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:691ab2ac363b8217f7d31b3497108fb1f50faab2f75dfb03284ec2f217e87bf8"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9e8064fb1cc019296958595f6db671fba95209e3ceb0c4734c9baf97de04b20"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:472a8d7ded663e6162dafdf20015c486a7009483ca671cece7a9279b512fcb13"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:89b54027a766529136a06cfebeecb3a04900397a3590fd252160b888479517bf"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:86172b0831b82ce4f7877f280055892b31179e1576aa00d0df3bb1bbf8c3e524"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win32.whl", hash = "sha256:44ce27545b6efcf0fdbdceb31c9a5bdea9333e664cda58a7e674bb74608b3986"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a285e3eb7a5a45a2ff504e31f4a8d1b12ef62e84e5411c6804a42197c1cf586c"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cc7d296b5ea4d29e6570dabeaed58d31c3fea35a633a69679fb03d7664f43fb3"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:417423db963cb4be8bac3fc1204fe61610f6abeed1580a7a2cbb2fbda20f12af"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:b957b71c6b2387610f556a7eb0828afbe40b4a98036fc0d2acfa5a44a0c2036f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:097690ba1f2efdeb165a20469d59d8bb03c55fb6621eb2041a060ae8ea3e9642"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2815a87ab27848db0321fb78c7f0b2c8649dee134b7f2b80c6a45c6831d75ccd"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f7ed2c6543bad5a7d5530eb9e78c53132f93dfa44a28492db88b41cdab885202"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:652a2c9ccfb556235b2b501a3a7cf3742148cd22e04b5625c5fe057ea3e3191f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d6e4571eedf43af33d0fc233a382a76e849badbccdf1ac438841308652a08e1f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b574c51cf7d5d62e9be37ba446224b59a2da26dc4c1bb2ecbe936a4fb1a7cb7f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a37691702ed687799de29a518d63d4682d9016932db66d4e90c345831b02fb4e"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f95c00d5d6700b2b890479664a06e754974848afaae5e21beb4d83c106923fd0"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:559b38da23606e68681337ad74622c4dbba02254fc9cb4488a305dd5975c7eeb"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win32.whl", hash = "sha256:03edcc34d688572014ff223c125a3f77fb08091e4607e7745002fc214070b35f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:50480dcd74fa63b8e78235957d302d98d98d82ccbfac4c7e12108ba9ecbdba15"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:5cb1785d97b0c3d1d1a16bc1d710c4a0049daefc4935f3a8f31f827f4d3d2e7f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:1f90cff8aa76835cba5769f0b3121a22bd4eb9e6884cfe338216e557a9a548b8"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1f1be78ce9466a7ee64bfda57bdba0f7cc499d9794d518b854816c41bf0aa4e9"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:42fc1f4677106188ad9a55562bbade416f8b55456f522430fadab3cef7cd4e60"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98edb152429ab62a1818039744d8fbb3ccab98a7c29fc3d5fcef158f3f1f68b7"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d470ab1178551dd17fdba0fef463359c41aaa613cdcd7ff8373f54be629f9f8f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6408a7b064595afcab0a49393a413732a35788f2a5092fdc6266952ed67de586"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5d8c41325b382c07799a3682c1c258469ea2ff97103c53717b7893862d0c98ce"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:c7697918b5be27424e9ce568193efd13d925c4481dd364e43f5dff72d33e10f8"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win32.whl", hash = "sha256:d2912fd8114fc5545aa3a4b5576512f64c55a03f3ebcca4c10194d593d43ea36"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:4ceb838d4bd9dab43e06c363cab2eebf63846d6a4aeaea283bbdfd8f1a8ed58b"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7b03048319bfc6170e93bd60728a1af51d3dd7704935feb228c4d4faab35d334"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:600fd103672b925fe62ed08e0d874ea34d692474df6f4bf7ebe148b30f89f39f"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:665e1b916b043cef294bc54d47bf02d87e13f769bc4bc5fa225a24b3a6c5aca9"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:495c302af3aad1ca67420ddd5c7bd480c8867ad173528767d906428057a11f0e"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8fd420ef0c52c88b5a035a0886f367748c72147b2b8f384c9d12656678dfdfa9"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f975aa7ef9684ce7e2c18a3aa8f8e2106ce1e46b94ab713d156b2898811651d3"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8089c852a56c2966cf18835db62d9b34fef7ba74c726ad943928d494fa7f4735"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:cb9bb857b2d057c6dfc72ac5f3b44836924ba15721882ef103cecb40d002d80e"},
|
||||
{file = "pillow-12.1.1.tar.gz", hash = "sha256:9ad8fa5937ab05218e2b6a4cff30295ad35afd2f83ac592e68c0d871bb0fdbc4"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -14917,4 +14917,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "b5cbb1e25176845ac9f95650a802667e2f8be1a536e3e55a9269b5af5a42e3fc"
|
||||
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"
|
||||
|
||||
@@ -44,6 +44,12 @@ httpx = "*"
|
||||
scikit-learn = "^1.7.0"
|
||||
shap = "^0.48.0"
|
||||
google-cloud-recaptcha-enterprise = "^1.24.0"
|
||||
# Dependencies previously only in Dockerfile, now managed via poetry.lock
|
||||
prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
accept_router as invitation_accept_router,
|
||||
)
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
invitation_router,
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
@@ -78,8 +84,15 @@ base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
|
||||
from integrations.github.github_v1_callback_processor import ( # noqa: E402
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
from server.routes.integration.github import github_integration_router # noqa: E402
|
||||
|
||||
# Bludgeon mypy into not deleting my import
|
||||
logger.debug(f'Loaded {GithubV1CallbackProcessor.__name__}')
|
||||
|
||||
base_app.include_router(
|
||||
github_integration_router
|
||||
) # Add additional route for integration webhook events
|
||||
@@ -92,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
306
enterprise/server/auth/authorization.py
Normal file
306
enterprise/server/auth/authorization.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Permission-based authorization dependencies for API endpoints.
|
||||
|
||||
This module provides FastAPI dependencies for checking user permissions
|
||||
within organizations. It uses a permission-based authorization model where
|
||||
roles (owner, admin, member) are mapped to specific permissions.
|
||||
|
||||
Permissions are defined in the Permission enum and mapped to roles via
|
||||
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
|
||||
the familiar role-based hierarchy.
|
||||
|
||||
Usage:
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with VIEW_LLM_SETTINGS permission can access
|
||||
...
|
||||
|
||||
@router.patch('/{org_id}/settings')
|
||||
async def update_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with EDIT_LLM_SETTINGS permission can access
|
||||
...
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Permissions that can be assigned to roles."""
|
||||
|
||||
# Secrets
|
||||
MANAGE_SECRETS = 'manage_secrets'
|
||||
|
||||
# MCP
|
||||
MANAGE_MCP = 'manage_mcp'
|
||||
|
||||
# Integrations
|
||||
MANAGE_INTEGRATIONS = 'manage_integrations'
|
||||
|
||||
# Application Settings
|
||||
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
|
||||
|
||||
# API Keys
|
||||
MANAGE_API_KEYS = 'manage_api_keys'
|
||||
|
||||
# LLM Settings
|
||||
VIEW_LLM_SETTINGS = 'view_llm_settings'
|
||||
EDIT_LLM_SETTINGS = 'edit_llm_settings'
|
||||
|
||||
# Billing
|
||||
VIEW_BILLING = 'view_billing'
|
||||
ADD_CREDITS = 'add_credits'
|
||||
|
||||
# Organization Members
|
||||
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
|
||||
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
|
||||
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
|
||||
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
|
||||
|
||||
# Organization Management
|
||||
VIEW_ORG_SETTINGS = 'view_org_settings'
|
||||
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
|
||||
DELETE_ORGANIZATION = 'delete_organization'
|
||||
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
MEMBER = 'member'
|
||||
|
||||
|
||||
# Permission mappings for each role
|
||||
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
RoleName.OWNER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
Permission.CHANGE_USER_ROLE_OWNER,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (synchronous version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
|
||||
else:
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return RoleStore.get_role_by_id(org_member.role_id)
|
||||
|
||||
|
||||
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
|
||||
parse_uuid(user_id)
|
||||
)
|
||||
else:
|
||||
org_member = await OrgMemberStore.get_org_member_async(
|
||||
org_id, parse_uuid(user_id)
|
||||
)
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return await RoleStore.get_role_by_id_async(org_member.role_id)
|
||||
|
||||
|
||||
def get_role_permissions(role_name: str) -> frozenset[Permission]:
|
||||
"""
|
||||
Get the permissions for a role.
|
||||
|
||||
Args:
|
||||
role_name: Name of the role
|
||||
|
||||
Returns:
|
||||
Set of permissions for the role
|
||||
"""
|
||||
try:
|
||||
role_enum = RoleName(role_name)
|
||||
return ROLE_PERMISSIONS.get(role_enum, frozenset())
|
||||
except ValueError:
|
||||
return frozenset()
|
||||
|
||||
|
||||
def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
"""
|
||||
Check if a role has a specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User's Role object
|
||||
permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if the role has the permission
|
||||
"""
|
||||
permissions = get_role_permissions(user_role.name)
|
||||
return permission in permissions
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
...
|
||||
|
||||
Args:
|
||||
permission: The permission required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that validates permission and returns user_id
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role_async(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'User not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if not has_permission(user_role, permission):
|
||||
logger.warning(
|
||||
'Insufficient permissions',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
'required_permission': permission.value,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'Requires {permission.value} permission',
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
@@ -15,6 +15,11 @@ IS_FEATURE_ENV = (
|
||||
) # Does not include the staging deployment
|
||||
IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
|
||||
# Role name constants
|
||||
ROLE_OWNER = 'owner'
|
||||
ROLE_ADMIN = 'admin'
|
||||
ROLE_MEMBER = 'member'
|
||||
|
||||
# Deprecated - billing margins are now handled internally in litellm
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
@@ -25,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
# Minimax is now the default as it gives results close to claude in terms of quality
|
||||
# but at a much lower price
|
||||
5: 'minimax-m2.5',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
@@ -54,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
|
||||
@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
|
||||
obj['stack_info'] = format_stack(stack_info)
|
||||
|
||||
result = json.dumps(obj, **kwargs)
|
||||
|
||||
# Swap out newlines to make things easier to read. This will produce
|
||||
# invalid json but means we can have similar logs in local development
|
||||
# to production, making things easier to correlate. Obviously,
|
||||
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
|
||||
if LOG_JSON_FOR_CONSOLE:
|
||||
result = result.replace('\\n', '\n')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ class SetAuthCookieMiddleware:
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/api/organizations/members/invite/accept',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
|
||||
@@ -2,10 +2,12 @@ from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -52,7 +54,6 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
|
||||
try:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
@@ -135,9 +136,9 @@ class ApiKeyCreate(BaseModel):
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ApiKeyCreateResponse(ApiKeyResponse):
|
||||
@@ -148,8 +149,47 @@ class LlmApiKeyResponse(BaseModel):
|
||||
key: str | None
|
||||
|
||||
|
||||
@api_router.post('', response_model=ApiKeyCreateResponse)
|
||||
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
|
||||
class ByorPermittedResponse(BaseModel):
|
||||
permitted: bool
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor/permitted', tags=['Keys'])
|
||||
async def check_byor_permitted(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> ByorPermittedResponse:
|
||||
"""Check if BYOR key export is permitted for the user's current org."""
|
||||
try:
|
||||
permitted = await OrgService.check_byor_export_enabled(user_id)
|
||||
return ByorPermittedResponse(permitted=permitted)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error checking BYOR export permission', extra={'error': str(e)}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to check BYOR export permission',
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('', tags=['Keys'])
|
||||
async def create_api_key(
|
||||
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
|
||||
) -> ApiKeyCreateResponse:
|
||||
"""Create a new API key for the authenticated user."""
|
||||
try:
|
||||
api_key = await api_key_store.create_api_key(
|
||||
@@ -158,48 +198,29 @@ async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user
|
||||
# Get the created key details
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
for key in keys:
|
||||
if key['name'] == key_data.name:
|
||||
return {
|
||||
**key,
|
||||
'key': api_key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
if key.name == key_data.name:
|
||||
return ApiKeyCreateResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
key=api_key,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error creating API key')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('', response_model=list[ApiKeyResponse])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('', tags=['Keys'])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
|
||||
"""List all API keys for the authenticated user."""
|
||||
try:
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
return [
|
||||
{
|
||||
**key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
for key in keys
|
||||
]
|
||||
return [api_key_to_response(key) for key in keys]
|
||||
except Exception:
|
||||
logger.exception('Error listing API keys')
|
||||
raise HTTPException(
|
||||
@@ -208,8 +229,10 @@ async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete('/{key_id}')
|
||||
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
@api_router.delete('/{key_id}', tags=['Keys'])
|
||||
async def delete_api_key(
|
||||
key_id: int, user_id: str = Depends(get_user_id)
|
||||
) -> MessageResponse:
|
||||
"""Delete an API key."""
|
||||
try:
|
||||
# First, verify the key belongs to the user
|
||||
@@ -217,7 +240,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
key_to_delete = None
|
||||
|
||||
for key in keys:
|
||||
if key['id'] == key_id:
|
||||
if key.id == key_id:
|
||||
key_to_delete = key
|
||||
break
|
||||
|
||||
@@ -235,7 +258,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete API key',
|
||||
)
|
||||
return {'message': 'API key deleted successfully'}
|
||||
return MessageResponse(message='API key deleted successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
@@ -246,22 +269,33 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
|
||||
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
This endpoint validates that the key exists in LiteLLM before returning it.
|
||||
If validation fails, it automatically generates a new key to ensure users
|
||||
always receive a working key.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Check if the BYOR key exists in the database
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
return LlmApiKeyResponse(key=byor_key)
|
||||
else:
|
||||
# Key exists in DB but is invalid in LiteLLM - regenerate it
|
||||
logger.warning(
|
||||
@@ -286,7 +320,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'Successfully generated and stored new BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate new BYOR LLM API key',
|
||||
@@ -308,12 +342,24 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
|
||||
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
|
||||
@api_router.post('/llm/byor/refresh', tags=['Keys'])
|
||||
async def refresh_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
@@ -352,7 +398,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'BYOR LLM API key refresh completed successfully',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
except HTTPException as he:
|
||||
logger.error(
|
||||
'HTTP exception during BYOR LLM API key refresh',
|
||||
|
||||
@@ -5,6 +5,7 @@ import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
import posthog
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
|
||||
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
|
||||
)
|
||||
|
||||
|
||||
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
|
||||
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token, invitation_token).
|
||||
Tokens may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None, None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return (
|
||||
state_data.get('redirect_url', ''),
|
||||
state_data.get('recaptcha_token'),
|
||||
state_data.get('invitation_token'),
|
||||
)
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None, None
|
||||
|
||||
|
||||
# Keep alias for backward compatibility
|
||||
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
|
||||
"""Extract redirect URL and reCAPTCHA token from OAuth state.
|
||||
|
||||
Deprecated: Use _extract_oauth_state instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token). Token may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None
|
||||
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
|
||||
return redirect_url, recaptcha_token
|
||||
|
||||
|
||||
@oauth_router.get('/keycloak/callback')
|
||||
@@ -130,8 +156,8 @@ async def keycloak_callback(
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
):
|
||||
# Extract redirect URL and reCAPTCHA token from state
|
||||
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
|
||||
# Extract redirect URL, reCAPTCHA token, and invitation token from state
|
||||
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
|
||||
if not redirect_url:
|
||||
redirect_url = str(request.base_url)
|
||||
|
||||
@@ -302,8 +328,13 @@ async def keycloak_callback(
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
# Preserve invitation token so it can be included in OAuth state after verification
|
||||
if invitation_token:
|
||||
verification_redirect_url = (
|
||||
f'{verification_redirect_url}&invitation_token={invitation_token}'
|
||||
)
|
||||
response = RedirectResponse(verification_redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
@@ -381,14 +412,90 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
|
||||
# Process invitation token if present (after email verification but before TOS)
|
||||
if invitation_token:
|
||||
try:
|
||||
logger.info(
|
||||
'Processing invitation token during auth callback',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'invitation_token_prefix': invitation_token[:10] + '...',
|
||||
},
|
||||
)
|
||||
|
||||
await OrgInvitationService.accept_invitation(
|
||||
invitation_token, parse_uuid(user_id)
|
||||
)
|
||||
logger.info(
|
||||
'Invitation accepted during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation expired during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Add query param to redirect URL
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invalid invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'User already member during invitation acceptance',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Email mismatch during auth callback invitation acceptance',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error processing invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
# Don't fail the login if invitation processing fails
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_error=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_error=true'
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
tos_redirect_url = (
|
||||
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
|
||||
)
|
||||
if invitation_token:
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
else:
|
||||
if invitation_token:
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
set_response_cookie(
|
||||
|
||||
@@ -9,14 +9,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import (
|
||||
STRIPE_API_KEY,
|
||||
)
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -94,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(user.current_org_id)
|
||||
)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
@@ -148,7 +147,7 @@ async def create_customer_setup_session(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
success_url=f'{base_url}?setup=success',
|
||||
cancel_url=f'{base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
@@ -250,15 +249,21 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Enable BYOR export for the org now that they've purchased credits
|
||||
if org:
|
||||
org.byor_export_enabled = True
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
|
||||
122
enterprise/server/routes/org_invitation_models.py
Normal file
122
enterprise/server/routes/org_invitation_models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Pydantic models and custom exceptions for organization invitations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
class InvitationError(Exception):
|
||||
"""Base exception for invitation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvitationAlreadyExistsError(InvitationError):
|
||||
"""Raised when a pending invitation already exists for the email."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = 'A pending invitation already exists for this email'
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserAlreadyMemberError(InvitationError):
|
||||
"""Raised when the user is already a member of the organization."""
|
||||
|
||||
def __init__(self, message: str = 'User is already a member of this organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationExpiredError(InvitationError):
|
||||
"""Raised when the invitation has expired."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation has expired'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationInvalidError(InvitationError):
|
||||
"""Raised when the invitation is invalid or revoked."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation is no longer valid'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionError(InvitationError):
|
||||
"""Raised when the user lacks permission to perform the action."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EmailMismatchError(InvitationError):
|
||||
"""Raised when the accepting user's email doesn't match the invitation email."""
|
||||
|
||||
def __init__(self, message: str = 'Your email does not match the invitation'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationCreate(BaseModel):
|
||||
"""Request model for creating invitation(s)."""
|
||||
|
||||
emails: list[EmailStr]
|
||||
role: str = 'member' # Default to member role
|
||||
|
||||
|
||||
class InvitationResponse(BaseModel):
|
||||
"""Response model for invitation details."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
created_at: str
|
||||
expires_at: str
|
||||
inviter_email: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_invitation(
|
||||
cls,
|
||||
invitation: OrgInvitation,
|
||||
inviter_email: str | None = None,
|
||||
) -> 'InvitationResponse':
|
||||
"""Create an InvitationResponse from an OrgInvitation entity.
|
||||
|
||||
Args:
|
||||
invitation: The invitation entity to convert
|
||||
inviter_email: Optional email of the inviter
|
||||
|
||||
Returns:
|
||||
InvitationResponse: The response model instance
|
||||
"""
|
||||
role_name = ''
|
||||
if invitation.role:
|
||||
role_name = invitation.role.name
|
||||
elif invitation.role_id:
|
||||
role = RoleStore.get_role_by_id(invitation.role_id)
|
||||
role_name = role.name if role else ''
|
||||
|
||||
return cls(
|
||||
id=invitation.id,
|
||||
email=invitation.email,
|
||||
role=role_name,
|
||||
status=invitation.status,
|
||||
created_at=invitation.created_at.isoformat(),
|
||||
expires_at=invitation.expires_at.isoformat(),
|
||||
inviter_email=inviter_email,
|
||||
)
|
||||
|
||||
|
||||
class InvitationFailure(BaseModel):
|
||||
"""Response model for a failed invitation."""
|
||||
|
||||
email: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchInvitationResponse(BaseModel):
|
||||
"""Response model for batch invitation creation."""
|
||||
|
||||
successful: list[InvitationResponse]
|
||||
failed: list[InvitationFailure]
|
||||
226
enterprise/server/routes/org_invitations.py
Normal file
226
enterprise/server/routes/org_invitations.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""API routes for organization invitations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from server.routes.org_invitation_models import (
|
||||
BatchInvitationResponse,
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationCreate,
|
||||
InvitationExpiredError,
|
||||
InvitationFailure,
|
||||
InvitationInvalidError,
|
||||
InvitationResponse,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
|
||||
# Router for invitation operations on an organization (requires org_id)
|
||||
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
|
||||
|
||||
# Router for accepting invitations (no org_id required)
|
||||
accept_router = APIRouter(prefix='/api/organizations/members/invite')
|
||||
|
||||
|
||||
@invitation_router.post(
|
||||
'/invite',
|
||||
response_model=BatchInvitationResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
invitation_data: InvitationCreate,
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Create organization invitations for multiple email addresses.
|
||||
|
||||
Sends emails to invitees with secure links to join the organization.
|
||||
Supports batch invitations - some may succeed while others fail.
|
||||
|
||||
Permission rules:
|
||||
- Only owners and admins can create invitations
|
||||
- Admins can only invite with 'member' or 'admin' role (not 'owner')
|
||||
- Owners can invite with any role
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
invitation_data: Invitation details (emails array, role)
|
||||
request: FastAPI request
|
||||
user_id: Authenticated user ID (from dependency)
|
||||
|
||||
Returns:
|
||||
BatchInvitationResponse: Lists of successful and failed invitations
|
||||
|
||||
Raises:
|
||||
HTTPException 400: Invalid role or organization not found
|
||||
HTTPException 403: User lacks permission to invite
|
||||
HTTPException 429: Rate limit exceeded
|
||||
"""
|
||||
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
|
||||
await check_rate_limit_by_user_id(
|
||||
request=request,
|
||||
key_prefix='org_invitation_create',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=6,
|
||||
)
|
||||
|
||||
try:
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=[str(email) for email in invitation_data.emails],
|
||||
role_name=invitation_data.role,
|
||||
inviter_id=UUID(user_id),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Batch organization invitations created',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'total_emails': len(invitation_data.emails),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
'inviter_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return BatchInvitationResponse(
|
||||
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
|
||||
failed=[
|
||||
InvitationFailure(email=email, error=error) for email, error in failed
|
||||
],
|
||||
)
|
||||
|
||||
except InsufficientPermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error creating batch invitations',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@accept_router.get('/accept')
|
||||
async def accept_invitation(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""Accept an organization invitation via token.
|
||||
|
||||
This endpoint is accessed via the link in the invitation email.
|
||||
|
||||
Flow:
|
||||
1. If user is authenticated: Accept invitation directly and redirect to home
|
||||
2. If user is not authenticated: Redirect to login page with invitation token
|
||||
- Frontend stores token and includes it in OAuth state during login
|
||||
- After authentication, keycloak_callback processes the invitation
|
||||
|
||||
Args:
|
||||
token: The invitation token from the email link
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
|
||||
or home page with error query params on failure
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
|
||||
# Try to get user_id from auth (may not be authenticated)
|
||||
user_id = None
|
||||
try:
|
||||
user_auth = await get_user_auth(request)
|
||||
if user_auth:
|
||||
user_id = await user_auth.get_user_id()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
# User not authenticated - redirect to login page with invitation token
|
||||
# Frontend will store the token and include it in OAuth state during login
|
||||
logger.info(
|
||||
'Invitation accept: redirecting unauthenticated user to login',
|
||||
extra={'token_prefix': token[:10] + '...'},
|
||||
)
|
||||
login_url = f'{base_url}/login?invitation_token={token}'
|
||||
return RedirectResponse(login_url, status_code=302)
|
||||
|
||||
# User is authenticated - process the invitation directly
|
||||
try:
|
||||
await OrgInvitationService.accept_invitation(token, UUID(user_id))
|
||||
|
||||
logger.info(
|
||||
'Invitation accepted successfully',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Redirect to home page on success
|
||||
return RedirectResponse(f'{base_url}/', status_code=302)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation accept failed: expired',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: invalid',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'Invitation accept: user already member',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: email mismatch',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error accepting invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, StringConstraints
|
||||
from pydantic import BaseModel, EmailStr, Field, SecretStr, StringConstraints
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
@@ -43,6 +45,16 @@ class OrgAuthorizationError(OrgDeletionError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrphanedUserError(OrgDeletionError):
|
||||
"""Raised when deleting an org would leave users without any organization."""
|
||||
|
||||
def __init__(self, user_ids: list[str]):
|
||||
self.user_ids = user_ids
|
||||
super().__init__(
|
||||
f'Cannot delete organization: {len(user_ids)} user(s) would have no remaining organization'
|
||||
)
|
||||
|
||||
|
||||
class OrgNotFoundError(Exception):
|
||||
"""Raised when organization is not found or user doesn't have access."""
|
||||
|
||||
@@ -51,6 +63,61 @@ class OrgNotFoundError(Exception):
|
||||
super().__init__(f'Organization with id "{org_id}" not found')
|
||||
|
||||
|
||||
class OrgMemberNotFoundError(Exception):
|
||||
"""Raised when a member is not found in an organization."""
|
||||
|
||||
def __init__(self, org_id: str, user_id: str):
|
||||
self.org_id = org_id
|
||||
self.user_id = user_id
|
||||
super().__init__(f'Member "{user_id}" not found in organization "{org_id}"')
|
||||
|
||||
|
||||
class RoleNotFoundError(Exception):
|
||||
"""Raised when a role is not found."""
|
||||
|
||||
def __init__(self, role_id: int):
|
||||
self.role_id = role_id
|
||||
super().__init__(f'Role with id "{role_id}" not found')
|
||||
|
||||
|
||||
class InvalidRoleError(Exception):
|
||||
"""Raised when an invalid role name is specified."""
|
||||
|
||||
def __init__(self, role_name: str):
|
||||
self.role_name = role_name
|
||||
super().__init__(f'Invalid role: "{role_name}"')
|
||||
|
||||
|
||||
class InsufficientPermissionError(Exception):
|
||||
"""Raised when user lacks permission to perform an operation."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CannotModifySelfError(Exception):
|
||||
"""Raised when user attempts to modify their own membership."""
|
||||
|
||||
def __init__(self, action: str = 'modify'):
|
||||
self.action = action
|
||||
super().__init__(f'Cannot {action} your own membership')
|
||||
|
||||
|
||||
class LastOwnerError(Exception):
|
||||
"""Raised when attempting to remove or demote the last owner."""
|
||||
|
||||
def __init__(self, action: str = 'remove'):
|
||||
self.action = action
|
||||
super().__init__(f'Cannot {action} the last owner of an organization')
|
||||
|
||||
|
||||
class MemberUpdateError(Exception):
|
||||
"""Raised when member update operation fails."""
|
||||
|
||||
def __init__(self, message: str = 'Failed to update member'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
@@ -91,14 +158,18 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
is_personal: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
def from_org(
|
||||
cls, org: Org, credits: float | None = None, user_id: str | None = None
|
||||
) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
user_id: Optional user ID to determine if org is personal (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
@@ -134,6 +205,7 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
is_personal=str(org.id) == user_id if user_id else False,
|
||||
)
|
||||
|
||||
|
||||
@@ -142,12 +214,17 @@ class OrgPage(BaseModel):
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
current_org_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
"""Request model for updating an organization."""
|
||||
|
||||
# Basic organization information (any authenticated user can update)
|
||||
name: Annotated[
|
||||
str | None,
|
||||
StringConstraints(strip_whitespace=True, min_length=1, max_length=255),
|
||||
] = None
|
||||
contact_name: str | None = None
|
||||
contact_email: EmailStr | None = None
|
||||
conversation_expiration: int | None = None
|
||||
@@ -173,3 +250,79 @@ class OrgUpdate(BaseModel):
|
||||
confirmation_mode: bool | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
|
||||
|
||||
class OrgMemberResponse(BaseModel):
|
||||
"""Response model for a single organization member."""
|
||||
|
||||
user_id: str
|
||||
email: str | None
|
||||
role_id: int
|
||||
role: str
|
||||
role_rank: int
|
||||
status: str | None
|
||||
|
||||
|
||||
class OrgMemberPage(BaseModel):
|
||||
"""Paginated response for organization members."""
|
||||
|
||||
items: list[OrgMemberResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
class OrgMemberUpdate(BaseModel):
|
||||
"""Request model for updating an organization member."""
|
||||
|
||||
role: str | None = None # Role name: 'owner', 'admin', or 'member'
|
||||
|
||||
|
||||
class MeResponse(BaseModel):
|
||||
"""Response model for the current user's membership in an organization."""
|
||||
|
||||
org_id: str
|
||||
user_id: str
|
||||
email: str
|
||||
role: str
|
||||
llm_api_key: str
|
||||
max_iterations: int | None = None
|
||||
llm_model: str | None = None
|
||||
llm_api_key_for_byor: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _mask_key(secret: SecretStr | None) -> str:
|
||||
"""Mask an API key, showing only last 4 characters."""
|
||||
if secret is None:
|
||||
return ''
|
||||
raw = secret.get_secret_value()
|
||||
if not raw:
|
||||
return ''
|
||||
if len(raw) <= 4:
|
||||
return '****'
|
||||
return '****' + raw[-4:]
|
||||
|
||||
@classmethod
|
||||
def from_org_member(cls, member: OrgMember, role: Role, email: str) -> 'MeResponse':
|
||||
"""Create a MeResponse from an OrgMember, Role, and user email.
|
||||
|
||||
Args:
|
||||
member: The OrgMember entity
|
||||
role: The Role entity (provides role name)
|
||||
email: The user's email address
|
||||
|
||||
Returns:
|
||||
MeResponse with masked API keys
|
||||
"""
|
||||
return cls(
|
||||
org_id=str(member.org_id),
|
||||
user_id=str(member.user_id),
|
||||
email=email,
|
||||
role=role.name,
|
||||
llm_api_key=cls._mask_key(member.llm_api_key),
|
||||
max_iterations=member.max_iterations,
|
||||
llm_model=member.llm_model,
|
||||
llm_api_key_for_byor=cls._mask_key(member.llm_api_key_for_byor) or None,
|
||||
llm_base_url=member.llm_base_url,
|
||||
status=member.status,
|
||||
)
|
||||
|
||||
@@ -2,19 +2,37 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
OrgUpdate,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.services.org_member_service import OrgMemberService
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -61,6 +79,12 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch user to get current_org_id
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
current_org_id = (
|
||||
str(user.current_org_id) if user and user.current_org_id else None
|
||||
)
|
||||
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
@@ -69,7 +93,9 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
org_responses = [
|
||||
OrgResponse.from_org(org, credits=None, user_id=user_id) for org in orgs
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
@@ -80,7 +106,11 @@ async def list_user_orgs(
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
return OrgPage(
|
||||
items=org_responses,
|
||||
next_page_id=next_page_id,
|
||||
current_org_id=current_org_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@@ -136,7 +166,7 @@ async def create_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -174,23 +204,26 @@ async def create_org(
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
This endpoint retrieves details for a specific organization. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -211,7 +244,7 @@ async def get_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -228,26 +261,86 @@ async def get_org(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/me', response_model=MeResponse)
|
||||
async def get_me(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> MeResponse:
|
||||
"""Get the current user's membership record for an organization.
|
||||
|
||||
Returns the authenticated user's role, status, email, and LLM override
|
||||
fields (with masked API keys) within the specified organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
MeResponse: The user's membership data
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user is not a member or org doesn't exist
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving current member details',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
return OrgMemberService.get_me(org_id, user_uuid)
|
||||
|
||||
except OrgMemberNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Organization with id "{org_id}" not found',
|
||||
)
|
||||
except RoleNotFoundError as e:
|
||||
logger.exception(
|
||||
'Role not found for org member',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'role_id': e.role_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving member details',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
This endpoint permanently deletes an organization and all associated data including
|
||||
organization members, conversations, billing data, and external LiteLLM team resources.
|
||||
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
org_id: Organization ID to delete (UUID)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
@@ -303,6 +396,19 @@ async def delete_org(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrphanedUserError as e:
|
||||
logger.warning(
|
||||
'Cannot delete organization: users would be orphaned',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'orphaned_users': e.user_ids,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error during organization deletion',
|
||||
@@ -327,25 +433,26 @@ async def delete_org(
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
|
||||
permission, which is granted to admin and owner roles.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
org_id: Organization ID to update (UUID)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 409 if organization name already exists
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
@@ -368,7 +475,7 @@ async def update_org(
|
||||
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
|
||||
credits = await OrgService.get_org_credits(user_id, updated_org.id)
|
||||
|
||||
return OrgResponse.from_org(updated_org, credits=credits)
|
||||
return OrgResponse.from_org(updated_org, credits=credits, user_id=user_id)
|
||||
|
||||
except ValueError as e:
|
||||
# Organization not found
|
||||
@@ -376,6 +483,11 @@ async def update_org(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# User lacks permission for LLM settings
|
||||
raise HTTPException(
|
||||
@@ -400,3 +512,314 @@ async def update_org(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/members')
|
||||
async def get_org_members(
|
||||
org_id: UUID,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgMemberPage:
|
||||
"""Get all members of an organization with cursor-based pagination.
|
||||
|
||||
This endpoint retrieves a paginated list of organization members. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of members to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgMemberPage: Paginated list of organization members
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 400 if org_id or page_id format is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
success, error_code, data = await OrgMemberService.get_org_members(
|
||||
org_id=org_id,
|
||||
current_user_id=UUID(user_id),
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not success:
|
||||
error_map = {
|
||||
'not_a_member': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You are not a member of this organization',
|
||||
),
|
||||
'invalid_page_id': (
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'Invalid page_id format',
|
||||
),
|
||||
}
|
||||
status_code, detail = error_map.get(
|
||||
error_code, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
|
||||
if data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve members',
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error retrieving organization members')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve members',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}/members/{user_id}')
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Remove a member from an organization.
|
||||
|
||||
Only owners and admins can remove members:
|
||||
- Owners can remove admins and regular users
|
||||
- Admins can only remove regular users
|
||||
|
||||
Users cannot remove themselves. The last owner cannot be removed.
|
||||
"""
|
||||
try:
|
||||
success, error = await OrgMemberService.remove_org_member(
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
)
|
||||
|
||||
if not success:
|
||||
error_map = {
|
||||
'not_a_member': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You are not a member of this organization',
|
||||
),
|
||||
'cannot_remove_self': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'Cannot remove yourself from an organization',
|
||||
),
|
||||
'member_not_found': (
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
'Member not found in this organization',
|
||||
),
|
||||
'insufficient_permission': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You do not have permission to remove this member',
|
||||
),
|
||||
'cannot_remove_last_owner': (
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'Cannot remove the last owner of an organization',
|
||||
),
|
||||
'removal_failed': (
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'Failed to remove member',
|
||||
),
|
||||
}
|
||||
status_code, detail = error_map.get(
|
||||
error, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
|
||||
return {'message': 'Member removed successfully'}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization or user ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error removing organization member')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to remove member',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def switch_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Switch to a different organization.
|
||||
|
||||
This endpoint allows authenticated users to switch their current active
|
||||
organization. The user must be a member of the target organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to switch to (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details that was switched to
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 403 if user is not a member of the organization
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if switch fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to switch organization with membership validation
|
||||
org = await OrgService.switch_org(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM for the new current org
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed during organization switch',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to switch organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error switching organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
|
||||
async def update_org_member(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
update_data: OrgMemberUpdate,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
) -> OrgMemberResponse:
|
||||
"""Update a member's role in an organization.
|
||||
|
||||
Permission rules:
|
||||
- Admins can change roles of regular members to Admin or Member
|
||||
- Admins cannot modify other Admins or Owners
|
||||
- Owners can change roles of Admins and Members to any role (Owner, Admin, Member)
|
||||
- Owners cannot modify other Owners
|
||||
|
||||
Members cannot modify their own role. The last owner cannot be demoted.
|
||||
"""
|
||||
try:
|
||||
return await OrgMemberService.update_org_member(
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
update_data=update_data,
|
||||
)
|
||||
except OrgMemberNotFoundError as e:
|
||||
# Distinguish between requester not being a member vs target not found
|
||||
if str(current_user_id) in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You are not a member of this organization',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Member not found in this organization',
|
||||
)
|
||||
except CannotModifySelfError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Cannot modify your own role',
|
||||
)
|
||||
except RoleNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Role configuration error',
|
||||
)
|
||||
except InvalidRoleError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid role specified',
|
||||
)
|
||||
except InsufficientPermissionError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You do not have permission to modify this member',
|
||||
)
|
||||
except LastOwnerError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot demote the last owner of an organization',
|
||||
)
|
||||
except MemberUpdateError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization or user ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error updating organization member')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
|
||||
@@ -1139,6 +1139,71 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
}
|
||||
update_conversation_metadata(conversation_id, metadata_content)
|
||||
|
||||
async def list_files(self, sid: str, path: str | None = None) -> list[str]:
|
||||
"""List files in the workspace for a conversation.
|
||||
|
||||
Delegates to the nested container's list-files endpoint.
|
||||
|
||||
Args:
|
||||
sid: The session/conversation ID.
|
||||
path: Optional path to list files from. If None, lists from workspace root.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_list_files_from_nested(
|
||||
sid, nested_url, session_api_key, path
|
||||
)
|
||||
|
||||
async def select_file(self, sid: str, file: str) -> tuple[str | None, str | None]:
|
||||
"""Read a file from the workspace via nested container.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_select_file_from_nested(
|
||||
sid, nested_url, session_api_key, file
|
||||
)
|
||||
|
||||
async def upload_files(
|
||||
self, sid: str, files: list[tuple[str, bytes]]
|
||||
) -> tuple[list[str], list[dict[str, str]]]:
|
||||
"""Upload files to the workspace via nested container.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_upload_files_to_nested(
|
||||
sid, nested_url, session_api_key, files
|
||||
)
|
||||
|
||||
|
||||
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
|
||||
last_updated_at = conversation.last_updated_at
|
||||
|
||||
131
enterprise/server/services/email_service.py
Normal file
131
enterprise/server/services/email_service.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Email service for sending transactional emails via Resend."""
|
||||
|
||||
import os
|
||||
|
||||
try:
|
||||
import resend
|
||||
|
||||
RESEND_AVAILABLE = True
|
||||
except ImportError:
|
||||
RESEND_AVAILABLE = False
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
|
||||
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""Service for sending transactional emails."""
|
||||
|
||||
@staticmethod
|
||||
def _get_resend_client() -> bool:
|
||||
"""Initialize and return the Resend client.
|
||||
|
||||
Returns:
|
||||
bool: True if client is ready, False otherwise
|
||||
"""
|
||||
if not RESEND_AVAILABLE:
|
||||
logger.warning('Resend library not installed, skipping email')
|
||||
return False
|
||||
|
||||
resend_api_key = os.environ.get('RESEND_API_KEY')
|
||||
if not resend_api_key:
|
||||
logger.warning('RESEND_API_KEY not configured, skipping email')
|
||||
return False
|
||||
|
||||
resend.api_key = resend_api_key
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def send_invitation_email(
|
||||
to_email: str,
|
||||
org_name: str,
|
||||
inviter_name: str,
|
||||
role_name: str,
|
||||
invitation_token: str,
|
||||
invitation_id: int,
|
||||
) -> None:
|
||||
"""Send an organization invitation email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient's email address
|
||||
org_name: Name of the organization
|
||||
inviter_name: Display name of the person who sent the invite
|
||||
role_name: Role being offered (e.g., 'member', 'admin')
|
||||
invitation_token: The secure invitation token
|
||||
invitation_id: The invitation ID for logging
|
||||
"""
|
||||
if not EmailService._get_resend_client():
|
||||
return
|
||||
|
||||
# Build invitation URL
|
||||
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
|
||||
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
|
||||
|
||||
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
|
||||
|
||||
params = {
|
||||
'from': from_email,
|
||||
'to': [to_email],
|
||||
'subject': f"You're invited to join {org_name} on OpenHands",
|
||||
'html': f"""
|
||||
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<p>Hi,</p>
|
||||
|
||||
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
|
||||
|
||||
<p>Click the button below to accept the invitation:</p>
|
||||
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{invitation_url}"
|
||||
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
|
||||
text-decoration: none; border-radius: 8px; display: inline-block;
|
||||
font-size: 14px; font-weight: 600;">
|
||||
Accept Invitation
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
Or copy and paste this link into your browser:<br>
|
||||
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
This invitation will expire in 7 days.
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
If you weren't expecting this invitation, you can safely ignore this email.
|
||||
</p>
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
|
||||
|
||||
<p style="color: #999; font-size: 12px;">
|
||||
Best,<br>
|
||||
The OpenHands Team
|
||||
</p>
|
||||
</div>
|
||||
""",
|
||||
}
|
||||
|
||||
try:
|
||||
response = resend.Emails.send(params)
|
||||
logger.info(
|
||||
'Invitation email sent',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'response_id': response.get('id') if response else None,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise
|
||||
397
enterprise/server/services/org_invitation_service.py
Normal file
397
enterprise/server/services/org_invitation_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Service for managing organization invitations."""
|
||||
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import ROLE_ADMIN, ROLE_OWNER
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.email_service import EmailService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import OrgInvitationStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgInvitationService:
|
||||
"""Service for organization invitation operations."""
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the organization exists
|
||||
2. Validates this is not a personal workspace
|
||||
3. Checks inviter has owner/admin role
|
||||
4. Validates role assignment permissions
|
||||
5. Checks if user is already a member
|
||||
6. Creates the invitation
|
||||
7. Sends the invitation email
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
UserAlreadyMemberError: If email is already a member
|
||||
InvitationAlreadyExistsError: If pending invitation exists
|
||||
"""
|
||||
email = email.lower().strip()
|
||||
|
||||
logger.info(
|
||||
'Creating organization invitation',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
# Step 2: Check this is not a personal workspace
|
||||
# A personal workspace has org_id matching the user's id
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
# Step 3: Check inviter is a member and has permission
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
# Step 4: Validate role assignment permissions
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
# Get the target role
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 5: Check if user is already a member (by email)
|
||||
existing_user = await UserStore.get_user_by_email_async(email)
|
||||
if existing_user:
|
||||
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'User is already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 6: Create the invitation
|
||||
invitation = await OrgInvitationStore.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_id=target_role.id,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Step 7: Send invitation email
|
||||
try:
|
||||
# Get inviter info for the email
|
||||
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
|
||||
inviter_name = 'A team member'
|
||||
if inviter_user and inviter_user.email:
|
||||
inviter_name = inviter_user.email.split('@')[0]
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email=email,
|
||||
org_name=org.name,
|
||||
inviter_name=inviter_name,
|
||||
role_name=target_role.name,
|
||||
invitation_token=invitation.token,
|
||||
invitation_id=invitation.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'email': email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Don't fail the invitation creation if email fails
|
||||
# The user can still access via direct link
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def create_invitations_batch(
|
||||
org_id: UUID,
|
||||
emails: list[str],
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
|
||||
"""Create multiple organization invitations concurrently.
|
||||
|
||||
Validates permissions once upfront, then creates invitations in parallel.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
emails: List of invitee email addresses
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitations
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_invitations, failed_emails_with_errors)
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
"""
|
||||
logger.info(
|
||||
'Creating batch organization invitations',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email_count': len(emails),
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate permissions upfront (shared for all emails)
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 2: Create invitations concurrently
|
||||
async def create_single(
|
||||
email: str,
|
||||
) -> tuple[str, OrgInvitation | None, str | None]:
|
||||
"""Create single invitation, return (email, invitation, error)."""
|
||||
try:
|
||||
invitation = await OrgInvitationService.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_name=role_name,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
return (email, invitation, None)
|
||||
except (UserAlreadyMemberError, ValueError) as e:
|
||||
return (email, None, str(e))
|
||||
|
||||
results = await asyncio.gather(*[create_single(email) for email in emails])
|
||||
|
||||
# Step 3: Separate successes and failures
|
||||
successful: list[OrgInvitation] = []
|
||||
failed: list[tuple[str, str]] = []
|
||||
for email, invitation, error in results:
|
||||
if invitation:
|
||||
successful.append(invitation)
|
||||
elif error:
|
||||
failed.append((email, error))
|
||||
|
||||
logger.info(
|
||||
'Batch invitation creation completed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
},
|
||||
)
|
||||
|
||||
return successful, failed
|
||||
|
||||
@staticmethod
|
||||
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
|
||||
"""Accept an organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the token and invitation status
|
||||
2. Checks expiration
|
||||
3. Verifies user is not already a member
|
||||
4. Creates LiteLLM integration
|
||||
5. Adds user to the organization
|
||||
6. Marks invitation as accepted
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
user_id: The user accepting the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The accepted invitation
|
||||
|
||||
Raises:
|
||||
InvitationInvalidError: If token is invalid or invitation not pending
|
||||
InvitationExpiredError: If invitation has expired
|
||||
UserAlreadyMemberError: If user is already a member
|
||||
"""
|
||||
logger.info(
|
||||
'Accepting organization invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
|
||||
'user_id': str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Get and validate invitation
|
||||
invitation = await OrgInvitationStore.get_invitation_by_token(token)
|
||||
|
||||
if not invitation:
|
||||
raise InvitationInvalidError('Invalid invitation token')
|
||||
|
||||
if invitation.status != OrgInvitation.STATUS_PENDING:
|
||||
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
|
||||
raise InvitationInvalidError('Invitation has already been accepted')
|
||||
elif invitation.status == OrgInvitation.STATUS_REVOKED:
|
||||
raise InvitationInvalidError('Invitation has been revoked')
|
||||
else:
|
||||
raise InvitationInvalidError('Invitation is no longer valid')
|
||||
|
||||
# Step 2: Check expiration
|
||||
if OrgInvitationStore.is_token_expired(invitation):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
raise InvitationExpiredError('Invitation has expired')
|
||||
|
||||
# Step 2.5: Verify user email matches invitation email
|
||||
user = await UserStore.get_user_by_id_async(str(user_id))
|
||||
if not user:
|
||||
raise InvitationInvalidError('User not found')
|
||||
|
||||
user_email = user.email
|
||||
# Fallback: fetch email from Keycloak if not in database (for existing users)
|
||||
if not user_email:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
|
||||
user_email = user_info.get('email') if user_info else None
|
||||
|
||||
if not user_email:
|
||||
raise EmailMismatchError('Your account does not have an email address')
|
||||
|
||||
user_email = user_email.lower().strip()
|
||||
invitation_email = invitation.email.lower().strip()
|
||||
|
||||
if user_email != invitation_email:
|
||||
logger.warning(
|
||||
'Email mismatch during invitation acceptance',
|
||||
extra={
|
||||
'user_id': str(user_id),
|
||||
'user_email': user_email,
|
||||
'invitation_email': invitation_email,
|
||||
'invitation_id': invitation.id,
|
||||
},
|
||||
)
|
||||
raise EmailMismatchError()
|
||||
|
||||
# Step 3: Check if user is already a member
|
||||
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'You are already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 4: Create LiteLLM integration for the user in the new org
|
||||
try:
|
||||
settings = await OrgService.create_litellm_integration(
|
||||
invitation.org_id, str(user_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to create LiteLLM integration for invitation acceptance',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise InvitationInvalidError(
|
||||
'Failed to set up organization access. Please try again.'
|
||||
)
|
||||
|
||||
# Step 5: Add user to organization
|
||||
from storage.org_member_store import OrgMemberStore as OMS
|
||||
|
||||
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
|
||||
# Don't override with org defaults - use invitation-specified role
|
||||
org_member_kwargs.pop('llm_model', None)
|
||||
org_member_kwargs.pop('llm_base_url', None)
|
||||
|
||||
OrgMemberStore.add_user_to_org(
|
||||
org_id=invitation.org_id,
|
||||
user_id=user_id,
|
||||
role_id=invitation.role_id,
|
||||
llm_api_key=settings.llm_api_key,
|
||||
status='active',
|
||||
)
|
||||
|
||||
# Step 6: Mark invitation as accepted
|
||||
updated_invitation = await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id,
|
||||
OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization invitation accepted',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'role_id': invitation.role_id,
|
||||
},
|
||||
)
|
||||
|
||||
return updated_invitation
|
||||
342
enterprise/server/services/org_member_service.py
Normal file
342
enterprise/server/services/org_member_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Service for managing organization members."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import ROLE_ADMIN, ROLE_MEMBER, ROLE_OWNER
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class OrgMemberService:
|
||||
"""Service for organization member operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_me(org_id: UUID, user_id: UUID) -> MeResponse:
|
||||
"""Get the current user's membership record for an organization.
|
||||
|
||||
Retrieves the authenticated user's role, status, email, and LLM override
|
||||
fields (with masked API keys) within the specified organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: User ID (UUID)
|
||||
|
||||
Returns:
|
||||
MeResponse: The user's membership data with masked API keys
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If user is not a member of the organization
|
||||
RoleNotFoundError: If the role associated with the member is not found
|
||||
"""
|
||||
# Look up the user's membership in this org
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
if org_member is None:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(user_id))
|
||||
|
||||
# Resolve role name from role_id
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if role is None:
|
||||
raise RoleNotFoundError(org_member.role_id)
|
||||
|
||||
# Get user email
|
||||
user = UserStore.get_user_by_id(str(user_id))
|
||||
email = user.email if user and user.email else ''
|
||||
|
||||
return MeResponse.from_org_member(org_member, role, email)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members(
|
||||
org_id: UUID,
|
||||
current_user_id: UUID,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> tuple[bool, str | None, OrgMemberPage | None]:
|
||||
"""Get organization members with authorization check.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_code, data). If success is True, error_code is None.
|
||||
"""
|
||||
# Verify current user is a member of the organization
|
||||
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
|
||||
if not requester_membership:
|
||||
return False, 'not_a_member', None
|
||||
|
||||
# Parse page_id to get offset (page_id is offset encoded as string)
|
||||
offset = 0
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
if offset < 0:
|
||||
return False, 'invalid_page_id', None
|
||||
except ValueError:
|
||||
return False, 'invalid_page_id', None
|
||||
|
||||
# Call store to get paginated members
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=offset, limit=limit
|
||||
)
|
||||
|
||||
# Transform data to response format
|
||||
items = []
|
||||
for member in members:
|
||||
# Access user and role relationships (eagerly loaded)
|
||||
user = member.user
|
||||
role = member.role
|
||||
|
||||
items.append(
|
||||
OrgMemberResponse(
|
||||
user_id=str(member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=member.role_id,
|
||||
role=role.name if role else '',
|
||||
role_rank=role.rank if role else 0,
|
||||
status=member.status,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate next_page_id
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return True, None, OrgMemberPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
@staticmethod
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
target_user_id: UUID,
|
||||
current_user_id: UUID,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Remove a member from an organization.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_message). If success is True, error_message is None.
|
||||
"""
|
||||
|
||||
def _remove_member():
|
||||
# Get current user's membership in the org
|
||||
requester_membership = OrgMemberStore.get_org_member(
|
||||
org_id, current_user_id
|
||||
)
|
||||
if not requester_membership:
|
||||
return False, 'not_a_member'
|
||||
|
||||
# Check if trying to remove self
|
||||
if str(current_user_id) == str(target_user_id):
|
||||
return False, 'cannot_remove_self'
|
||||
|
||||
# Get target user's membership
|
||||
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
|
||||
if not target_membership:
|
||||
return False, 'member_not_found'
|
||||
|
||||
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
|
||||
target_role = RoleStore.get_role_by_id(target_membership.role_id)
|
||||
|
||||
if not requester_role or not target_role:
|
||||
return False, 'role_not_found'
|
||||
|
||||
# Check permission based on roles
|
||||
if not OrgMemberService._can_remove_member(
|
||||
requester_role.name, target_role.name
|
||||
):
|
||||
return False, 'insufficient_permission'
|
||||
|
||||
# Check if removing the last owner
|
||||
if target_role.name == ROLE_OWNER:
|
||||
if OrgMemberService._is_last_owner(org_id, target_user_id):
|
||||
return False, 'cannot_remove_last_owner'
|
||||
|
||||
# Perform the removal
|
||||
success = OrgMemberStore.remove_user_from_org(org_id, target_user_id)
|
||||
if not success:
|
||||
return False, 'removal_failed'
|
||||
|
||||
return True, None
|
||||
|
||||
return await call_sync_from_async(_remove_member)
|
||||
|
||||
@staticmethod
|
||||
async def update_org_member(
|
||||
org_id: UUID,
|
||||
target_user_id: UUID,
|
||||
current_user_id: UUID,
|
||||
update_data: OrgMemberUpdate,
|
||||
) -> OrgMemberResponse:
|
||||
"""Update a member's role in an organization.
|
||||
|
||||
Permission rules:
|
||||
- Admins can change roles of users (rank > ADMIN_RANK) to Admin or User
|
||||
- Admins cannot modify other Admins or Owners
|
||||
- Owners can change roles of non-owners (rank > OWNER_RANK) to any role
|
||||
- Owners cannot modify other Owners
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
target_user_id: User ID of the member to update
|
||||
current_user_id: User ID of the requester
|
||||
update_data: Update data containing fields to modify
|
||||
|
||||
Returns:
|
||||
OrgMemberResponse: The updated member data
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If requester or target is not a member
|
||||
CannotModifySelfError: If trying to modify self
|
||||
RoleNotFoundError: If role configuration is invalid
|
||||
InvalidRoleError: If new_role_name is not a valid role
|
||||
InsufficientPermissionError: If requester lacks permission
|
||||
LastOwnerError: If trying to demote the last owner
|
||||
MemberUpdateError: If update operation fails
|
||||
"""
|
||||
new_role_name = update_data.role
|
||||
|
||||
def _update_member():
|
||||
# Get current user's membership in the org
|
||||
requester_membership = OrgMemberStore.get_org_member(
|
||||
org_id, current_user_id
|
||||
)
|
||||
if not requester_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
|
||||
|
||||
# Check if trying to modify self
|
||||
if str(current_user_id) == str(target_user_id):
|
||||
raise CannotModifySelfError('modify')
|
||||
|
||||
# Get target user's membership
|
||||
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
|
||||
if not target_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(target_user_id))
|
||||
|
||||
# Get roles
|
||||
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
|
||||
target_role = RoleStore.get_role_by_id(target_membership.role_id)
|
||||
|
||||
if not requester_role:
|
||||
raise RoleNotFoundError(requester_membership.role_id)
|
||||
if not target_role:
|
||||
raise RoleNotFoundError(target_membership.role_id)
|
||||
|
||||
# If no role change requested, return current state
|
||||
if new_role_name is None:
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
return OrgMemberResponse(
|
||||
user_id=str(target_membership.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=target_membership.role_id,
|
||||
role=target_role.name,
|
||||
role_rank=target_role.rank,
|
||||
status=target_membership.status,
|
||||
)
|
||||
|
||||
# Validate new role exists
|
||||
new_role = RoleStore.get_role_by_name(new_role_name.lower())
|
||||
if not new_role:
|
||||
raise InvalidRoleError(new_role_name)
|
||||
|
||||
# Check permission to modify target
|
||||
if not OrgMemberService._can_update_member_role(
|
||||
requester_role.name, target_role.name, new_role.name
|
||||
):
|
||||
raise InsufficientPermissionError(
|
||||
'You do not have permission to modify this member'
|
||||
)
|
||||
|
||||
# Check if demoting the last owner
|
||||
if (
|
||||
target_role.name == ROLE_OWNER
|
||||
and new_role.name != ROLE_OWNER
|
||||
and OrgMemberService._is_last_owner(org_id, target_user_id)
|
||||
):
|
||||
raise LastOwnerError('demote')
|
||||
|
||||
# Perform the update
|
||||
updated_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id, target_user_id, new_role.id
|
||||
)
|
||||
if not updated_member:
|
||||
raise MemberUpdateError('Failed to update member')
|
||||
|
||||
# Get user email for response
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
|
||||
return OrgMemberResponse(
|
||||
user_id=str(updated_member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=updated_member.role_id,
|
||||
role=new_role.name,
|
||||
role_rank=new_role.rank,
|
||||
status=updated_member.status,
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_update_member)
|
||||
|
||||
@staticmethod
|
||||
def _can_update_member_role(
|
||||
requester_role_name: str, target_role_name: str, new_role_name: str
|
||||
) -> bool:
|
||||
"""Check if requester can change target's role to new_role.
|
||||
|
||||
Permission rules:
|
||||
- Owners can modify admins and users, can set any role
|
||||
- Owners cannot modify other owners
|
||||
- Admins can only modify users
|
||||
- Admins can only set admin or user roles (not owner)
|
||||
"""
|
||||
is_requester_owner = requester_role_name == ROLE_OWNER
|
||||
is_requester_admin = requester_role_name == ROLE_ADMIN
|
||||
is_target_owner = target_role_name == ROLE_OWNER
|
||||
is_target_admin = target_role_name == ROLE_ADMIN
|
||||
is_new_role_owner = new_role_name == ROLE_OWNER
|
||||
|
||||
if is_requester_owner:
|
||||
# Owners cannot modify other owners
|
||||
if is_target_owner:
|
||||
return False
|
||||
# Owners can set any role (owner, admin, user)
|
||||
return True
|
||||
elif is_requester_admin:
|
||||
# Admins cannot modify owners or other admins
|
||||
if is_target_owner or is_target_admin:
|
||||
return False
|
||||
# Admins can only set admin or user roles (not owner)
|
||||
return not is_new_role_owner
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _can_remove_member(requester_role_name: str, target_role_name: str) -> bool:
|
||||
"""Check if requester can remove target based on roles."""
|
||||
if requester_role_name == ROLE_OWNER:
|
||||
return True
|
||||
elif requester_role_name == ROLE_ADMIN:
|
||||
# Admins can only remove members (not owners or other admins)
|
||||
return target_role_name == ROLE_MEMBER
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_last_owner(org_id: UUID, user_id: UUID) -> bool:
|
||||
"""Check if user is the last owner of the organization."""
|
||||
members = OrgMemberStore.get_org_members(org_id)
|
||||
owners = []
|
||||
for m in members:
|
||||
# Use role_id (column) instead of role (relationship) to avoid DetachedInstanceError
|
||||
role = RoleStore.get_role_by_id(m.role_id)
|
||||
if role and role.name == ROLE_OWNER:
|
||||
owners.append(m)
|
||||
return len(owners) == 1 and str(owners[0].user_id) == str(user_id)
|
||||
@@ -22,11 +22,63 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.errors import AuthError
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _get_current_user(self) -> User | None:
|
||||
"""Get the current user using the existing db_session.
|
||||
|
||||
Uses self.db_session to avoid opening a separate database session.
|
||||
|
||||
Returns:
|
||||
User object or None if no user_id is available
|
||||
"""
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str:
|
||||
return None
|
||||
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
result = await self.db_session.execute(
|
||||
select(User).where(User.id == user_id_uuid)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _apply_user_and_org_filter(self, query):
|
||||
"""Apply user_id and org_id filters to ensure conversation isolation.
|
||||
|
||||
Filters conversations by:
|
||||
- user_id: Only show conversations belonging to the current user
|
||||
- org_id: Only show conversations belonging to the user's current organization
|
||||
|
||||
Args:
|
||||
query: SQLAlchemy query to apply filters to
|
||||
|
||||
Returns:
|
||||
Query with user and organization filters applied
|
||||
|
||||
Raises:
|
||||
AuthError: If no user_id is available (secure default: deny access)
|
||||
"""
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str:
|
||||
# Secure default: no user means no access, not "show everything"
|
||||
raise AuthError('User authentication required')
|
||||
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
# Filter by organization ID to ensure conversations are isolated per organization
|
||||
user = await self._get_current_user()
|
||||
if user and user.current_org_id is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadataSaas.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
@@ -38,13 +90,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
return await self._apply_user_and_org_filter(query)
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
@@ -57,13 +103,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
return await self._apply_user_and_org_filter(query)
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
@@ -155,21 +195,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
# Apply user and organization filtering
|
||||
query = await self._apply_user_and_org_filter(query)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
@@ -233,7 +268,13 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
# Fetch sub-conversation IDs
|
||||
sub_conversation_ids = await self.get_sub_conversation_ids(conversation_id)
|
||||
return self._to_info_with_user_id(
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
sub_conversation_ids=sub_conversation_ids,
|
||||
)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
@@ -262,8 +303,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
# Fetch sub-conversation IDs for each conversation
|
||||
sub_conversation_ids = await self.get_sub_conversation_ids(
|
||||
UUID(conversation_id)
|
||||
)
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
self._to_info_with_user_id(
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
sub_conversation_ids=sub_conversation_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
@@ -316,10 +365,11 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
info = self._to_info(stored, sub_conversation_ids=sub_conversation_ids)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
|
||||
@@ -20,8 +20,10 @@ from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
@@ -65,8 +67,10 @@ __all__ = [
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgInvitation',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'ResendSyncedUser',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
|
||||
@@ -126,7 +126,7 @@ class ApiKeyStore:
|
||||
|
||||
return True
|
||||
|
||||
async def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all API keys for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
@@ -134,24 +134,14 @@ class ApiKeyStore:
|
||||
|
||||
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
|
||||
with self.session_maker() as session:
|
||||
keys = (
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'id': key.id,
|
||||
'name': key.name,
|
||||
'created_at': key.created_at,
|
||||
'last_used_at': key.last_used_at,
|
||||
'expires_at': key.expires_at,
|
||||
}
|
||||
for key in keys
|
||||
if 'MCP_API_KEY' != key.name
|
||||
]
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
|
||||
@@ -10,7 +10,6 @@ import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
@@ -44,6 +43,34 @@ def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
def get_budget_from_team_info(
|
||||
user_team_info: dict | None, user_id: str, org_id: str
|
||||
) -> tuple[float, float]:
|
||||
"""Extract max_budget and spend from user team info.
|
||||
|
||||
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
|
||||
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
|
||||
|
||||
Args:
|
||||
user_team_info: The response from get_user_team_info
|
||||
user_id: The user's ID
|
||||
org_id: The organization's ID
|
||||
|
||||
Returns:
|
||||
Tuple of (max_budget, spend)
|
||||
"""
|
||||
if not user_team_info:
|
||||
return 0, 0
|
||||
spend = user_team_info.get('spend', 0)
|
||||
if user_id == org_id:
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
else:
|
||||
max_budget = user_team_info.get('max_budget_in_team') or 0
|
||||
return max_budget, spend
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
@@ -72,8 +99,33 @@ class LiteLlmManager:
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Check if team already exists and get its budget
|
||||
# New users joining existing orgs should inherit the team's budget
|
||||
team_budget = 0.0
|
||||
try:
|
||||
existing_team = await LiteLlmManager._get_team(client, org_id)
|
||||
if existing_team:
|
||||
team_info = existing_team.get('team_info', {})
|
||||
team_budget = team_info.get('max_budget', 0.0) or 0.0
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:existing_team_budget',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_budget': team_budget,
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Team doesn't exist yet (404) - this is expected for first user
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:no_existing_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
)
|
||||
|
||||
if create_user:
|
||||
@@ -82,7 +134,7 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
@@ -894,21 +946,31 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
team_response = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_response:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
for membership in team_response.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not user_membership:
|
||||
return None
|
||||
|
||||
# For team orgs (user_id != team_id), include team-level budget info
|
||||
# The team's max_budget and spend are shared across all members
|
||||
if keycloak_user_id != team_id:
|
||||
team_info = team_response.get('team_info', {})
|
||||
user_membership['max_budget_in_team'] = team_info.get('max_budget')
|
||||
user_membership['spend'] = team_info.get('spend', 0)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -46,10 +46,12 @@ class Org(Base): # type: ignore
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
byor_export_enabled = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
invitations = relationship('OrgInvitation', back_populates='org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
|
||||
59
enterprise/storage/org_invitation.py
Normal file
59
enterprise/storage/org_invitation.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization Invitation.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class OrgInvitation(Base): # type: ignore
|
||||
"""Organization invitation model.
|
||||
|
||||
Represents an invitation for a user to join an organization.
|
||||
Invitations are created by organization owners/admins and contain
|
||||
a secure token that can be used to accept the invitation.
|
||||
"""
|
||||
|
||||
__tablename__ = 'org_invitation'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
token = Column(String(64), nullable=False, unique=True, index=True)
|
||||
org_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('org.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
email = Column(String(255), nullable=False, index=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
status = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
server_default=text("'pending'"),
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
accepted_at = Column(DateTime, nullable=True)
|
||||
accepted_by_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('user.id'),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='invitations')
|
||||
role = relationship('Role')
|
||||
inviter = relationship('User', foreign_keys=[inviter_id])
|
||||
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
|
||||
|
||||
# Status constants
|
||||
STATUS_PENDING = 'pending'
|
||||
STATUS_ACCEPTED = 'accepted'
|
||||
STATUS_REVOKED = 'revoked'
|
||||
STATUS_EXPIRED = 'expired'
|
||||
227
enterprise/storage/org_invitation_store.py
Normal file
227
enterprise/storage/org_invitation_store.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Store class for managing organization invitations.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Invitation token configuration
|
||||
INVITATION_TOKEN_PREFIX = 'inv-'
|
||||
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
|
||||
DEFAULT_EXPIRATION_DAYS = 7
|
||||
|
||||
|
||||
class OrgInvitationStore:
|
||||
"""Store for managing organization invitations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
|
||||
"""Generate a secure invitation token.
|
||||
|
||||
Uses cryptographically secure random generation for tokens.
|
||||
Pattern from api_key_store.py.
|
||||
|
||||
Args:
|
||||
length: Length of the random part of the token
|
||||
|
||||
Returns:
|
||||
str: Token with prefix (e.g., 'inv-aBcDeF123...')
|
||||
"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_id: int,
|
||||
inviter_id: UUID,
|
||||
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_id: Role ID to assign on acceptance
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
expiration_days: Days until the invitation expires
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation record
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
|
||||
|
||||
invitation = OrgInvitation(
|
||||
token=token,
|
||||
org_id=org_id,
|
||||
email=email.lower().strip(),
|
||||
role_id=role_id,
|
||||
inviter_id=inviter_id,
|
||||
status=OrgInvitation.STATUS_PENDING,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(invitation)
|
||||
await session.commit()
|
||||
|
||||
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.id == invitation.id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
logger.info(
|
||||
'Created organization invitation',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'inviter_id': str(inviter_id),
|
||||
'expires_at': expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
|
||||
"""Get an invitation by its token.
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.token == token)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_invitation(
|
||||
org_id: UUID, email: str
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Get a pending invitation for an email in an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Email address to check
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if no pending invitation exists
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(
|
||||
and_(
|
||||
OrgInvitation.org_id == org_id,
|
||||
OrgInvitation.email == email.lower().strip(),
|
||||
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def update_invitation_status(
|
||||
invitation_id: int,
|
||||
status: str,
|
||||
accepted_by_user_id: Optional[UUID] = None,
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Update an invitation's status.
|
||||
|
||||
Args:
|
||||
invitation_id: The invitation ID
|
||||
status: New status (pending, accepted, revoked, expired)
|
||||
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
|
||||
|
||||
Returns:
|
||||
Updated OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
if not invitation:
|
||||
return None
|
||||
|
||||
old_status = invitation.status
|
||||
invitation.status = status
|
||||
|
||||
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
invitation.accepted_at = datetime.utcnow()
|
||||
invitation.accepted_by_user_id = accepted_by_user_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(invitation)
|
||||
|
||||
logger.info(
|
||||
'Updated invitation status',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'old_status': old_status,
|
||||
'new_status': status,
|
||||
'accepted_by_user_id': (
|
||||
str(accepted_by_user_id) if accepted_by_user_id else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
def is_token_expired(invitation: OrgInvitation) -> bool:
|
||||
"""Check if an invitation token has expired.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if expired, False otherwise
|
||||
"""
|
||||
# Use timezone-naive datetime for comparison (database stores without timezone)
|
||||
now = datetime.utcnow()
|
||||
return invitation.expires_at < now
|
||||
|
||||
@staticmethod
|
||||
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
|
||||
"""Check if invitation is expired and update status if needed.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if invitation was marked as expired, False otherwise
|
||||
"""
|
||||
if (
|
||||
invitation.status == OrgInvitation.STATUS_PENDING
|
||||
and OrgInvitationStore.is_token_expired(invitation)
|
||||
):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@@ -6,8 +6,10 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -59,6 +61,51 @@ class OrgMemberStore:
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
result = (
|
||||
session.query(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_org_member_for_current_org_async(
|
||||
user_id: UUID,
|
||||
) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
@@ -135,3 +182,36 @@ class OrgMemberStore:
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_paginated(
|
||||
org_id: UUID,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[OrgMember], bool]:
|
||||
"""Get paginated list of organization members with user and role info.
|
||||
|
||||
Returns:
|
||||
Tuple of (members_list, has_more) where has_more indicates if there are more results.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
# Query for limit + 1 items to determine if there are more results
|
||||
# Order by user_id for consistent pagination
|
||||
query = (
|
||||
select(OrgMember)
|
||||
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
|
||||
.filter(OrgMember.org_id == org_id)
|
||||
.order_by(OrgMember.user_id)
|
||||
.offset(offset)
|
||||
.limit(limit + 1)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
members = list(result.scalars().all())
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(members) > limit
|
||||
if has_more:
|
||||
# Remove the extra item
|
||||
members = members[:limit]
|
||||
|
||||
return members, has_more
|
||||
|
||||
@@ -521,6 +521,7 @@ class OrgService:
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
|
||||
OrgNameExistsError: If new name already exists for another organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -550,6 +551,24 @@ class OrgService:
|
||||
'User must be a member of the organization to update it'
|
||||
)
|
||||
|
||||
# Check if name is being updated and validate uniqueness
|
||||
if update_data.name is not None:
|
||||
# Check if new name conflicts with another org
|
||||
existing_org_with_name = OrgStore.get_org_by_name(update_data.name)
|
||||
if (
|
||||
existing_org_with_name is not None
|
||||
and existing_org_with_name.id != org_id
|
||||
):
|
||||
logger.warning(
|
||||
'Attempted to update organization with duplicate name',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'attempted_name': update_data.name,
|
||||
},
|
||||
)
|
||||
raise OrgNameExistsError(update_data.name)
|
||||
|
||||
# Check if update contains any LLM settings
|
||||
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
|
||||
if llm_fields_being_updated:
|
||||
@@ -637,10 +656,9 @@ class OrgService:
|
||||
)
|
||||
return None
|
||||
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(org_id)
|
||||
)
|
||||
spend = user_team_info.get('spend', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
|
||||
logger.debug(
|
||||
@@ -842,3 +860,94 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def check_byor_export_enabled(user_id: str) -> bool:
|
||||
"""Check if BYOR export is enabled for the user's current org.
|
||||
|
||||
Returns True if the user's current org has byor_export_enabled set to True.
|
||||
Returns False if the user is not found, has no current org, or the flag is False.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
|
||||
Returns:
|
||||
bool: True if BYOR export is enabled, False otherwise
|
||||
"""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
org = OrgStore.get_org_by_id(user.current_org_id)
|
||||
if not org:
|
||||
return False
|
||||
|
||||
return org.byor_export_enabled
|
||||
|
||||
@staticmethod
|
||||
async def switch_org(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Switch user's current organization to the specified organization.
|
||||
|
||||
This method:
|
||||
1. Validates that the organization exists
|
||||
2. Validates that the user is a member of the organization
|
||||
3. Updates the user's current_org_id
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID to switch to
|
||||
|
||||
Returns:
|
||||
Org: The organization that was switched to
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not a member of the organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Step 2: Validate user is a member of the organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'User attempted to switch to organization they are not a member of',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgAuthorizationError(
|
||||
'User must be a member of the organization to switch to it'
|
||||
)
|
||||
|
||||
# Step 3: Update user's current_org_id
|
||||
try:
|
||||
updated_user = UserStore.update_current_org(user_id, org_id)
|
||||
if not updated_user:
|
||||
raise OrgDatabaseError('User not found')
|
||||
|
||||
logger.info(
|
||||
'Successfully switched user organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except OrgDatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to switch user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to switch organization: {str(e)}')
|
||||
|
||||
@@ -10,6 +10,7 @@ from server.constants import (
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.routes.org_models import OrphanedUserError
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
@@ -320,17 +321,41 @@ class OrgStore:
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Delete organization memberships
|
||||
# 3. Handle users with this as current_org_id BEFORE deleting memberships
|
||||
# Single query to find orphaned users (those with no alternative org)
|
||||
orphaned_users = session.execute(
|
||||
text("""
|
||||
SELECT u.id
|
||||
FROM "user" u
|
||||
WHERE u.current_org_id = :org_id
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
).fetchall()
|
||||
|
||||
if orphaned_users:
|
||||
raise OrphanedUserError([str(row[0]) for row in orphaned_users])
|
||||
|
||||
# Batch update: reassign current_org_id to an alternative org for all affected users
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
text("""
|
||||
UPDATE "user" u
|
||||
SET current_org_id = (
|
||||
SELECT om.org_id FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE u.current_org_id = :org_id
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Handle users with this as current_org_id
|
||||
# 4. Delete organization memberships (now safe)
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
|
||||
),
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
|
||||
35
enterprise/storage/resend_synced_user.py
Normal file
35
enterprise/storage/resend_synced_user.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""SQLAlchemy model for tracking users synced to Resend audiences."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class ResendSyncedUser(Base): # type: ignore
|
||||
"""Tracks users that have been synced to a Resend audience.
|
||||
|
||||
This table ensures that once a user is synced to a Resend audience,
|
||||
they won't be re-added even if they are later deleted from the
|
||||
Resend UI. This respects manual deletions/unsubscribes.
|
||||
"""
|
||||
|
||||
__tablename__ = 'resend_synced_users'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
email = Column(String, nullable=False, index=True)
|
||||
audience_id = Column(String, nullable=False, index=True)
|
||||
synced_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
keycloak_user_id = Column(String, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
125
enterprise/storage/resend_synced_user_store.py
Normal file
125
enterprise/storage/resend_synced_user_store.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Store class for managing Resend synced users."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Set
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResendSyncedUserStore:
|
||||
"""Store for tracking users synced to Resend audiences."""
|
||||
|
||||
session_maker: sessionmaker
|
||||
|
||||
def is_user_synced(self, email: str, audience_id: str) -> bool:
|
||||
"""Check if a user has been synced to a specific audience.
|
||||
|
||||
Args:
|
||||
email: The email address to check.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
True if the user has been synced, False otherwise.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = select(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt).first()
|
||||
return result is not None
|
||||
|
||||
def get_synced_emails_for_audience(self, audience_id: str) -> Set[str]:
|
||||
"""Get all synced email addresses for a specific audience.
|
||||
|
||||
Args:
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
A set of lowercase email addresses that have been synced.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = select(ResendSyncedUser.email).where(
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt).scalars().all()
|
||||
return set(result)
|
||||
|
||||
def mark_user_synced(
|
||||
self,
|
||||
email: str,
|
||||
audience_id: str,
|
||||
keycloak_user_id: Optional[str] = None,
|
||||
) -> ResendSyncedUser:
|
||||
"""Mark a user as synced to a specific audience.
|
||||
|
||||
Uses upsert to handle race conditions - if the user is already
|
||||
marked as synced, this is a no-op.
|
||||
|
||||
Args:
|
||||
email: The email address of the user.
|
||||
audience_id: The Resend audience ID.
|
||||
keycloak_user_id: Optional Keycloak user ID.
|
||||
|
||||
Returns:
|
||||
The ResendSyncedUser record.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the record could not be created or retrieved.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = (
|
||||
insert(ResendSyncedUser)
|
||||
.values(
|
||||
email=email.lower(),
|
||||
audience_id=audience_id,
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
synced_at=datetime.now(UTC),
|
||||
)
|
||||
.on_conflict_do_nothing(constraint='uq_resend_synced_email_audience')
|
||||
.returning(ResendSyncedUser)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0]
|
||||
|
||||
# on_conflict_do_nothing triggered, fetch the existing record
|
||||
existing = session.execute(
|
||||
select(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
return existing[0]
|
||||
|
||||
raise RuntimeError(
|
||||
f'Failed to create or retrieve synced user record for {email}'
|
||||
)
|
||||
|
||||
def remove_synced_user(self, email: str, audience_id: str) -> bool:
|
||||
"""Remove a user's synced status for a specific audience.
|
||||
|
||||
Args:
|
||||
email: The email address of the user.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
True if a record was deleted, False if no record existed.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = delete(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -29,6 +29,20 @@ class RoleStore:
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_id_async(
|
||||
role_id: int,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by ID (async version)."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
|
||||
@@ -24,6 +24,7 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -106,13 +107,13 @@ class SaasSettingsStore(SettingsStore):
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations is not None:
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model is not None:
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor is not None:
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url is not None:
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
if org.v1_enabled is None:
|
||||
kwargs['v1_enabled'] = True
|
||||
@@ -161,9 +162,12 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
|
||||
# Check if we need to generate an LLM key.
|
||||
is_openhands_provider = self._is_openhands_provider(item)
|
||||
if is_openhands_provider or item.llm_base_url == LITE_LLM_API_URL:
|
||||
await self._ensure_api_key(item, str(org_id), is_openhands_provider)
|
||||
if not item.llm_base_url:
|
||||
item.llm_base_url = LITE_LLM_API_URL
|
||||
if item.llm_base_url == LITE_LLM_API_URL:
|
||||
await self._ensure_api_key(
|
||||
item, str(org_id), openhands_type=is_openhands_model(item.llm_model)
|
||||
)
|
||||
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
for model in (user, org, org_member):
|
||||
@@ -228,10 +232,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_api_key(
|
||||
self, item: Settings, org_id: str, openhands_type: bool = False
|
||||
) -> None:
|
||||
|
||||
@@ -5,6 +5,7 @@ Store class for managing users.
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
@@ -82,6 +83,8 @@ class UserStore:
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
user.email = user_info.get('email')
|
||||
user.email_verified = user_info.get('email_verified')
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
@@ -172,6 +175,19 @@ class UserStore:
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# Check if user has completed billing sessions to enable BYOR export
|
||||
from storage.billing_session import BillingSession
|
||||
|
||||
has_completed_billing = (
|
||||
session.query(BillingSession)
|
||||
.filter(
|
||||
BillingSession.user_id == user_id,
|
||||
BillingSession.status == 'completed',
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
@@ -180,6 +196,7 @@ class UserStore:
|
||||
contact_name=resolve_display_name(user_info)
|
||||
or user_info.get('username', ''),
|
||||
contact_email=user_info['email'],
|
||||
byor_export_enabled=has_completed_billing,
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
@@ -753,12 +770,62 @@ class UserStore:
|
||||
finally:
|
||||
await UserStore._release_user_creation_lock(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_email_async(email: str) -> Optional[User]:
|
||||
"""Get user by email address (async version).
|
||||
|
||||
This method looks up a user by their email address. Note that email
|
||||
addresses may not be unique across all users in rare cases.
|
||||
|
||||
Args:
|
||||
email: The email address to search for
|
||||
|
||||
Returns:
|
||||
User: The user with the matching email, or None if not found
|
||||
"""
|
||||
if not email:
|
||||
return None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.email == email.lower().strip())
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
@staticmethod
|
||||
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
|
||||
"""Update the user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
org_id: The organization ID to set as current
|
||||
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.current_org_id = org_id
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
|
||||
"""Update contact_name on the personal org if it still has a username-style value.
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import asyncio
|
||||
import asyncio # noqa: I001
|
||||
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
# This must be before the import of storage
|
||||
# to set up logging and prevent alembic from
|
||||
# running its mouth.
|
||||
from openhands.core.logger import openhands_logger
|
||||
|
||||
from storage.proactive_conversation_store import (
|
||||
ProactiveConversationStore,
|
||||
)
|
||||
|
||||
OLDER_THAN = 30 # 30 minutes
|
||||
|
||||
|
||||
async def main():
|
||||
openhands_logger.info('clean_proactive_convo_table')
|
||||
convo_store = ProactiveConversationStore()
|
||||
await convo_store.clean_old_convos(older_than_minutes=OLDER_THAN)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ Optional environment variables:
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -34,6 +35,7 @@ import resend
|
||||
from keycloak.exceptions import KeycloakError
|
||||
from resend.exceptions import ResendError
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.resend_synced_user_store import ResendSyncedUserStore
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -68,9 +70,6 @@ RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
|
||||
# Set up Resend API
|
||||
resend.api_key = RESEND_API_KEY
|
||||
|
||||
print('resend module', resend)
|
||||
print('has contacts', hasattr(resend, 'Contacts'))
|
||||
|
||||
|
||||
class ResendSyncError(Exception):
|
||||
"""Base exception for Resend sync errors."""
|
||||
@@ -90,6 +89,31 @@ class ResendAPIError(ResendSyncError):
|
||||
pass
|
||||
|
||||
|
||||
# Email validation regex pattern - matches standard email format
|
||||
# This pattern is intentionally strict to avoid Resend API validation errors
|
||||
# It rejects special characters like ! that some email providers technically allow
|
||||
# but Resend's API does not accept
|
||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
|
||||
|
||||
|
||||
def is_valid_email(email: str) -> bool:
|
||||
"""Validate an email address format.
|
||||
|
||||
This uses a regex pattern that matches most valid email addresses
|
||||
while rejecting addresses with special characters that Resend's API
|
||||
does not accept (e.g., exclamation marks).
|
||||
|
||||
Args:
|
||||
email: The email address to validate.
|
||||
|
||||
Returns:
|
||||
True if the email is valid, False otherwise.
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
return bool(EMAIL_REGEX.match(email))
|
||||
|
||||
|
||||
def get_keycloak_users(offset: int = 0, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get users from Keycloak using the admin client.
|
||||
|
||||
@@ -173,8 +197,6 @@ def get_resend_contacts(audience_id: str) -> Dict[str, Dict[str, Any]]:
|
||||
Raises:
|
||||
ResendAPIError: If the API call fails.
|
||||
"""
|
||||
print('getting resend contacts')
|
||||
print('has resend contacts', hasattr(resend, 'Contacts'))
|
||||
try:
|
||||
contacts = resend.Contacts.list(audience_id).get('data', [])
|
||||
# Create a dictionary mapping email addresses to contact data for
|
||||
@@ -291,8 +313,84 @@ def send_welcome_email(
|
||||
raise
|
||||
|
||||
|
||||
def _get_resend_synced_user_store() -> ResendSyncedUserStore:
|
||||
"""Get the ResendSyncedUserStore instance.
|
||||
|
||||
This is separated into a function to allow for easier testing/mocking.
|
||||
"""
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
config = get_global_config()
|
||||
db_session_injector = config.db_session
|
||||
return ResendSyncedUserStore(session_maker=db_session_injector.get_session_maker())
|
||||
|
||||
|
||||
def _backfill_existing_resend_contacts(
|
||||
synced_user_store: ResendSyncedUserStore,
|
||||
audience_id: str,
|
||||
) -> int:
|
||||
"""Backfill the synced_users table with contacts already in Resend.
|
||||
|
||||
This ensures that users who were added to Resend before the tracking
|
||||
table existed are properly recorded, preventing duplicate welcome emails.
|
||||
|
||||
Args:
|
||||
synced_user_store: The store for tracking synced users.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
The number of contacts backfilled.
|
||||
"""
|
||||
logger.info('Starting backfill of existing Resend contacts...')
|
||||
|
||||
try:
|
||||
resend_contacts = get_resend_contacts(audience_id)
|
||||
logger.info(f'Found {len(resend_contacts)} contacts in Resend audience')
|
||||
|
||||
already_synced_emails = synced_user_store.get_synced_emails_for_audience(
|
||||
audience_id
|
||||
)
|
||||
logger.info(
|
||||
f'Found {len(already_synced_emails)} already synced emails in database'
|
||||
)
|
||||
|
||||
backfilled_count = 0
|
||||
for email in resend_contacts:
|
||||
if email.lower() not in already_synced_emails:
|
||||
synced_user_store.mark_user_synced(
|
||||
email=email,
|
||||
audience_id=audience_id,
|
||||
keycloak_user_id=None, # We don't have this info during backfill
|
||||
)
|
||||
backfilled_count += 1
|
||||
logger.debug(f'Backfilled existing Resend contact: {email}')
|
||||
|
||||
logger.info(
|
||||
f'Backfill completed: {backfilled_count} contacts added to tracking'
|
||||
)
|
||||
return backfilled_count
|
||||
|
||||
except Exception:
|
||||
logger.exception('Error during backfill of existing Resend contacts')
|
||||
# Don't fail the entire sync if backfill fails - just log and continue
|
||||
return 0
|
||||
|
||||
|
||||
def sync_users_to_resend():
|
||||
"""Sync users from Keycloak to Resend."""
|
||||
"""Sync users from Keycloak to Resend.
|
||||
|
||||
This function syncs users from Keycloak to a Resend audience. It tracks
|
||||
which users have been synced in the database to ensure that:
|
||||
1. Users are only added once (even across multiple sync runs)
|
||||
2. Users who are manually deleted from Resend are not re-added
|
||||
|
||||
The tracking is done via the resend_synced_users table, which records
|
||||
each email/audience_id combination that has been synced.
|
||||
|
||||
On first run (or when new contacts exist in Resend), it will backfill
|
||||
the tracking table with existing Resend contacts to avoid sending
|
||||
duplicate welcome emails.
|
||||
"""
|
||||
# Check required environment variables
|
||||
required_vars = {
|
||||
'RESEND_API_KEY': RESEND_API_KEY,
|
||||
@@ -318,27 +416,36 @@ def sync_users_to_resend():
|
||||
)
|
||||
|
||||
try:
|
||||
# Get the store for tracking synced users
|
||||
synced_user_store = _get_resend_synced_user_store()
|
||||
|
||||
# Backfill existing Resend contacts into our tracking table
|
||||
# This ensures users already in Resend don't get duplicate welcome emails
|
||||
backfilled_count = _backfill_existing_resend_contacts(
|
||||
synced_user_store, RESEND_AUDIENCE_ID
|
||||
)
|
||||
|
||||
# Get the total number of users
|
||||
total_users = get_total_keycloak_users()
|
||||
logger.info(
|
||||
f'Found {total_users} users in Keycloak realm {KEYCLOAK_REALM_NAME}'
|
||||
)
|
||||
|
||||
# Get contacts from Resend
|
||||
resend_contacts = get_resend_contacts(RESEND_AUDIENCE_ID)
|
||||
logger.info(
|
||||
f'Found {len(resend_contacts)} contacts in Resend audience '
|
||||
f'{RESEND_AUDIENCE_ID}'
|
||||
)
|
||||
|
||||
# Stats
|
||||
stats = {
|
||||
'total_users': total_users,
|
||||
'existing_contacts': len(resend_contacts),
|
||||
'backfilled_contacts': backfilled_count,
|
||||
'already_synced': 0,
|
||||
'added_contacts': 0,
|
||||
'skipped_invalid_emails': 0,
|
||||
'errors': 0,
|
||||
}
|
||||
|
||||
synced_emails = synced_user_store.get_synced_emails_for_audience(
|
||||
RESEND_AUDIENCE_ID
|
||||
)
|
||||
logger.info(f'Found {len(synced_emails)} already synced emails in database')
|
||||
|
||||
# Process users in batches
|
||||
offset = 0
|
||||
while offset < total_users:
|
||||
@@ -351,39 +458,65 @@ def sync_users_to_resend():
|
||||
continue
|
||||
|
||||
email = email.lower()
|
||||
if email in resend_contacts:
|
||||
logger.debug(f'User {email} already exists in Resend, skipping')
|
||||
|
||||
if email in synced_emails:
|
||||
logger.debug(
|
||||
f'User {email} was already synced to this audience, skipping'
|
||||
)
|
||||
stats['already_synced'] += 1
|
||||
continue
|
||||
|
||||
# Validate email format before attempting to add to Resend
|
||||
if not is_valid_email(email):
|
||||
logger.warning(f'Skipping user with invalid email format: {email}')
|
||||
stats['skipped_invalid_emails'] += 1
|
||||
continue
|
||||
|
||||
first_name = user.get('first_name')
|
||||
last_name = user.get('last_name')
|
||||
keycloak_user_id = user.get('id')
|
||||
|
||||
# Mark as synced first (optimistic) to ensure consistency.
|
||||
# If Resend API fails, we remove the record.
|
||||
try:
|
||||
synced_user_store.mark_user_synced(
|
||||
email=email,
|
||||
audience_id=RESEND_AUDIENCE_ID,
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to mark user {email} as synced')
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
first_name = user.get('first_name')
|
||||
last_name = user.get('last_name')
|
||||
|
||||
# Add the contact to the Resend audience
|
||||
add_contact_to_resend(
|
||||
RESEND_AUDIENCE_ID, email, first_name, last_name
|
||||
)
|
||||
logger.info(f'Added user {email} to Resend')
|
||||
stats['added_contacts'] += 1
|
||||
|
||||
# Sleep to respect rate limit after first API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
# Send a welcome email to the newly added contact
|
||||
try:
|
||||
send_welcome_email(email, first_name, last_name)
|
||||
logger.info(f'Sent welcome email to {email}')
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f'Failed to send welcome email to {email}, but contact was added to audience'
|
||||
)
|
||||
# Continue with the sync process even if sending the welcome email fails
|
||||
|
||||
# Sleep to respect rate limit after second API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
except Exception:
|
||||
logger.exception(f'Error adding user {email} to Resend')
|
||||
synced_user_store.remove_synced_user(email, RESEND_AUDIENCE_ID)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
synced_emails.add(email)
|
||||
stats['added_contacts'] += 1
|
||||
|
||||
# Sleep to respect rate limit after first API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
# Send a welcome email to the newly added contact
|
||||
try:
|
||||
send_welcome_email(email, first_name, last_name)
|
||||
logger.info(f'Sent welcome email to {email}')
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f'Failed to send welcome email to {email}, but contact was added to audience'
|
||||
)
|
||||
|
||||
# Sleep to respect rate limit after second API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
|
||||
@@ -126,3 +126,24 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
|
||||
# Should be a different instance than the original (copied after handler runs)
|
||||
assert result is not agent
|
||||
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
|
||||
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
|
||||
def test_run_agent_variant_tests_v1_preserves_planning_agent_system_prompt():
|
||||
"""Planning agents should retain their specialized system prompt and not be overwritten by the experiment."""
|
||||
# Arrange
|
||||
planning_agent = make_agent().model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_planning.j2'}
|
||||
)
|
||||
conv_id = uuid4()
|
||||
|
||||
# Act
|
||||
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
|
||||
user_id='user-planning',
|
||||
conversation_id=conv_id,
|
||||
agent=planning_agent,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.system_prompt_filename == 'system_prompt_planning.j2'
|
||||
|
||||
@@ -141,12 +141,14 @@ def test_custom_to_static_conversion():
|
||||
|
||||
def create_provider_tokens(
|
||||
tokens_dict: dict[ProviderType, str],
|
||||
) -> dict[ProviderType, ProviderToken]:
|
||||
"""Helper to create provider tokens dictionary."""
|
||||
return {
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
) -> MappingProxyType:
|
||||
"""Helper to create provider tokens as MappingProxyType."""
|
||||
return MappingProxyType(
|
||||
{
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,3 +266,63 @@ async def test_get_latest_token_can_be_used_with_static_secret(
|
||||
# Assert - this should NOT raise a ValidationError
|
||||
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
|
||||
assert static_secret.get_value() == token_value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for get_authenticated_git_url - ensuring proper authenticated URLs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_git_url_raises_when_no_tokens(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_authenticated_git_url raises error when no provider tokens available."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match='No provider tokens available'):
|
||||
await resolver_context.get_authenticated_git_url('owner/repo')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_handler_caches_instance(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that _get_provider_handler caches the handler instance."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
|
||||
|
||||
# Act - call _get_provider_handler twice
|
||||
handler1 = await resolver_context._get_provider_handler()
|
||||
handler2 = await resolver_context._get_provider_handler()
|
||||
|
||||
# Assert - should be the same instance (cached)
|
||||
assert handler1 is handler2
|
||||
# get_provider_tokens should only be called once
|
||||
assert mock_saas_user_auth.get_provider_tokens.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_handler_creates_handler_with_correct_params(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that _get_provider_handler creates ProviderHandler with correct parameters."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
|
||||
|
||||
# Act
|
||||
handler = await resolver_context._get_provider_handler()
|
||||
|
||||
# Assert
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
|
||||
assert isinstance(handler, ProviderHandler)
|
||||
assert handler.provider_tokens == provider_tokens
|
||||
|
||||
@@ -6,6 +6,9 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
ByorPermittedResponse,
|
||||
LlmApiKeyResponse,
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
@@ -182,16 +185,18 @@ class TestGetLlmApiKeyForByor:
|
||||
"""Test the get_llm_api_key_for_byor endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_no_key_in_database_generates_new(
|
||||
self, mock_get_key, mock_generate_key, mock_store_key
|
||||
self, mock_get_key, mock_generate_key, mock_store_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when no key exists in database, a new one is generated."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = None
|
||||
mock_generate_key.return_value = new_key
|
||||
mock_store_key.return_value = None
|
||||
@@ -200,21 +205,24 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_generate_key.assert_called_once_with(user_id)
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
self, mock_get_key, mock_verify_key
|
||||
self, mock_get_key, mock_verify_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when a valid key exists in database, it is returned."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
existing_key = 'sk-existing-valid-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = existing_key
|
||||
mock_verify_key.return_value = True
|
||||
|
||||
@@ -222,11 +230,13 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': existing_key}
|
||||
assert result == LlmApiKeyResponse(key=existing_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_verify_key.assert_called_once_with(existing_key, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -239,12 +249,14 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_delete_key,
|
||||
mock_generate_key,
|
||||
mock_store_key,
|
||||
mock_check_enabled,
|
||||
):
|
||||
"""Test that when an invalid key exists in database, it is regenerated."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
invalid_key = 'sk-invalid-key'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = invalid_key
|
||||
mock_verify_key.return_value = False
|
||||
mock_delete_key.return_value = True
|
||||
@@ -255,7 +267,8 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_verify_key.assert_called_once_with(invalid_key, user_id)
|
||||
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||
@@ -263,6 +276,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -275,12 +289,14 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_delete_key,
|
||||
mock_generate_key,
|
||||
mock_store_key,
|
||||
mock_check_enabled,
|
||||
):
|
||||
"""Test that even if deletion fails, regeneration still proceeds."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
invalid_key = 'sk-invalid-key'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = invalid_key
|
||||
mock_verify_key.return_value = False
|
||||
mock_delete_key.return_value = False # Deletion fails
|
||||
@@ -291,20 +307,23 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||
mock_generate_key.assert_called_once_with(user_id)
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_key_generation_failure_raises_exception(
|
||||
self, mock_get_key, mock_generate_key
|
||||
self, mock_get_key, mock_generate_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when key generation fails, an HTTPException is raised."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = None
|
||||
mock_generate_key.return_value = None
|
||||
|
||||
@@ -316,11 +335,15 @@ class TestGetLlmApiKeyForByor:
|
||||
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_database_error_raises_exception(self, mock_get_key):
|
||||
async def test_database_error_raises_exception(
|
||||
self, mock_get_key, mock_check_enabled
|
||||
):
|
||||
"""Test that database errors are properly handled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.side_effect = Exception('Database connection error')
|
||||
|
||||
# Act & Assert
|
||||
@@ -330,6 +353,21 @@ class TestGetLlmApiKeyForByor:
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
|
||||
"""Test that when BYOR export is disabled, 402 is returned."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 402
|
||||
assert 'BYOR key export is not enabled' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteByorKeyFromLitellm:
|
||||
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
|
||||
@@ -425,3 +463,52 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCheckByorPermitted:
|
||||
"""Test the check_byor_permitted endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_permitted_when_enabled(self, mock_check_enabled):
|
||||
"""Test that permitted=True is returned when BYOR export is enabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == ByorPermittedResponse(permitted=True)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_not_permitted_when_disabled(self, mock_check_enabled):
|
||||
"""Test that permitted=False is returned when BYOR export is disabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == ByorPermittedResponse(permitted=False)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_error_raises_500(self, mock_check_enabled):
|
||||
"""Test that an exception raises 500 error."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.side_effect = Exception('Database error')
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_byor_permitted(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1711
enterprise/tests/unit/server/services/test_org_member_service.py
Normal file
1711
enterprise/tests/unit/server/services/test_org_member_service.py
Normal file
File diff suppressed because it is too large
Load Diff
158
enterprise/tests/unit/storage/test_resend_synced_user_store.py
Normal file
158
enterprise/tests/unit/storage/test_resend_synced_user_store.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Unit tests for ResendSyncedUserStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Import directly from the module files to avoid loading all of storage/__init__.py
|
||||
# which has many dependencies
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
from storage.resend_synced_user_store import ResendSyncedUserStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(mock_session_maker):
|
||||
"""Create ResendSyncedUserStore instance."""
|
||||
return ResendSyncedUserStore(session_maker=mock_session_maker)
|
||||
|
||||
|
||||
class TestResendSyncedUserStore:
|
||||
"""Test cases for ResendSyncedUserStore."""
|
||||
|
||||
def test_is_user_synced_returns_true_when_exists(self, store, mock_session):
|
||||
"""Test is_user_synced returns True when user exists in database."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_row = MagicMock()
|
||||
mock_session.execute.return_value.first.return_value = mock_row
|
||||
|
||||
result = store.is_user_synced(email, audience_id)
|
||||
|
||||
assert result is True
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_is_user_synced_returns_false_when_not_exists(self, store, mock_session):
|
||||
"""Test is_user_synced returns False when user doesn't exist."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
result = store.is_user_synced(email, audience_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
|
||||
"""Test that is_user_synced normalizes email to lowercase."""
|
||||
email = 'TEST@EXAMPLE.COM'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
store.is_user_synced(email, audience_id)
|
||||
|
||||
# Verify the query was called (we can't easily check the exact SQL)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_creates_new_record(self, store, mock_session):
|
||||
"""Test that mark_user_synced creates a new record."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
keycloak_user_id = 'kc-user-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = store.mark_user_synced(email, audience_id, keycloak_user_id)
|
||||
|
||||
assert result == mock_synced_user
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_handles_existing_record(self, store, mock_session):
|
||||
"""Test that mark_user_synced handles conflict (existing record)."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
# First execute (insert) returns None (conflict occurred)
|
||||
# Second execute (select existing) returns the record
|
||||
mock_existing_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result_insert = MagicMock()
|
||||
mock_result_insert.first.return_value = None
|
||||
|
||||
mock_result_select = MagicMock()
|
||||
mock_result_select.first.return_value = (mock_existing_user,)
|
||||
|
||||
mock_session.execute.side_effect = [mock_result_insert, mock_result_select]
|
||||
|
||||
result = store.mark_user_synced(email, audience_id)
|
||||
|
||||
assert result == mock_existing_user
|
||||
assert mock_session.execute.call_count == 2
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
|
||||
"""Test that mark_user_synced normalizes email to lowercase."""
|
||||
email = 'TEST@EXAMPLE.COM'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
store.mark_user_synced(email, audience_id)
|
||||
|
||||
# Verify execute was called (the email normalization happens in the SQL)
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_without_keycloak_user_id(self, store, mock_session):
|
||||
"""Test that mark_user_synced works without keycloak_user_id."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = store.mark_user_synced(email, audience_id)
|
||||
|
||||
assert result == mock_synced_user
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestResendSyncedUser:
|
||||
"""Test cases for ResendSyncedUser model."""
|
||||
|
||||
def test_model_has_required_fields(self):
|
||||
"""Test that the model has all required fields."""
|
||||
assert hasattr(ResendSyncedUser, 'id')
|
||||
assert hasattr(ResendSyncedUser, 'email')
|
||||
assert hasattr(ResendSyncedUser, 'audience_id')
|
||||
assert hasattr(ResendSyncedUser, 'synced_at')
|
||||
assert hasattr(ResendSyncedUser, 'keycloak_user_id')
|
||||
|
||||
def test_model_table_name(self):
|
||||
"""Test the model's table name."""
|
||||
assert ResendSyncedUser.__tablename__ == 'resend_synced_users'
|
||||
@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
# Test UUIDs
|
||||
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
|
||||
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
|
||||
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session with pre-populated Org and User rows for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
# Insert Orgs first (required for User foreign key)
|
||||
org1 = Org(
|
||||
id=ORG1_ID,
|
||||
name='test-org-1',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
org2 = Org(
|
||||
id=ORG2_ID,
|
||||
name='test-org-2',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
db_session.add(org1)
|
||||
db_session.add(org2)
|
||||
await db_session.flush()
|
||||
|
||||
# Insert Users
|
||||
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
|
||||
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
|
||||
db_session.add(user1)
|
||||
db_session.add(user2)
|
||||
await db_session.commit()
|
||||
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
async def test_secure_select_includes_user_and_org_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
"""Test that _secure_select method includes both user_id and org_id filtering."""
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
query = await service._secure_select()
|
||||
|
||||
# Convert query to string to verify filters are present
|
||||
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||
|
||||
# Verify user_id filter is present
|
||||
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
|
||||
|
||||
# Verify org_id filter is present (user1 is in org1)
|
||||
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
async def test_user_isolation_different_users(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
"""Test that different users cannot see each other's conversations."""
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
@@ -366,3 +347,142 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_user_org_switching_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that the same user switching orgs cannot see conversations from other orgs.
|
||||
|
||||
This tests the actual bug scenario: a user creates a conversation in org1,
|
||||
then switches to org2, and should NOT see org1's conversations.
|
||||
"""
|
||||
# Create service for user1 in org1
|
||||
user1_service_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create a conversation while user is in org1
|
||||
conv_in_org1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org1',
|
||||
title='Conversation in Org 1',
|
||||
)
|
||||
await user1_service_org1.save_app_conversation_info(conv_in_org1)
|
||||
|
||||
# Verify user can see the conversation in org1
|
||||
page_in_org1 = await user1_service_org1.search_app_conversation_info()
|
||||
assert len(page_in_org1.items) == 1
|
||||
assert page_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
# Simulate user switching to org2 by updating current_org_id using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
# Clear SQLAlchemy's identity map cache to simulate a new request
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Create new service instance (simulating a new request after org switch)
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should NOT see org1's conversations after switching to org2
|
||||
page_in_org2 = await user1_service_org2.search_app_conversation_info()
|
||||
assert (
|
||||
len(page_in_org2.items) == 0
|
||||
), 'User should not see conversations from org1 after switching to org2'
|
||||
|
||||
# User should not be able to get the specific conversation from org1
|
||||
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
|
||||
conv_in_org1.id
|
||||
)
|
||||
assert (
|
||||
conv_from_org2 is None
|
||||
), 'User should not be able to access org1 conversation from org2'
|
||||
|
||||
# Now create a conversation in org2
|
||||
conv_in_org2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org2',
|
||||
title='Conversation in Org 2',
|
||||
)
|
||||
await user1_service_org2.save_app_conversation_info(conv_in_org2)
|
||||
|
||||
# User should only see org2's conversation
|
||||
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
|
||||
assert len(page_in_org2_after.items) == 1
|
||||
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
|
||||
|
||||
# Switch back to org1 and verify isolation works both ways
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG1_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should only see org1's conversation now
|
||||
page_back_in_org1 = (
|
||||
await user1_service_back_to_org1.search_app_conversation_info()
|
||||
)
|
||||
assert len(page_back_in_org1.items) == 1
|
||||
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_respects_org_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that count_app_conversation_info respects org isolation."""
|
||||
# Create service for user1 in org1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations in org1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_org1_{i}',
|
||||
title=f'Org1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Count should be 3
|
||||
count_org1 = await user1_service.count_app_conversation_info()
|
||||
assert count_org1 == 3
|
||||
|
||||
# Switch to org2 using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Count should be 0 in org2
|
||||
count_org2 = await user1_service_org2.count_app_conversation_info()
|
||||
assert count_org2 == 0
|
||||
|
||||
117
enterprise/tests/unit/sync/test_resend_keycloak.py
Normal file
117
enterprise/tests/unit/sync/test_resend_keycloak.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Tests for resend_keycloak email validation."""
|
||||
|
||||
from sync.resend_keycloak import is_valid_email
|
||||
|
||||
|
||||
class TestIsValidEmail:
|
||||
"""Test cases for is_valid_email function."""
|
||||
|
||||
def test_valid_simple_email(self):
|
||||
"""Test that a simple valid email passes validation."""
|
||||
assert is_valid_email('user@example.com') is True
|
||||
|
||||
def test_valid_email_with_plus(self):
|
||||
"""Test that email with + modifier passes validation."""
|
||||
assert is_valid_email('user+tag@example.com') is True
|
||||
|
||||
def test_valid_email_with_dots(self):
|
||||
"""Test that email with dots in local part passes validation."""
|
||||
assert is_valid_email('first.last@example.com') is True
|
||||
|
||||
def test_valid_email_with_numbers(self):
|
||||
"""Test that email with numbers passes validation."""
|
||||
assert is_valid_email('user123@example.com') is True
|
||||
|
||||
def test_valid_email_with_subdomain(self):
|
||||
"""Test that email with subdomain passes validation."""
|
||||
assert is_valid_email('user@mail.example.com') is True
|
||||
|
||||
def test_valid_email_with_hyphen_domain(self):
|
||||
"""Test that email with hyphen in domain passes validation."""
|
||||
assert is_valid_email('user@example-site.com') is True
|
||||
|
||||
def test_valid_email_with_underscore(self):
|
||||
"""Test that email with underscore passes validation."""
|
||||
assert is_valid_email('user_name@example.com') is True
|
||||
|
||||
def test_valid_email_with_percent(self):
|
||||
"""Test that email with percent sign passes validation."""
|
||||
assert is_valid_email('user%name@example.com') is True
|
||||
|
||||
def test_invalid_email_with_exclamation(self):
|
||||
"""Test that email with exclamation mark fails validation.
|
||||
|
||||
This is the specific case from the bug report:
|
||||
ethanjames3713+!@gmail.com
|
||||
"""
|
||||
assert is_valid_email('ethanjames3713+!@gmail.com') is False
|
||||
|
||||
def test_invalid_email_with_special_chars(self):
|
||||
"""Test that email with other special characters fails validation."""
|
||||
assert is_valid_email('user!name@example.com') is False
|
||||
assert is_valid_email('user#name@example.com') is False
|
||||
assert is_valid_email('user$name@example.com') is False
|
||||
assert is_valid_email('user&name@example.com') is False
|
||||
assert is_valid_email("user'name@example.com") is False
|
||||
assert is_valid_email('user*name@example.com') is False
|
||||
assert is_valid_email('user=name@example.com') is False
|
||||
assert is_valid_email('user^name@example.com') is False
|
||||
assert is_valid_email('user`name@example.com') is False
|
||||
assert is_valid_email('user{name@example.com') is False
|
||||
assert is_valid_email('user|name@example.com') is False
|
||||
assert is_valid_email('user}name@example.com') is False
|
||||
assert is_valid_email('user~name@example.com') is False
|
||||
|
||||
def test_invalid_email_no_at_symbol(self):
|
||||
"""Test that email without @ symbol fails validation."""
|
||||
assert is_valid_email('userexample.com') is False
|
||||
|
||||
def test_invalid_email_no_domain(self):
|
||||
"""Test that email without domain fails validation."""
|
||||
assert is_valid_email('user@') is False
|
||||
|
||||
def test_invalid_email_no_local_part(self):
|
||||
"""Test that email without local part fails validation."""
|
||||
assert is_valid_email('@example.com') is False
|
||||
|
||||
def test_invalid_email_no_tld(self):
|
||||
"""Test that email without TLD fails validation."""
|
||||
assert is_valid_email('user@example') is False
|
||||
|
||||
def test_invalid_email_single_char_tld(self):
|
||||
"""Test that email with single character TLD fails validation."""
|
||||
assert is_valid_email('user@example.c') is False
|
||||
|
||||
def test_invalid_email_empty_string(self):
|
||||
"""Test that empty string fails validation."""
|
||||
assert is_valid_email('') is False
|
||||
|
||||
def test_invalid_email_none(self):
|
||||
"""Test that None fails validation."""
|
||||
assert is_valid_email(None) is False
|
||||
|
||||
def test_invalid_email_whitespace(self):
|
||||
"""Test that email with whitespace fails validation."""
|
||||
assert is_valid_email('user @example.com') is False
|
||||
assert is_valid_email('user@ example.com') is False
|
||||
assert is_valid_email(' user@example.com') is False
|
||||
assert is_valid_email('user@example.com ') is False
|
||||
|
||||
def test_invalid_email_double_at(self):
|
||||
"""Test that email with double @ fails validation."""
|
||||
assert is_valid_email('user@@example.com') is False
|
||||
|
||||
def test_email_double_dot_domain(self):
|
||||
"""Test email with double dot in domain.
|
||||
|
||||
Note: The regex allows this as it's technically valid in some edge cases,
|
||||
and Resend's API may accept it. The main goal is to reject special
|
||||
characters like ! that Resend definitely rejects.
|
||||
"""
|
||||
# This is allowed by our regex - Resend may or may not accept it
|
||||
assert is_valid_email('user@example..com') is True
|
||||
|
||||
def test_case_insensitive_validation(self):
|
||||
"""Test that validation works for uppercase emails."""
|
||||
assert is_valid_email('USER@EXAMPLE.COM') is True
|
||||
assert is_valid_email('User@Example.Com') is True
|
||||
@@ -265,17 +265,17 @@ async def test_list_api_keys(
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
assert result[0]['created_at'] == now
|
||||
assert result[0]['last_used_at'] == now
|
||||
assert result[0]['expires_at'] == now + timedelta(days=30)
|
||||
assert result[0].id == 1
|
||||
assert result[0].name == 'Key 1'
|
||||
assert result[0].created_at == now
|
||||
assert result[0].last_used_at == now
|
||||
assert result[0].expires_at == now + timedelta(days=30)
|
||||
|
||||
assert result[1]['id'] == 2
|
||||
assert result[1]['name'] == 'Key 2'
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
assert result[1].id == 2
|
||||
assert result[1].name == 'Key 2'
|
||||
assert result[1].created_at == now
|
||||
assert result[1].last_used_at is None
|
||||
assert result[1].expires_at is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
181
enterprise/tests/unit/test_auth_invitation_callback.py
Normal file
181
enterprise/tests/unit/test_auth_invitation_callback.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAuthCallbackInvitationEmailMismatch:
|
||||
"""Test cases for EmailMismatchError handling during auth callback."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redirect_url(self):
|
||||
"""Base redirect URL."""
|
||||
return 'https://app.example.com/'
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self):
|
||||
"""Mock user ID."""
|
||||
return '87654321-4321-8765-4321-876543218765'
|
||||
|
||||
def test_email_mismatch_appends_to_url_without_query_params(
|
||||
self, mock_redirect_url, mock_user_id
|
||||
):
|
||||
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
# Simulate the logic from auth.py
|
||||
redirect_url = mock_redirect_url
|
||||
try:
|
||||
raise EmailMismatchError('Your email does not match the invitation')
|
||||
except EmailMismatchError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
|
||||
|
||||
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
|
||||
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
redirect_url = 'https://app.example.com/?other_param=value'
|
||||
try:
|
||||
raise EmailMismatchError()
|
||||
except EmailMismatchError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
assert (
|
||||
redirect_url
|
||||
== 'https://app.example.com/?other_param=value&email_mismatch=true'
|
||||
)
|
||||
|
||||
def test_email_mismatch_error_has_default_message(self):
|
||||
"""Test that EmailMismatchError has the default message."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
error = EmailMismatchError()
|
||||
assert str(error) == 'Your email does not match the invitation'
|
||||
|
||||
def test_email_mismatch_error_accepts_custom_message(self):
|
||||
"""Test that EmailMismatchError accepts a custom message."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
custom_message = 'Custom error message'
|
||||
error = EmailMismatchError(custom_message)
|
||||
assert str(error) == custom_message
|
||||
|
||||
def test_email_mismatch_error_is_invitation_error(self):
|
||||
"""Test that EmailMismatchError inherits from InvitationError."""
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InvitationError,
|
||||
)
|
||||
|
||||
error = EmailMismatchError()
|
||||
assert isinstance(error, InvitationError)
|
||||
|
||||
|
||||
class TestInvitationTokenInOAuthState:
|
||||
"""Test cases for invitation token handling in OAuth state."""
|
||||
|
||||
def test_invitation_token_included_in_oauth_state(self):
|
||||
"""Test that invitation token is included in OAuth state data."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
# Simulate building OAuth state with invitation token
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
'invitation_token': 'inv-test-token-12345',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
decoded_data = json.loads(base64.b64decode(encoded_state))
|
||||
|
||||
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
|
||||
assert decoded_data['redirect_url'] == 'https://app.example.com/'
|
||||
|
||||
def test_invitation_token_extracted_from_oauth_state(self):
|
||||
"""Test that invitation token can be extracted from OAuth state."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
'invitation_token': 'inv-test-token-12345',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
|
||||
# Simulate decoding in callback
|
||||
decoded_state = json.loads(base64.b64decode(encoded_state))
|
||||
invitation_token = decoded_state.get('invitation_token')
|
||||
|
||||
assert invitation_token == 'inv-test-token-12345'
|
||||
|
||||
def test_oauth_state_without_invitation_token(self):
|
||||
"""Test that OAuth state works without invitation token."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
decoded_data = json.loads(base64.b64decode(encoded_state))
|
||||
|
||||
assert 'invitation_token' not in decoded_data
|
||||
assert decoded_data['redirect_url'] == 'https://app.example.com/'
|
||||
|
||||
|
||||
class TestAuthCallbackInvitationErrors:
|
||||
"""Test cases for various invitation error scenarios in auth callback."""
|
||||
|
||||
def test_invitation_expired_appends_flag(self):
|
||||
"""Test that invitation_expired=true is appended for expired invitations."""
|
||||
from server.routes.org_invitation_models import InvitationExpiredError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise InvitationExpiredError()
|
||||
except InvitationExpiredError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
|
||||
|
||||
def test_invitation_invalid_appends_flag(self):
|
||||
"""Test that invitation_invalid=true is appended for invalid invitations."""
|
||||
from server.routes.org_invitation_models import InvitationInvalidError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise InvitationInvalidError()
|
||||
except InvitationInvalidError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
|
||||
|
||||
def test_already_member_appends_flag(self):
|
||||
"""Test that already_member=true is appended when user is already a member."""
|
||||
from server.routes.org_invitation_models import UserAlreadyMemberError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise UserAlreadyMemberError()
|
||||
except UserAlreadyMemberError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?already_member=true'
|
||||
756
enterprise/tests/unit/test_authorization.py
Normal file
756
enterprise/tests/unit/test_authorization.py
Normal file
@@ -0,0 +1,756 @@
|
||||
"""
|
||||
Unit tests for permission-based authorization (authorization.py).
|
||||
|
||||
Tests the FastAPI dependencies that validate user permissions within organizations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.auth.authorization import (
|
||||
ROLE_PERMISSIONS,
|
||||
Permission,
|
||||
RoleName,
|
||||
get_role_permissions,
|
||||
get_user_org_role,
|
||||
has_permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Tests for Permission enum
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPermission:
|
||||
"""Tests for Permission enum."""
|
||||
|
||||
def test_permission_values(self):
|
||||
"""
|
||||
GIVEN: Permission enum
|
||||
WHEN: Accessing permission values
|
||||
THEN: All expected permissions exist with correct string values
|
||||
"""
|
||||
assert Permission.MANAGE_SECRETS.value == 'manage_secrets'
|
||||
assert Permission.MANAGE_MCP.value == 'manage_mcp'
|
||||
assert Permission.MANAGE_INTEGRATIONS.value == 'manage_integrations'
|
||||
assert (
|
||||
Permission.MANAGE_APPLICATION_SETTINGS.value
|
||||
== 'manage_application_settings'
|
||||
)
|
||||
assert Permission.MANAGE_API_KEYS.value == 'manage_api_keys'
|
||||
assert Permission.VIEW_LLM_SETTINGS.value == 'view_llm_settings'
|
||||
assert Permission.EDIT_LLM_SETTINGS.value == 'edit_llm_settings'
|
||||
assert Permission.VIEW_BILLING.value == 'view_billing'
|
||||
assert Permission.ADD_CREDITS.value == 'add_credits'
|
||||
assert (
|
||||
Permission.INVITE_USER_TO_ORGANIZATION.value
|
||||
== 'invite_user_to_organization'
|
||||
)
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER.value == 'change_user_role:member'
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN.value == 'change_user_role:admin'
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER.value == 'change_user_role:owner'
|
||||
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
|
||||
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
|
||||
|
||||
def test_permission_from_string(self):
|
||||
"""
|
||||
GIVEN: Valid permission string
|
||||
WHEN: Creating Permission from string
|
||||
THEN: Correct enum value is returned
|
||||
"""
|
||||
assert Permission('manage_secrets') == Permission.MANAGE_SECRETS
|
||||
assert Permission('view_llm_settings') == Permission.VIEW_LLM_SETTINGS
|
||||
assert Permission('delete_organization') == Permission.DELETE_ORGANIZATION
|
||||
|
||||
def test_permission_invalid_string(self):
|
||||
"""
|
||||
GIVEN: Invalid permission string
|
||||
WHEN: Creating Permission from string
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
Permission('invalid_permission')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for RoleName enum
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRoleName:
|
||||
"""Tests for RoleName enum."""
|
||||
|
||||
def test_role_name_values(self):
|
||||
"""
|
||||
GIVEN: RoleName enum
|
||||
WHEN: Accessing role name values
|
||||
THEN: All expected roles exist with correct string values
|
||||
"""
|
||||
assert RoleName.OWNER.value == 'owner'
|
||||
assert RoleName.ADMIN.value == 'admin'
|
||||
assert RoleName.MEMBER.value == 'member'
|
||||
|
||||
def test_role_name_from_string(self):
|
||||
"""
|
||||
GIVEN: Valid role name string
|
||||
WHEN: Creating RoleName from string
|
||||
THEN: Correct enum value is returned
|
||||
"""
|
||||
assert RoleName('owner') == RoleName.OWNER
|
||||
assert RoleName('admin') == RoleName.ADMIN
|
||||
assert RoleName('member') == RoleName.MEMBER
|
||||
|
||||
def test_role_name_invalid_string(self):
|
||||
"""
|
||||
GIVEN: Invalid role name string
|
||||
WHEN: Creating RoleName from string
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
RoleName('invalid_role')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for ROLE_PERMISSIONS mapping
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRolePermissions:
|
||||
"""Tests for role permission mappings."""
|
||||
|
||||
def test_owner_has_all_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking owner permissions
|
||||
THEN: Owner has all permissions including owner-only permissions
|
||||
"""
|
||||
owner_perms = ROLE_PERMISSIONS[RoleName.OWNER]
|
||||
assert Permission.MANAGE_SECRETS in owner_perms
|
||||
assert Permission.MANAGE_MCP in owner_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in owner_perms
|
||||
assert Permission.EDIT_LLM_SETTINGS in owner_perms
|
||||
assert Permission.VIEW_BILLING in owner_perms
|
||||
assert Permission.ADD_CREDITS in owner_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
|
||||
assert Permission.DELETE_ORGANIZATION in owner_perms
|
||||
|
||||
def test_admin_has_admin_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking admin permissions
|
||||
THEN: Admin has admin permissions but not owner-only permissions
|
||||
"""
|
||||
admin_perms = ROLE_PERMISSIONS[RoleName.ADMIN]
|
||||
assert Permission.MANAGE_SECRETS in admin_perms
|
||||
assert Permission.MANAGE_MCP in admin_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in admin_perms
|
||||
assert Permission.EDIT_LLM_SETTINGS in admin_perms
|
||||
assert Permission.VIEW_BILLING in admin_perms
|
||||
assert Permission.ADD_CREDITS in admin_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
|
||||
# Admin should NOT have owner-only permissions
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
|
||||
assert Permission.DELETE_ORGANIZATION not in admin_perms
|
||||
|
||||
def test_member_has_limited_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking member permissions
|
||||
THEN: Member has limited permissions
|
||||
"""
|
||||
member_perms = ROLE_PERMISSIONS[RoleName.MEMBER]
|
||||
# Member has basic settings permissions
|
||||
assert Permission.MANAGE_SECRETS in member_perms
|
||||
assert Permission.MANAGE_MCP in member_perms
|
||||
assert Permission.MANAGE_INTEGRATIONS in member_perms
|
||||
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
|
||||
assert Permission.MANAGE_API_KEYS in member_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in member_perms
|
||||
assert Permission.VIEW_ORG_SETTINGS in member_perms
|
||||
# Member should NOT have admin/owner permissions
|
||||
assert Permission.EDIT_LLM_SETTINGS not in member_perms
|
||||
assert Permission.VIEW_BILLING not in member_perms
|
||||
assert Permission.ADD_CREDITS not in member_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in member_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in member_perms
|
||||
assert Permission.DELETE_ORGANIZATION not in member_perms
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_role_permissions function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetRolePermissions:
|
||||
"""Tests for get_role_permissions function."""
|
||||
|
||||
def test_get_owner_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'owner'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Owner permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('owner')
|
||||
assert Permission.DELETE_ORGANIZATION in perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in perms
|
||||
|
||||
def test_get_admin_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'admin'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Admin permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('admin')
|
||||
assert Permission.EDIT_LLM_SETTINGS in perms
|
||||
assert Permission.DELETE_ORGANIZATION not in perms
|
||||
|
||||
def test_get_member_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'member'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Member permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('member')
|
||||
assert Permission.VIEW_LLM_SETTINGS in perms
|
||||
assert Permission.EDIT_LLM_SETTINGS not in perms
|
||||
|
||||
def test_get_invalid_role_permissions(self):
|
||||
"""
|
||||
GIVEN: Invalid role name
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Empty frozenset is returned
|
||||
"""
|
||||
perms = get_role_permissions('invalid_role')
|
||||
assert perms == frozenset()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for has_permission function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHasPermission:
|
||||
"""Tests for has_permission function."""
|
||||
|
||||
def test_owner_has_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is True
|
||||
|
||||
def test_owner_has_view_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: Checking for VIEW_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
|
||||
|
||||
def test_admin_has_edit_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: Checking for EDIT_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is True
|
||||
|
||||
def test_admin_lacks_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
def test_member_has_view_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for VIEW_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
|
||||
|
||||
def test_member_lacks_edit_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for EDIT_LLM_SETTINGS permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is False
|
||||
|
||||
def test_member_lacks_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
def test_invalid_role_has_no_permissions(self):
|
||||
"""
|
||||
GIVEN: User with invalid role
|
||||
WHEN: Checking for any permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'invalid_role'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is False
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_user_org_role function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetUserOrgRole:
|
||||
"""Tests for get_user_org_role function."""
|
||||
|
||||
def test_returns_role_when_member_exists(self):
|
||||
"""
|
||||
GIVEN: User is a member of organization with role
|
||||
WHEN: get_user_org_role is called
|
||||
THEN: Role object is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.role_id = 1
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
return_value=mock_org_member,
|
||||
),
|
||||
patch(
|
||||
'server.auth.authorization.RoleStore.get_role_by_id',
|
||||
return_value=mock_role,
|
||||
),
|
||||
):
|
||||
result = get_user_org_role(user_id, org_id)
|
||||
assert result == mock_role
|
||||
|
||||
def test_returns_none_when_not_member(self):
|
||||
"""
|
||||
GIVEN: User is not a member of organization
|
||||
WHEN: get_user_org_role is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_org_role(user_id, org_id)
|
||||
assert result is None
|
||||
|
||||
def test_returns_role_when_org_id_is_none(self):
|
||||
"""
|
||||
GIVEN: User with a current organization
|
||||
WHEN: get_user_org_role is called with org_id=None
|
||||
THEN: Role object is returned using get_org_member_for_current_org
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.role_id = 1
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
|
||||
return_value=mock_org_member,
|
||||
) as mock_get_current,
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
) as mock_get_org_member,
|
||||
patch(
|
||||
'server.auth.authorization.RoleStore.get_role_by_id',
|
||||
return_value=mock_role,
|
||||
),
|
||||
):
|
||||
result = get_user_org_role(user_id, None)
|
||||
assert result == mock_role
|
||||
mock_get_current.assert_called_once()
|
||||
mock_get_org_member.assert_not_called()
|
||||
|
||||
def test_returns_none_when_org_id_is_none_and_no_current_org(self):
|
||||
"""
|
||||
GIVEN: User with no current organization membership
|
||||
WHEN: get_user_org_role is called with org_id=None
|
||||
THEN: None is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_org_role(user_id, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for require_permission dependency
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
"""Tests for require_permission dependency factory."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_id_when_authorized(self):
|
||||
"""
|
||||
GIVEN: User with required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_401_when_not_authenticated(self):
|
||||
"""
|
||||
GIVEN: No user ID (not authenticated)
|
||||
WHEN: Permission checker is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
org_id = uuid4()
|
||||
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_not_member(self):
|
||||
"""
|
||||
GIVEN: User is not a member of organization
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_insufficient_permission(self):
|
||||
"""
|
||||
GIVEN: User without required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'delete_organization' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_delete_organization(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: DELETE_ORGANIZATION permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_delete_organization(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: DELETE_ORGANIZATION permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_on_insufficient_permission(self):
|
||||
"""
|
||||
GIVEN: User without required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: Warning is logged with details
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
),
|
||||
patch('server.auth.authorization.logger') as mock_logger,
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['user_role'] == 'member'
|
||||
assert call_args[1]['extra']['required_permission'] == 'delete_organization'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_id_when_org_id_is_none(self):
|
||||
"""
|
||||
GIVEN: User with required permission in their current org
|
||||
WHEN: Permission checker is called with org_id=None
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
) as mock_get_role:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=None, user_id=user_id)
|
||||
assert result == user_id
|
||||
mock_get_role.assert_called_once_with(user_id, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_org_id_is_none_and_not_member(self):
|
||||
"""
|
||||
GIVEN: User not a member of their current organization
|
||||
WHEN: Permission checker is called with org_id=None
|
||||
THEN: HTTPException with 403 status is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=None, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for permission-based access control scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPermissionScenarios:
|
||||
"""Tests for real-world permission scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_can_manage_secrets(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: MANAGE_SECRETS permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_cannot_invite_users(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_invite_users(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_change_owner_role(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: CHANGE_USER_ROLE_OWNER permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_change_owner_role(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: CHANGE_USER_ROLE_OWNER permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
@@ -101,7 +101,7 @@ async def test_get_credits_success():
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
'max_budget_in_team': 100.00,
|
||||
}
|
||||
},
|
||||
request=MagicMock(),
|
||||
@@ -121,7 +121,7 @@ async def test_get_credits_success():
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
'max_budget_in_team': 100.00,
|
||||
},
|
||||
),
|
||||
):
|
||||
@@ -299,6 +299,8 @@ async def test_success_callback_success():
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
@@ -311,7 +313,7 @@ async def test_success_callback_success():
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
'max_budget_in_team': 100.00,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
@@ -319,7 +321,17 @@ async def test_success_callback_success():
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
# First query: BillingSession (query().filter().filter().first())
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
# Second query: Org (query().filter().first()) - use side_effect for different return chains
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
@@ -337,9 +349,12 @@ async def test_success_callback_success():
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
125.0, # 100 + 25.00
|
||||
)
|
||||
|
||||
# Verify BYOR export is enabled for the org (updated in same session)
|
||||
assert mock_org.byor_export_enabled is True
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
@@ -387,6 +402,68 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
"""Test that database changes are not committed when update_team_and_users_budget fails.
|
||||
|
||||
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
|
||||
the database transaction rolls back.
|
||||
"""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 0,
|
||||
'max_budget_in_team': 0,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=1000, # $10
|
||||
customer='mock_customer_id',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_session_not_found():
|
||||
"""Test cancel callback when billing session is not found."""
|
||||
@@ -502,6 +579,6 @@ async def test_create_customer_setup_session_success():
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com/?free_credits=success',
|
||||
success_url='https://test.com/?setup=success',
|
||||
cancel_url='https://test.com/',
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ async def test_create_customer_setup_session_uses_customer_id():
|
||||
customer=customer_id,
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
success_url=f'{request.base_url}?setup=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
|
||||
|
||||
192
enterprise/tests/unit/test_email_service.py
Normal file
192
enterprise/tests/unit/test_email_service.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for email service."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from server.services.email_service import (
|
||||
DEFAULT_WEB_HOST,
|
||||
EmailService,
|
||||
)
|
||||
|
||||
|
||||
class TestEmailServiceInvitationUrl:
|
||||
"""Test cases for invitation URL generation."""
|
||||
|
||||
def test_invitation_url_uses_correct_endpoint(self):
|
||||
"""Test that invitation URL points to the correct API endpoint."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
# Get the call arguments
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
# Verify the URL in the email HTML contains the correct endpoint
|
||||
assert (
|
||||
'/api/organizations/members/invite/accept?token='
|
||||
in email_params['html']
|
||||
)
|
||||
assert 'inv-test-token-12345' in email_params['html']
|
||||
|
||||
def test_invitation_url_uses_web_host_env_var(self):
|
||||
"""Test that invitation URL uses WEB_HOST environment variable."""
|
||||
custom_host = 'https://custom.example.com'
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
|
||||
),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
|
||||
assert expected_url in email_params['html']
|
||||
|
||||
def test_invitation_url_uses_default_host_when_env_not_set(self):
|
||||
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
|
||||
# Remove WEB_HOST if it exists
|
||||
env_without_web_host.pop('WEB_HOST', None)
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, env_without_web_host, clear=True),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
# Clear WEB_HOST from the environment
|
||||
os.environ.pop('WEB_HOST', None)
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
|
||||
assert expected_url in email_params['html']
|
||||
|
||||
|
||||
class TestEmailServiceGetResendClient:
|
||||
"""Test cases for Resend client initialization."""
|
||||
|
||||
def test_get_resend_client_returns_false_when_resend_not_available(self):
|
||||
"""Test that _get_resend_client returns False when resend is not installed."""
|
||||
with patch('server.services.email_service.RESEND_AVAILABLE', False):
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is False
|
||||
|
||||
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
|
||||
"""Test that _get_resend_client returns False when API key is missing."""
|
||||
with (
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch.dict(os.environ, {}, clear=True),
|
||||
):
|
||||
os.environ.pop('RESEND_API_KEY', None)
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is False
|
||||
|
||||
def test_get_resend_client_returns_true_when_configured(self):
|
||||
"""Test that _get_resend_client returns True when properly configured."""
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is True
|
||||
assert mock_resend.api_key == 'test-key'
|
||||
|
||||
|
||||
class TestEmailServiceSendInvitationEmail:
|
||||
"""Test cases for send_invitation_email method."""
|
||||
|
||||
def test_send_invitation_email_skips_when_client_not_ready(self):
|
||||
"""Test that email sending is skipped when client is not ready."""
|
||||
with patch.object(
|
||||
EmailService, '_get_resend_client', return_value=False
|
||||
) as mock_get_client:
|
||||
# Should not raise, just return early
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
mock_get_client.assert_called_once()
|
||||
|
||||
def test_send_invitation_email_includes_all_required_info(self):
|
||||
"""Test that invitation email includes org name, inviter name, and role."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Acme Corp',
|
||||
inviter_name='John Doe',
|
||||
role_name='admin',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=42,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
# Verify email content
|
||||
assert email_params['to'] == ['test@example.com']
|
||||
assert 'Acme Corp' in email_params['subject']
|
||||
assert 'John Doe' in email_params['html']
|
||||
assert 'Acme Corp' in email_params['html']
|
||||
assert 'admin' in email_params['html']
|
||||
@@ -142,44 +142,192 @@ class TestLiteLlmManager:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
|
||||
"""Test create_entries in cloud deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_settings,
|
||||
create_user=False,
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert (
|
||||
result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
)
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Verify API calls were made
|
||||
assert (
|
||||
mock_client.post.call_count == 3
|
||||
) # create_team, create_user, add_user_to_team, generate_key
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made (get_team + 3 posts)
|
||||
assert mock_client.get.call_count == 1 # get_team
|
||||
assert (
|
||||
mock_client.post.call_count == 3
|
||||
) # create_team, add_user_to_team, generate_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_inherits_existing_team_budget(
|
||||
self, mock_settings, mock_response
|
||||
):
|
||||
"""Test that create_entries inherits budget from existing team."""
|
||||
mock_team_response = MagicMock()
|
||||
mock_team_response.is_success = True
|
||||
mock_team_response.status_code = 200
|
||||
mock_team_response.json.return_value = {
|
||||
'team_info': {'max_budget': 30.0, 'spend': 5.0},
|
||||
'team_memberships': [],
|
||||
}
|
||||
mock_team_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_team_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify _get_team was called first
|
||||
mock_client.get.assert_called_once()
|
||||
get_call_url = mock_client.get.call_args[0][0]
|
||||
assert 'team/info' in get_call_url
|
||||
assert 'test-org-id' in get_call_url
|
||||
|
||||
# Verify _create_team was called with inherited budget (30.0)
|
||||
create_team_call = mock_client.post.call_args_list[0]
|
||||
assert 'team/new' in create_team_call[0][0]
|
||||
assert create_team_call[1]['json']['max_budget'] == 30.0
|
||||
|
||||
# Verify _add_user_to_team was called with inherited budget (30.0)
|
||||
add_user_call = mock_client.post.call_args_list[1]
|
||||
assert 'team/member_add' in add_user_call[0][0]
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_new_org_uses_zero_budget(
|
||||
self, mock_settings, mock_response
|
||||
):
|
||||
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify _create_team was called with budget=0
|
||||
create_team_call = mock_client.post.call_args_list[0]
|
||||
assert 'team/new' in create_team_call[0][0]
|
||||
assert create_team_call[1]['json']['max_budget'] == 0.0
|
||||
|
||||
# Verify _add_user_to_team was called with budget=0
|
||||
add_user_call = mock_client.post.call_args_list[1]
|
||||
assert 'team/member_add' in add_user_call[0][0]
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
|
||||
"""Test that create_entries propagates non-404 errors from _get_team."""
|
||||
mock_500_response = MagicMock()
|
||||
mock_500_response.status_code = 500
|
||||
mock_500_response.is_success = False
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_500_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Internal Server Error',
|
||||
request=MagicMock(),
|
||||
response=mock_500_response,
|
||||
)
|
||||
)
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_missing_config(self, mock_user_settings):
|
||||
|
||||
464
enterprise/tests/unit/test_org_invitation_service.py
Normal file
464
enterprise/tests/unit/test_org_invitation_service.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Tests for organization invitation service - email validation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
|
||||
class TestAcceptInvitationEmailValidation:
|
||||
"""Test cases for email validation during invitation acceptance."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create a mock invitation with pending status."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.id = 1
|
||||
invitation.email = 'alice@example.com'
|
||||
invitation.status = OrgInvitation.STATUS_PENDING
|
||||
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
invitation.role_id = 1
|
||||
return invitation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock user with email."""
|
||||
user = MagicMock()
|
||||
user.id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
user.email = 'alice@example.com'
|
||||
return user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
|
||||
"""Test that invitation is accepted when user email matches invitation email."""
|
||||
# Arrange
|
||||
user_id = mock_user.id
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
|
||||
) as mock_accept:
|
||||
mock_accept.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert
|
||||
mock_accept.assert_called_once_with(token, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_mismatch_raises_error(
|
||||
self, mock_invitation, mock_user
|
||||
):
|
||||
"""Test that EmailMismatchError is raised when emails don't match."""
|
||||
# Arrange
|
||||
user_id = mock_user.id
|
||||
token = 'inv-test-token-12345'
|
||||
mock_user.email = 'bob@example.com' # Different email
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(EmailMismatchError):
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that Keycloak email is used when user has no email in database."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = None # No email in database
|
||||
|
||||
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.TokenManager'
|
||||
) as mock_token_manager_class,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock TokenManager instance
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.get_user_info_from_user_id = AsyncMock(
|
||||
return_value=mock_keycloak_user_info
|
||||
)
|
||||
mock_token_manager_class.return_value = mock_token_manager
|
||||
|
||||
mock_get_member.return_value = None # Not already a member
|
||||
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because Keycloak email matches
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert
|
||||
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
|
||||
str(user_id)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_no_email_anywhere_raises_error(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = None # No email in database
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.TokenManager'
|
||||
) as mock_token_manager_class,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock TokenManager to return no email
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
|
||||
mock_token_manager_class.return_value = mock_token_manager
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(EmailMismatchError) as exc_info:
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
assert 'does not have an email address' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_comparison_is_case_insensitive(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that email comparison is case insensitive."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
|
||||
|
||||
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because emails match case-insensitively
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert - invitation was accepted (update_invitation_status was called)
|
||||
mock_update_status.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateInvitationsBatch:
|
||||
"""Test cases for batch invitation creation."""
|
||||
|
||||
@pytest.fixture
|
||||
def org_id(self):
|
||||
"""Organization UUID for testing."""
|
||||
return UUID('12345678-1234-5678-1234-567812345678')
|
||||
|
||||
@pytest.fixture
|
||||
def inviter_id(self):
|
||||
"""Inviter UUID for testing."""
|
||||
return UUID('87654321-4321-8765-4321-876543218765')
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org(self):
|
||||
"""Create a mock organization."""
|
||||
org = MagicMock()
|
||||
org.id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
org.name = 'Test Org'
|
||||
return org
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inviter_member(self):
|
||||
"""Create a mock inviter member with owner role."""
|
||||
member = MagicMock()
|
||||
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
member.role_id = 1
|
||||
return member
|
||||
|
||||
@pytest.fixture
|
||||
def mock_owner_role(self):
|
||||
"""Create a mock owner role."""
|
||||
role = MagicMock()
|
||||
role.id = 1
|
||||
role.name = 'owner'
|
||||
return role
|
||||
|
||||
@pytest.fixture
|
||||
def mock_member_role(self):
|
||||
"""Create a mock member role."""
|
||||
role = MagicMock()
|
||||
role.id = 3
|
||||
role.name = 'member'
|
||||
return role
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_creates_all_invitations_successfully(
|
||||
self,
|
||||
org_id,
|
||||
inviter_id,
|
||||
mock_org,
|
||||
mock_inviter_member,
|
||||
mock_owner_role,
|
||||
mock_member_role,
|
||||
):
|
||||
"""Test that batch creation succeeds for all valid emails."""
|
||||
# Arrange
|
||||
emails = ['alice@example.com', 'bob@example.com']
|
||||
mock_invitation_1 = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation_1.id = 1
|
||||
mock_invitation_2 = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation_2.id = 2
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=mock_member_role,
|
||||
),
|
||||
patch.object(
|
||||
OrgInvitationService,
|
||||
'create_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[mock_invitation_1, mock_invitation_2],
|
||||
),
|
||||
):
|
||||
# Act
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(successful) == 2
|
||||
assert len(failed) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_handles_partial_success(
|
||||
self,
|
||||
org_id,
|
||||
inviter_id,
|
||||
mock_org,
|
||||
mock_inviter_member,
|
||||
mock_owner_role,
|
||||
mock_member_role,
|
||||
):
|
||||
"""Test that batch returns partial results when some emails fail."""
|
||||
# Arrange
|
||||
from server.routes.org_invitation_models import UserAlreadyMemberError
|
||||
|
||||
emails = ['alice@example.com', 'existing@example.com']
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=mock_member_role,
|
||||
),
|
||||
patch.object(
|
||||
OrgInvitationService,
|
||||
'create_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[mock_invitation, UserAlreadyMemberError()],
|
||||
),
|
||||
):
|
||||
# Act
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(successful) == 1
|
||||
assert len(failed) == 1
|
||||
assert failed[0][0] == 'existing@example.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
|
||||
"""Test that permission error fails the entire batch upfront."""
|
||||
# Arrange
|
||||
|
||||
emails = ['alice@example.com', 'bob@example.com']
|
||||
|
||||
with patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=None, # Organization not found
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
assert 'not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_fails_on_invalid_role(
|
||||
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
|
||||
):
|
||||
"""Test that invalid role fails the entire batch."""
|
||||
# Arrange
|
||||
emails = ['alice@example.com']
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=None, # Invalid role
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='invalid_role',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
assert 'Invalid role' in str(exc_info.value)
|
||||
308
enterprise/tests/unit/test_org_invitation_store.py
Normal file
308
enterprise/tests/unit/test_org_invitation_store.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Tests for organization invitation store."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import (
|
||||
INVITATION_TOKEN_LENGTH,
|
||||
INVITATION_TOKEN_PREFIX,
|
||||
OrgInvitationStore,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateToken:
|
||||
"""Test cases for token generation."""
|
||||
|
||||
def test_generate_token_has_correct_prefix(self):
|
||||
"""Test that generated tokens have the correct prefix."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
assert token.startswith(INVITATION_TOKEN_PREFIX)
|
||||
|
||||
def test_generate_token_has_correct_length(self):
|
||||
"""Test that generated tokens have the correct total length."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
|
||||
assert len(token) == expected_length
|
||||
|
||||
def test_generate_token_uses_alphanumeric_characters(self):
|
||||
"""Test that generated tokens use only alphanumeric characters."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Remove prefix and check the rest is alphanumeric
|
||||
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
|
||||
assert random_part.isalnum()
|
||||
|
||||
def test_generate_token_is_unique(self):
|
||||
"""Test that generated tokens are unique (probabilistically)."""
|
||||
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
|
||||
assert len(set(tokens)) == 100
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Test cases for token expiration checking."""
|
||||
|
||||
def test_token_not_expired_when_future(self):
|
||||
"""Test that tokens with future expiration are not expired."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is False
|
||||
|
||||
def test_token_expired_when_past(self):
|
||||
"""Test that tokens with past expiration are expired."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is True
|
||||
|
||||
def test_token_expired_at_exact_boundary(self):
|
||||
"""Test that tokens at exact expiration time are expired."""
|
||||
# A token that expires "now" should be expired
|
||||
now = datetime.utcnow()
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = now - timedelta(microseconds=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestCreateInvitation:
|
||||
"""Test cases for invitation creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invitation_normalizes_email(self):
|
||||
"""Test that email is normalized (lowercase, stripped) on creation."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.execute = AsyncMock()
|
||||
|
||||
# Mock the result of the re-fetch query
|
||||
mock_result = MagicMock()
|
||||
mock_invitation = MagicMock()
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.email = 'test@example.com'
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
await OrgInvitationStore.create_invitation(
|
||||
org_id=UUID('12345678-1234-5678-1234-567812345678'),
|
||||
email=' TEST@EXAMPLE.COM ',
|
||||
role_id=1,
|
||||
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
|
||||
)
|
||||
|
||||
# Verify that the OrgInvitation was created with normalized email
|
||||
add_call = mock_session.add.call_args
|
||||
created_invitation = add_call[0][0]
|
||||
assert created_invitation.email == 'test@example.com'
|
||||
|
||||
|
||||
class TestGetInvitationByToken:
|
||||
"""Test cases for getting invitation by token."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_by_token_returns_invitation(self):
|
||||
"""Test that get_invitation_by_token returns the invitation when found."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.token = 'inv-test-token-12345'
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.get_invitation_by_token(
|
||||
'inv-test-token-12345'
|
||||
)
|
||||
assert result == mock_invitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_by_token_returns_none_when_not_found(self):
|
||||
"""Test that get_invitation_by_token returns None when not found."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.get_invitation_by_token(
|
||||
'inv-nonexistent-token'
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetPendingInvitation:
|
||||
"""Test cases for getting pending invitation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_invitation_normalizes_email(self):
|
||||
"""Test that email is normalized when querying for pending invitations."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
await OrgInvitationStore.get_pending_invitation(
|
||||
org_id=UUID('12345678-1234-5678-1234-567812345678'),
|
||||
email=' TEST@EXAMPLE.COM ',
|
||||
)
|
||||
|
||||
# Verify the query was called (email normalization happens in the filter)
|
||||
assert mock_session.execute.called
|
||||
|
||||
|
||||
class TestUpdateInvitationStatus:
|
||||
"""Test cases for updating invitation status."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_sets_accepted_at_for_accepted(self):
|
||||
"""Test that accepted_at is set when status is accepted."""
|
||||
from uuid import UUID
|
||||
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation_id=1,
|
||||
status=OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
assert mock_invitation.accepted_at is not None
|
||||
assert mock_invitation.accepted_by_user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_returns_none_when_not_found(self):
|
||||
"""Test that update returns None when invitation not found."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.update_invitation_status(
|
||||
invitation_id=999,
|
||||
status=OrgInvitation.STATUS_ACCEPTED,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMarkExpiredIfNeeded:
|
||||
"""Test cases for marking expired invitations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_marks_expired_when_pending_and_past_expiry(self):
|
||||
"""Test that pending expired invitations are marked as expired."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is True
|
||||
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_mark_when_not_expired(self):
|
||||
"""Test that non-expired invitations are not marked."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is False
|
||||
mock_update.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_mark_when_not_pending(self):
|
||||
"""Test that non-pending invitations are not marked even if expired."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
|
||||
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is False
|
||||
mock_update.assert_not_called()
|
||||
388
enterprise/tests/unit/test_org_invitations_router.py
Normal file
388
enterprise/tests/unit/test_org_invitations_router.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Tests for organization invitations API router."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.routes.org_invitations import accept_router, invitation_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a FastAPI app with the invitation routers."""
|
||||
app = FastAPI()
|
||||
app.include_router(invitation_router)
|
||||
app.include_router(accept_router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestRouterPrefixes:
|
||||
"""Test that router prefixes are configured correctly."""
|
||||
|
||||
def test_invitation_router_has_correct_prefix(self):
|
||||
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
|
||||
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
|
||||
|
||||
def test_accept_router_has_correct_prefix(self):
|
||||
"""Test that accept_router has /api/organizations/members/invite prefix."""
|
||||
assert accept_router.prefix == '/api/organizations/members/invite'
|
||||
|
||||
|
||||
class TestAcceptInvitationEndpoint:
|
||||
"""Test cases for the accept invitation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth(self):
|
||||
"""Create a mock user auth."""
|
||||
user_auth = MagicMock()
|
||||
user_auth.get_user_id = AsyncMock(
|
||||
return_value='87654321-4321-8765-4321-876543218765'
|
||||
)
|
||||
return user_auth
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_unauthenticated_redirects_to_login(self, client):
|
||||
"""Test that unauthenticated users are redirected to login with invitation token."""
|
||||
with patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
|
||||
'location', ''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_authenticated_success_redirects_home(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that successful acceptance redirects to home page."""
|
||||
mock_invitation = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_invitation,
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
location = response.headers.get('location', '')
|
||||
assert location.endswith('/')
|
||||
assert 'invitation_expired' not in location
|
||||
assert 'invitation_invalid' not in location
|
||||
assert 'email_mismatch' not in location
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_expired_invitation_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that expired invitation redirects with invitation_expired=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InvitationExpiredError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_expired=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invalid_invitation_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that invalid invitation redirects with invitation_invalid=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InvitationInvalidError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_invalid=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_already_member_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that already member error redirects with already_member=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=UserAlreadyMemberError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'already_member=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_email_mismatch_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that email mismatch error redirects with email_mismatch=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=EmailMismatchError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'email_mismatch=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_unexpected_error_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that unexpected errors redirect with invitation_error=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception('Unexpected error'),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_error=true' in response.headers.get('location', '')
|
||||
|
||||
|
||||
class TestCreateInvitationBatchEndpoint:
|
||||
"""Test cases for the batch invitation creation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_app(self):
|
||||
"""Create a FastAPI app with dependency overrides for batch tests."""
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(invitation_router)
|
||||
|
||||
# Override the get_user_id dependency
|
||||
app.dependency_overrides[get_user_id] = (
|
||||
lambda: '87654321-4321-8765-4321-876543218765'
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def batch_client(self, batch_app):
|
||||
"""Create a test client with dependency overrides."""
|
||||
return TestClient(batch_app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create a mock invitation."""
|
||||
from datetime import datetime
|
||||
|
||||
invitation = MagicMock()
|
||||
invitation.id = 1
|
||||
invitation.email = 'alice@example.com'
|
||||
invitation.role = MagicMock(name='member')
|
||||
invitation.role.name = 'member'
|
||||
invitation.role_id = 3
|
||||
invitation.status = 'pending'
|
||||
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
|
||||
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
|
||||
return invitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_returns_successful_invitations(
|
||||
self, batch_client, mock_invitation
|
||||
):
|
||||
"""Test that batch creation returns successful invitations."""
|
||||
mock_invitation_2 = MagicMock()
|
||||
mock_invitation_2.id = 2
|
||||
mock_invitation_2.email = 'bob@example.com'
|
||||
mock_invitation_2.role = MagicMock()
|
||||
mock_invitation_2.role.name = 'member'
|
||||
mock_invitation_2.role_id = 3
|
||||
mock_invitation_2.status = 'pending'
|
||||
mock_invitation_2.created_at = mock_invitation.created_at
|
||||
mock_invitation_2.expires_at = mock_invitation.expires_at
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
return_value=([mock_invitation, mock_invitation_2], []),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={
|
||||
'emails': ['alice@example.com', 'bob@example.com'],
|
||||
'role': 'member',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert len(data['successful']) == 2
|
||||
assert len(data['failed']) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_returns_partial_success(
|
||||
self, batch_client, mock_invitation
|
||||
):
|
||||
"""Test that batch creation returns both successful and failed invitations."""
|
||||
failed_emails = [('existing@example.com', 'User is already a member')]
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
return_value=([mock_invitation], failed_emails),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={
|
||||
'emails': ['alice@example.com', 'existing@example.com'],
|
||||
'role': 'member',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert len(data['successful']) == 1
|
||||
assert len(data['failed']) == 1
|
||||
assert data['failed'][0]['email'] == 'existing@example.com'
|
||||
assert 'already a member' in data['failed'][0]['error']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_permission_denied_returns_403(self, batch_client):
|
||||
"""Test that permission denied returns 403 for entire batch."""
|
||||
from server.routes.org_invitation_models import InsufficientPermissionError
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientPermissionError(
|
||||
'Only owners and admins can invite'
|
||||
),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={'emails': ['alice@example.com'], 'role': 'member'},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert 'owners and admins' in response.json()['detail']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_invalid_role_returns_400(self, batch_client):
|
||||
"""Test that invalid role returns 400."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError('Invalid role: superuser'),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={'emails': ['alice@example.com'], 'role': 'superuser'},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert 'Invalid role' in response.json()['detail']
|
||||
@@ -1,8 +1,15 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Mock the database module before importing OrgMemberStore
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
@@ -10,6 +17,31 @@ with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
# Create all tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
def test_get_org_members(session_maker):
|
||||
# Test getting org_members by org ID
|
||||
with session_maker() as session:
|
||||
@@ -126,6 +158,57 @@ def test_get_org_member(session_maker):
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
|
||||
|
||||
|
||||
def test_get_org_member_for_current_org(session_maker):
|
||||
# Test getting org_member for user's current organization
|
||||
with session_maker() as session:
|
||||
# Create test data - user belongs to two orgs but current_org is org1
|
||||
org1 = Org(name='test-org-1')
|
||||
org2 = Org(name='test-org-2')
|
||||
session.add_all([org1, org2])
|
||||
session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org1.id)
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add_all([user, role])
|
||||
session.flush()
|
||||
|
||||
org_member1 = OrgMember(
|
||||
org_id=org1.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-1',
|
||||
status='active',
|
||||
)
|
||||
org_member2 = OrgMember(
|
||||
org_id=org2.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key-2',
|
||||
status='active',
|
||||
)
|
||||
session.add_all([org_member1, org_member2])
|
||||
session.commit()
|
||||
user_id = user.id
|
||||
org1_id = org1.id
|
||||
|
||||
# Test retrieval - should return org_member for current_org (org1)
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(user_id)
|
||||
assert retrieved_org_member is not None
|
||||
assert retrieved_org_member.org_id == org1_id
|
||||
assert retrieved_org_member.user_id == user_id
|
||||
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key-1'
|
||||
|
||||
|
||||
def test_get_org_member_for_current_org_user_not_found(session_maker):
|
||||
# Test getting org_member for non-existent user
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
retrieved_org_member = OrgMemberStore.get_org_member_for_current_org(
|
||||
uuid.uuid4()
|
||||
)
|
||||
assert retrieved_org_member is None
|
||||
|
||||
|
||||
def test_add_user_to_org(session_maker):
|
||||
# Test adding a user to an org
|
||||
with session_maker() as session:
|
||||
@@ -251,3 +334,324 @@ def test_remove_user_from_org_not_found(session_maker):
|
||||
with patch('storage.org_member_store.session_maker', session_maker):
|
||||
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_basic(async_session_maker):
|
||||
"""Test basic pagination returns correct number of items."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
# Create 5 users
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
|
||||
for i in range(5)
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
# Create org members
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=3
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 3
|
||||
assert has_more is True
|
||||
# Verify user and role relationships are loaded
|
||||
assert all(member.user is not None for member in members)
|
||||
assert all(member.role is not None for member in members)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_no_more(async_session_maker):
|
||||
"""Test pagination when there are no more results."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
# Create 3 users
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
|
||||
for i in range(3)
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
# Create org members
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=5
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 3
|
||||
assert has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_exact_limit(async_session_maker):
|
||||
"""Test pagination when results exactly match limit."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
# Create exactly 5 users
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
|
||||
for i in range(5)
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
# Create org members
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=5
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 5
|
||||
assert has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_with_offset(async_session_maker):
|
||||
"""Test pagination with offset skips correct number of items."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
# Create 10 users
|
||||
users = [
|
||||
User(id=uuid.uuid4(), current_org_id=org.id, email=f'user{i}@example.com')
|
||||
for i in range(10)
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
# Create org members
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user in enumerate(users)
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act - Get first page
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
first_page, has_more_first = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=3
|
||||
)
|
||||
|
||||
# Get second page
|
||||
second_page, has_more_second = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=3, limit=3
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(first_page) == 3
|
||||
assert has_more_first is True
|
||||
assert len(second_page) == 3
|
||||
assert has_more_second is True
|
||||
|
||||
# Verify no overlap between pages
|
||||
first_user_ids = {member.user_id for member in first_page}
|
||||
second_user_ids = {member.user_id for member in second_page}
|
||||
assert first_user_ids.isdisjoint(second_user_ids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_empty_org(async_session_maker):
|
||||
"""Test pagination with empty organization returns empty list."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 0
|
||||
assert has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_ordering(async_session_maker):
|
||||
"""Test that pagination orders results by user_id."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='admin', rank=1)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
# Create users with specific IDs to test ordering
|
||||
user_ids = [uuid.uuid4() for _ in range(5)]
|
||||
user_ids.sort() # Sort to verify ordering
|
||||
|
||||
users = [
|
||||
User(id=user_id, current_org_id=org.id, email=f'user{i}@example.com')
|
||||
for i, user_id in enumerate(user_ids)
|
||||
]
|
||||
session.add_all(users)
|
||||
await session.flush()
|
||||
|
||||
# Create org members in reverse order to test that ordering works
|
||||
org_members = [
|
||||
OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user_id,
|
||||
role_id=role.id,
|
||||
llm_api_key=f'test-key-{i}',
|
||||
status='active',
|
||||
)
|
||||
for i, user_id in enumerate(reversed(user_ids))
|
||||
]
|
||||
session.add_all(org_members)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 5
|
||||
# Verify members are ordered by user_id
|
||||
member_user_ids = [member.user_id for member in members]
|
||||
assert member_user_ids == sorted(member_user_ids)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_members_paginated_eager_loading(async_session_maker):
|
||||
"""Test that user and role relationships are eagerly loaded."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
role = Role(name='owner', rank=10)
|
||||
session.add(role)
|
||||
await session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id, email='test@example.com')
|
||||
session.add(user)
|
||||
await session.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id,
|
||||
llm_api_key='test-key',
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=0, limit=10
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(members) == 1
|
||||
member = members[0]
|
||||
# Verify relationships are loaded (not lazy)
|
||||
assert member.user is not None
|
||||
assert member.user.email == 'test@example.com'
|
||||
assert member.role is not None
|
||||
assert member.role.name == 'owner'
|
||||
assert member.role.rank == 10
|
||||
|
||||
@@ -482,7 +482,7 @@ async def test_get_org_credits_success(mock_litellm_api):
|
||||
spend = 25.0
|
||||
|
||||
mock_team_info = {
|
||||
'litellm_budget_table': {'max_budget': max_budget},
|
||||
'max_budget_in_team': max_budget,
|
||||
'spend': spend,
|
||||
}
|
||||
|
||||
@@ -1535,6 +1535,118 @@ async def test_update_org_with_permissions_database_error(session_maker):
|
||||
assert 'Failed to update organization' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_duplicate_name_raises_org_name_exists_error(
|
||||
session_maker,
|
||||
):
|
||||
"""
|
||||
GIVEN: User updates org name to a name already used by another organization
|
||||
WHEN: update_org_with_permissions is called
|
||||
THEN: OrgNameExistsError is raised with the conflicting name
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
other_org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
duplicate_name = 'Existing Org Name'
|
||||
|
||||
mock_current_org = Org(
|
||||
id=org_id,
|
||||
name='My Org',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
org_version=5,
|
||||
)
|
||||
mock_org_with_name = Org(
|
||||
id=other_org_id,
|
||||
name=duplicate_name,
|
||||
contact_name='Jane Doe',
|
||||
contact_email='jane@example.com',
|
||||
)
|
||||
|
||||
from server.routes.org_models import OrgUpdate
|
||||
|
||||
update_data = OrgUpdate(name=duplicate_name)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_current_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_name',
|
||||
return_value=mock_org_with_name,
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNameExistsError) as exc_info:
|
||||
await OrgService.update_org_with_permissions(
|
||||
org_id=org_id,
|
||||
update_data=update_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
assert duplicate_name in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_same_name_allowed(session_maker):
|
||||
"""
|
||||
GIVEN: User updates org with name unchanged (same as current org name)
|
||||
WHEN: update_org_with_permissions is called
|
||||
THEN: No OrgNameExistsError; update proceeds (name uniqueness allows same org)
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
current_name = 'My Org'
|
||||
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name=current_name,
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
org_version=5,
|
||||
)
|
||||
|
||||
from server.routes.org_models import OrgUpdate
|
||||
|
||||
update_data = OrgUpdate(name=current_name)
|
||||
|
||||
with (
|
||||
patch('storage.org_store.session_maker', session_maker),
|
||||
patch('storage.org_member_store.session_maker', session_maker),
|
||||
patch('storage.role_store.session_maker', session_maker),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_name',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.update_org',
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.update_org_with_permissions(
|
||||
org_id=org_id,
|
||||
update_data=update_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == current_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_with_permissions_only_llm_fields(session_maker):
|
||||
"""
|
||||
@@ -1657,3 +1769,258 @@ async def test_update_org_with_permissions_only_non_llm_fields(session_maker):
|
||||
assert result.contact_name == 'Jane Doe'
|
||||
assert result.conversation_expiration == 60
|
||||
assert result.enable_proactive_conversation_starters is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_true_when_enabled():
|
||||
"""
|
||||
GIVEN: User has current_org with byor_export_enabled=True
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns True
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.byor_export_enabled = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_disabled():
|
||||
"""
|
||||
GIVEN: User has current_org with byor_export_enabled=False
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.byor_export_enabled = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_user_not_found():
|
||||
"""
|
||||
GIVEN: User does not exist
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'nonexistent-user'
|
||||
|
||||
with patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_no_current_org():
|
||||
"""
|
||||
GIVEN: User exists but has no current_org_id
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = None
|
||||
|
||||
with patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_byor_export_enabled_returns_false_when_org_not_found():
|
||||
"""
|
||||
GIVEN: User has current_org_id but org does not exist
|
||||
WHEN: check_byor_export_enabled is called
|
||||
THEN: Returns False
|
||||
"""
|
||||
# Arrange
|
||||
user_id = 'test-user-123'
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
with (
|
||||
patch(
|
||||
'storage.org_service.UserStore.get_user_by_id_async',
|
||||
AsyncMock(return_value=mock_user),
|
||||
),
|
||||
patch(
|
||||
'storage.org_service.OrgStore.get_org_by_id',
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.check_byor_export_enabled(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_success():
|
||||
"""
|
||||
GIVEN: Valid org_id and user_id where user is a member
|
||||
WHEN: switch_org is called
|
||||
THEN: User's current_org_id is updated and org is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.UserStore.update_current_org',
|
||||
return_value=mock_updated_user,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.name == 'Target Organization'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_org_not_found():
|
||||
"""
|
||||
GIVEN: Organization does not exist
|
||||
WHEN: switch_org is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
assert str(org_id) in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_user_not_member():
|
||||
"""
|
||||
GIVEN: User is not a member of the organization
|
||||
WHEN: switch_org is called
|
||||
THEN: OrgAuthorizationError is raised
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=False),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgAuthorizationError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
assert 'member' in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_user_not_found():
|
||||
"""
|
||||
GIVEN: User does not exist in database
|
||||
WHEN: switch_org is called
|
||||
THEN: OrgDatabaseError is raised
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch('storage.org_service.UserStore.update_current_org', return_value=None),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
assert 'User not found' in str(exc_info.value)
|
||||
|
||||
@@ -786,3 +786,23 @@ def test_get_user_orgs_paginated_ordering(session_maker, mock_litellm_api):
|
||||
assert orgs[0].name == 'Apple Org'
|
||||
assert orgs[1].name == 'Banana Org'
|
||||
assert orgs[2].name == 'Zebra Org'
|
||||
|
||||
|
||||
def test_orphaned_user_error_contains_user_ids():
|
||||
"""
|
||||
GIVEN: OrphanedUserError is created with a list of user IDs
|
||||
WHEN: The error message is accessed
|
||||
THEN: Message includes the count and stores user IDs
|
||||
"""
|
||||
# Arrange
|
||||
from server.routes.org_models import OrphanedUserError
|
||||
|
||||
user_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||
|
||||
# Act
|
||||
error = OrphanedUserError(user_ids)
|
||||
|
||||
# Assert
|
||||
assert error.user_ids == user_ids
|
||||
assert '2 user(s)' in str(error)
|
||||
assert 'no remaining organization' in str(error)
|
||||
|
||||
@@ -3,7 +3,9 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine'), patch('storage.database.a_engine'):
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from integrations.github.github_view import get_user_proactive_conversation_setting
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
@@ -398,6 +398,121 @@ async def test_create_user_contact_name_falls_back_to_username():
|
||||
assert org.contact_name == 'jdoe'
|
||||
|
||||
|
||||
# --- Tests for email fields in create_user() ---
|
||||
# create_user() should populate user.email and user.email_verified from the
|
||||
# Keycloak user_info, ensuring the user table has the correct email data.
|
||||
|
||||
|
||||
class _StopAfterUserCreation(Exception):
|
||||
"""Halt create_user() after User creation for email field inspection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_sets_email_from_user_info():
|
||||
"""create_user() should set user.email and user.email_verified from user_info."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
user_info = {
|
||||
'preferred_username': 'testuser',
|
||||
'email': 'testuser@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_settings = Settings(language='en')
|
||||
mock_role = MagicMock()
|
||||
mock_role.id = 1
|
||||
|
||||
with (
|
||||
patch('storage.user_store.session_maker', mock_sm),
|
||||
patch.object(
|
||||
UserStore,
|
||||
'create_default_settings',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_settings,
|
||||
),
|
||||
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
|
||||
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
|
||||
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
|
||||
patch(
|
||||
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
|
||||
return_value={'llm_model': None, 'llm_base_url': None},
|
||||
),
|
||||
patch.object(
|
||||
mock_session,
|
||||
'commit',
|
||||
side_effect=_StopAfterUserCreation,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
with pytest.raises(_StopAfterUserCreation):
|
||||
await UserStore.create_user(user_id, user_info)
|
||||
|
||||
# Assert - User is the second object added to session (after Org)
|
||||
user = mock_session.add.call_args_list[1][0][0]
|
||||
assert isinstance(user, User)
|
||||
assert user.email == 'testuser@example.com'
|
||||
assert user.email_verified is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_handles_missing_email_verified():
|
||||
"""create_user() should handle missing email_verified in user_info gracefully."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
user_info = {
|
||||
'preferred_username': 'testuser',
|
||||
'email': 'testuser@example.com',
|
||||
# email_verified is not present
|
||||
}
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_sm.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_settings = Settings(language='en')
|
||||
mock_role = MagicMock()
|
||||
mock_role.id = 1
|
||||
|
||||
with (
|
||||
patch('storage.user_store.session_maker', mock_sm),
|
||||
patch.object(
|
||||
UserStore,
|
||||
'create_default_settings',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_settings,
|
||||
),
|
||||
patch('storage.org_store.OrgStore.get_kwargs_from_settings', return_value={}),
|
||||
patch.object(UserStore, 'get_kwargs_from_settings', return_value={}),
|
||||
patch('storage.user_store.RoleStore.get_role_by_name', return_value=mock_role),
|
||||
patch(
|
||||
'storage.org_member_store.OrgMemberStore.get_kwargs_from_settings',
|
||||
return_value={'llm_model': None, 'llm_base_url': None},
|
||||
),
|
||||
patch.object(
|
||||
mock_session,
|
||||
'commit',
|
||||
side_effect=_StopAfterUserCreation,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
with pytest.raises(_StopAfterUserCreation):
|
||||
await UserStore.create_user(user_id, user_info)
|
||||
|
||||
# Assert - User should have email but email_verified should be None
|
||||
user = mock_session.add.call_args_list[1][0][0]
|
||||
assert isinstance(user, User)
|
||||
assert user.email == 'testuser@example.com'
|
||||
assert user.email_verified is None
|
||||
|
||||
|
||||
# --- Tests for backfill_contact_name on login ---
|
||||
# Existing users created before the resolve_display_name fix may have
|
||||
# username-style values in contact_name. The backfill updates these to
|
||||
@@ -522,3 +637,46 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker):
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
assert org.contact_name == 'Custom Corp Name'
|
||||
|
||||
|
||||
def test_update_current_org_success(session_maker):
|
||||
"""
|
||||
GIVEN: User exists in database
|
||||
WHEN: update_current_org is called with new org_id
|
||||
THEN: User's current_org_id is updated and user is returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
initial_org_id = uuid.uuid4()
|
||||
new_org_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
user = User(id=uuid.UUID(user_id), current_org_id=initial_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
result = UserStore.update_current_org(user_id, new_org_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.current_org_id == new_org_id
|
||||
|
||||
|
||||
def test_update_current_org_user_not_found(session_maker):
|
||||
"""
|
||||
GIVEN: User does not exist in database
|
||||
WHEN: update_current_org is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Act
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
result = UserStore.update_current_org(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
27
frontend/__tests__/api/v1-conversation-service.test.ts
Normal file
27
frontend/__tests__/api/v1-conversation-service.test.ts
Normal file
@@ -0,0 +1,27 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
|
||||
const { mockGet } = vi.hoisted(() => ({ mockGet: vi.fn() }));
|
||||
vi.mock("#/api/open-hands-axios", () => ({
|
||||
openHands: { get: mockGet },
|
||||
}));
|
||||
|
||||
describe("V1ConversationService", () => {
|
||||
describe("readConversationFile", () => {
|
||||
it("uses default plan path when filePath is not provided", async () => {
|
||||
// Arrange
|
||||
const conversationId = "conv-123";
|
||||
mockGet.mockResolvedValue({ data: "# PLAN content" });
|
||||
|
||||
// Act
|
||||
await V1ConversationService.readConversationFile(conversationId);
|
||||
|
||||
// Assert
|
||||
expect(mockGet).toHaveBeenCalledTimes(1);
|
||||
const callUrl = mockGet.mock.calls[0][0] as string;
|
||||
expect(callUrl).toContain(
|
||||
"file_path=%2Fworkspace%2Fproject%2F.agents_tmp%2FPLAN.md",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -125,7 +125,7 @@ describe("ChatInterface - Chat Suggestions", () => {
|
||||
});
|
||||
|
||||
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
data: { APP_MODE: "local" },
|
||||
data: { app_mode: "local" },
|
||||
});
|
||||
(useGetTrajectory as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
mutate: vi.fn(),
|
||||
@@ -258,7 +258,7 @@ describe("ChatInterface - Empty state", () => {
|
||||
errorMessage: null,
|
||||
});
|
||||
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
data: { APP_MODE: "local" },
|
||||
data: { app_mode: "local" },
|
||||
});
|
||||
(useGetTrajectory as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
mutate: vi.fn(),
|
||||
|
||||
@@ -113,15 +113,15 @@ describe("ExpandableMessage", () => {
|
||||
|
||||
it("should render the out of credits message when the user is out of credits", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - We only care about the APP_MODE and FEATURE_FLAGS fields
|
||||
// @ts-expect-error - We only care about the app_mode and feature_flags fields
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: true,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
},
|
||||
});
|
||||
const RouterStub = createRoutesStub([
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import { QueryClient } from "@tanstack/react-query";
|
||||
import { MemoryRouter, Route, Routes } from "react-router";
|
||||
import { render } from "@testing-library/react";
|
||||
import { QueryClientProvider } from "@tanstack/react-query";
|
||||
import { useParamsMock, createUserMessageEvent } from "test-utils";
|
||||
import { ChatInterface } from "#/components/features/chat/chat-interface";
|
||||
import { useWsClient } from "#/context/ws-client-provider";
|
||||
import { useConversationId } from "#/hooks/use-conversation-id";
|
||||
import { useActiveConversation } from "#/hooks/query/use-active-conversation";
|
||||
import { useConversationWebSocket } from "#/contexts/conversation-websocket-context";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { useGetTrajectory } from "#/hooks/mutation/use-get-trajectory";
|
||||
import { useUnifiedUploadFiles } from "#/hooks/mutation/use-unified-upload-files";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import { useAgentState } from "#/hooks/use-agent-state";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { OpenHandsAction } from "#/types/core/actions";
|
||||
|
||||
// Module-level mocks
|
||||
vi.mock("#/context/ws-client-provider");
|
||||
vi.mock("#/hooks/query/use-config");
|
||||
vi.mock("#/hooks/mutation/use-get-trajectory");
|
||||
vi.mock("#/hooks/mutation/use-unified-upload-files");
|
||||
vi.mock("#/hooks/use-conversation-id");
|
||||
vi.mock("#/hooks/query/use-active-conversation");
|
||||
vi.mock("#/contexts/conversation-websocket-context");
|
||||
|
||||
vi.mock("#/hooks/use-user-providers", () => ({
|
||||
useUserProviders: () => ({
|
||||
providers: [],
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-conversation-name-context-menu", () => ({
|
||||
useConversationNameContextMenu: () => ({
|
||||
isOpen: false,
|
||||
contextMenuRef: { current: null },
|
||||
handleContextMenu: vi.fn(),
|
||||
handleClose: vi.fn(),
|
||||
handleRename: vi.fn(),
|
||||
handleDelete: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-agent-state", () => ({
|
||||
useAgentState: vi.fn(() => ({
|
||||
curAgentState: AgentState.AWAITING_USER_INPUT,
|
||||
})),
|
||||
}));
|
||||
|
||||
// Helper to render with QueryClient and route params
|
||||
const renderWithQueryClient = (
|
||||
ui: React.ReactElement,
|
||||
queryClient: QueryClient,
|
||||
route = "/test-conversation-id",
|
||||
) =>
|
||||
render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<MemoryRouter initialEntries={[route]}>
|
||||
<Routes>
|
||||
<Route path="/:conversationId" element={ui} />
|
||||
<Route path="/" element={ui} />
|
||||
</Routes>
|
||||
</MemoryRouter>
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
|
||||
// V0 user event (numeric id, action property)
|
||||
const createV0UserEvent = (): OpenHandsAction => ({
|
||||
id: 1,
|
||||
source: "user",
|
||||
action: "message",
|
||||
args: {
|
||||
content: "Hello from V0",
|
||||
image_urls: [],
|
||||
file_urls: [],
|
||||
},
|
||||
message: "Hello from V0",
|
||||
timestamp: "2025-07-01T00:00:00Z",
|
||||
});
|
||||
|
||||
describe("ChatInterface – message display continuity (spec 3.1)", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
useParamsMock.mockReturnValue({ conversationId: "test-conversation-id" });
|
||||
vi.mocked(useConversationId).mockReturnValue({
|
||||
conversationId: "test-conversation-id",
|
||||
});
|
||||
|
||||
// Default: V0, no loading, no events
|
||||
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
send: vi.fn(),
|
||||
isLoadingMessages: false,
|
||||
parsedEvents: [],
|
||||
});
|
||||
|
||||
(useConfig as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
data: { app_mode: "local" },
|
||||
});
|
||||
(useGetTrajectory as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
mutate: vi.fn(),
|
||||
mutateAsync: vi.fn(),
|
||||
isLoading: false,
|
||||
});
|
||||
(
|
||||
useUnifiedUploadFiles as unknown as ReturnType<typeof vi.fn>
|
||||
).mockReturnValue({
|
||||
mutateAsync: vi
|
||||
.fn()
|
||||
.mockResolvedValue({ skipped_files: [], uploaded_files: [] }),
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
// Default: no conversation (V0 behavior)
|
||||
vi.mocked(useActiveConversation).mockReturnValue({
|
||||
data: undefined,
|
||||
} as ReturnType<typeof useActiveConversation>);
|
||||
|
||||
// Default: no websocket context
|
||||
vi.mocked(useConversationWebSocket).mockReturnValue(null);
|
||||
});
|
||||
|
||||
describe("V1 conversations", () => {
|
||||
beforeEach(() => {
|
||||
// Set up V1 conversation
|
||||
vi.mocked(useActiveConversation).mockReturnValue({
|
||||
data: { conversation_version: "V1" },
|
||||
} as ReturnType<typeof useActiveConversation>);
|
||||
});
|
||||
|
||||
it("shows messages immediately when V1 events exist in store, even while loading", () => {
|
||||
// Simulate: history is loading but events already exist in store (e.g., remount)
|
||||
vi.mocked(useConversationWebSocket).mockReturnValue({
|
||||
isLoadingHistory: true,
|
||||
connectionState: "OPEN",
|
||||
sendMessage: vi.fn(),
|
||||
});
|
||||
|
||||
// Put V1 user events in the store
|
||||
const v1UserEvent = createUserMessageEvent("evt-1");
|
||||
useEventStore.setState({
|
||||
events: [v1UserEvent],
|
||||
uiEvents: [v1UserEvent],
|
||||
});
|
||||
|
||||
renderWithQueryClient(<ChatInterface />, queryClient);
|
||||
|
||||
// AC1: Messages should display immediately without skeleton
|
||||
expect(
|
||||
screen.queryByTestId("chat-messages-skeleton"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("loading-spinner")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("shows skeleton when store is empty and loading", () => {
|
||||
// Simulate: first load, no events yet
|
||||
vi.mocked(useConversationWebSocket).mockReturnValue({
|
||||
isLoadingHistory: true,
|
||||
connectionState: "OPEN",
|
||||
sendMessage: vi.fn(),
|
||||
});
|
||||
|
||||
// Store is empty
|
||||
useEventStore.setState({
|
||||
events: [],
|
||||
uiEvents: [],
|
||||
});
|
||||
|
||||
renderWithQueryClient(<ChatInterface />, queryClient);
|
||||
|
||||
// AC5: Genuine first-load shows skeleton
|
||||
expect(screen.getByTestId("chat-messages-skeleton")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("shows messages when loading is already false on mount (edge case)", () => {
|
||||
// Simulate: component re-mounts when WebSocket has already finished loading
|
||||
vi.mocked(useConversationWebSocket).mockReturnValue({
|
||||
isLoadingHistory: false,
|
||||
connectionState: "OPEN",
|
||||
sendMessage: vi.fn(),
|
||||
});
|
||||
|
||||
// V1 events in store
|
||||
const v1UserEvent = createUserMessageEvent("evt-2");
|
||||
useEventStore.setState({
|
||||
events: [v1UserEvent],
|
||||
uiEvents: [v1UserEvent],
|
||||
});
|
||||
|
||||
renderWithQueryClient(<ChatInterface />, queryClient);
|
||||
|
||||
// Messages should display, no skeleton
|
||||
expect(
|
||||
screen.queryByTestId("chat-messages-skeleton"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByTestId("loading-spinner")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("V0 conversations", () => {
|
||||
it("shows messages when V0 events exist in store even if isLoadingMessages is true", () => {
|
||||
// Simulate: loading flag is still true but events already exist in store (e.g., remount)
|
||||
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
send: vi.fn(),
|
||||
isLoadingMessages: true,
|
||||
parsedEvents: [],
|
||||
});
|
||||
|
||||
// Put V0 user events in the store
|
||||
useEventStore.setState({
|
||||
events: [createV0UserEvent()],
|
||||
uiEvents: [],
|
||||
});
|
||||
|
||||
renderWithQueryClient(<ChatInterface />, queryClient);
|
||||
|
||||
// AC1/AC4: Messages display immediately, no skeleton
|
||||
expect(
|
||||
screen.queryByTestId("chat-messages-skeleton"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("shows skeleton when store is empty and isLoadingMessages is true", () => {
|
||||
// Simulate: genuine first load, no events yet
|
||||
(useWsClient as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
send: vi.fn(),
|
||||
isLoadingMessages: true,
|
||||
parsedEvents: [],
|
||||
});
|
||||
|
||||
// Store is empty
|
||||
useEventStore.setState({
|
||||
events: [],
|
||||
uiEvents: [],
|
||||
});
|
||||
|
||||
renderWithQueryClient(<ChatInterface />, queryClient);
|
||||
|
||||
// AC5: Genuine first-load shows skeleton
|
||||
expect(screen.getByTestId("chat-messages-skeleton")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,7 +5,7 @@ import { EventMessage } from "#/components/features/chat/event-message";
|
||||
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
}),
|
||||
}));
|
||||
|
||||
|
||||
@@ -0,0 +1,287 @@
|
||||
import { fireEvent, render, screen, within } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { act } from "react";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { AlertBanner } from "#/components/features/alerts/alert-banner";
|
||||
|
||||
// Mock react-i18next
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { time?: string }) => {
|
||||
const translations: Record<string, string> = {
|
||||
MAINTENANCE$SCHEDULED_MESSAGE: `Scheduled maintenance will begin at ${options?.time || "{{time}}"}`,
|
||||
ALERT$FAULTY_MODELS_MESSAGE:
|
||||
"The following models are currently reporting errors:",
|
||||
"ERROR$TRANSLATED_KEY": "This is a translated error message",
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe("AlertBanner", () => {
|
||||
afterEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
describe("Maintenance alerts", () => {
|
||||
it("renders maintenance banner with formatted time", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00";
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
const { container } = render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner maintenanceStartTime={startTime} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).toBeInTheDocument();
|
||||
|
||||
const svgIcon = container.querySelector("svg");
|
||||
expect(svgIcon).toBeInTheDocument();
|
||||
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("click on dismiss button removes banner", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00";
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner maintenanceStartTime={startTime} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
|
||||
act(() => {
|
||||
fireEvent.click(button!);
|
||||
});
|
||||
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("banner reappears when updatedAt changes", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00";
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
const newUpdatedAt = "2024-01-15T10:00:00Z";
|
||||
|
||||
const { rerender } = render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner maintenanceStartTime={startTime} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
|
||||
act(() => {
|
||||
fireEvent.click(button!);
|
||||
});
|
||||
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
|
||||
rerender(
|
||||
<MemoryRouter>
|
||||
<AlertBanner
|
||||
maintenanceStartTime={startTime}
|
||||
updatedAt={newUpdatedAt}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("alert-banner")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Faulty models alerts", () => {
|
||||
it("renders banner with faulty models list", () => {
|
||||
const faultyModels = ["gpt-4", "claude-3"];
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner faultyModels={faultyModels} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).toBeInTheDocument();
|
||||
|
||||
expect(
|
||||
screen.getByText(/The following models are currently reporting errors:/),
|
||||
).toBeInTheDocument();
|
||||
// Models are displayed in the order they are provided
|
||||
expect(screen.getByText(/gpt-4/)).toBeInTheDocument();
|
||||
expect(screen.getByText(/claude-3/)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("does not render banner when faulty models array is empty", () => {
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner faultyModels={[]} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("banner reappears when updatedAt changes", () => {
|
||||
const faultyModels = ["gpt-4"];
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
const newUpdatedAt = "2024-01-15T10:00:00Z";
|
||||
|
||||
const { rerender } = render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner faultyModels={faultyModels} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
|
||||
act(() => {
|
||||
fireEvent.click(button!);
|
||||
});
|
||||
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
|
||||
rerender(
|
||||
<MemoryRouter>
|
||||
<AlertBanner faultyModels={faultyModels} updatedAt={newUpdatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("alert-banner")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error message alerts", () => {
|
||||
it("renders banner with translated error message", () => {
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner errorMessage="ERROR$TRANSLATED_KEY" updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("This is a translated error message"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders banner with raw error message when no translation exists", () => {
|
||||
const rawErrorMessage = "This is a raw error without translation";
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner errorMessage={rawErrorMessage} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).toBeInTheDocument();
|
||||
expect(screen.getByText(rawErrorMessage)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("does not render banner when error message is empty", () => {
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner errorMessage="" updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("does not render banner when error message is null", () => {
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner errorMessage={null} updatedAt={updatedAt} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Multiple alerts", () => {
|
||||
it("renders all alerts when multiple conditions are present", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00";
|
||||
const faultyModels = ["gpt-4"];
|
||||
const errorMessage = "ERROR$TRANSLATED_KEY";
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner
|
||||
maintenanceStartTime={startTime}
|
||||
faultyModels={faultyModels}
|
||||
errorMessage={errorMessage}
|
||||
updatedAt={updatedAt}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
expect(banner).toBeInTheDocument();
|
||||
|
||||
expect(
|
||||
screen.getByText(/Scheduled maintenance will begin at/),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(/The following models are currently reporting errors:/),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("This is a translated error message"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("dismissing hides all alerts", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00";
|
||||
const faultyModels = ["gpt-4"];
|
||||
const updatedAt = "2024-01-14T10:00:00Z";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<AlertBanner
|
||||
maintenanceStartTime={startTime}
|
||||
faultyModels={faultyModels}
|
||||
updatedAt={updatedAt}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const banner = screen.queryByTestId("alert-banner");
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
|
||||
act(() => {
|
||||
fireEvent.click(button!);
|
||||
});
|
||||
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -151,8 +151,9 @@ describe("LoginContent", () => {
|
||||
await user.click(githubButton);
|
||||
|
||||
// Wait for async handleAuthRedirect to complete
|
||||
// The URL includes state parameter added by handleAuthRedirect
|
||||
await waitFor(() => {
|
||||
expect(window.location.href).toBe(mockUrl);
|
||||
expect(window.location.href).toContain(mockUrl);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -201,4 +202,103 @@ describe("LoginContent", () => {
|
||||
|
||||
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display invitation pending message when hasInvitation is true", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasInvitation
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.getByText("AUTH$INVITATION_PENDING")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display invitation pending message when hasInvitation is false", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
hasInvitation={false}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.queryByText("AUTH$INVITATION_PENDING"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call buildOAuthStateData when clicking auth button", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockBuildOAuthStateData = vi.fn((baseState) => ({
|
||||
...baseState,
|
||||
invitation_token: "inv-test-token-12345",
|
||||
}));
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/login/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
buildOAuthStateData={mockBuildOAuthStateData}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
name: "GITHUB$CONNECT_TO_GITHUB",
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockBuildOAuthStateData).toHaveBeenCalled();
|
||||
const callArg = mockBuildOAuthStateData.mock.calls[0][0];
|
||||
expect(callArg).toHaveProperty("redirect_url");
|
||||
});
|
||||
});
|
||||
|
||||
it("should encode state with invitation token when buildOAuthStateData provides token", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockBuildOAuthStateData = vi.fn((baseState) => ({
|
||||
...baseState,
|
||||
invitation_token: "inv-test-token-12345",
|
||||
}));
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/login/oauth/authorize"
|
||||
appMode="saas"
|
||||
providersConfigured={["github"]}
|
||||
buildOAuthStateData={mockBuildOAuthStateData}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const githubButton = screen.getByRole("button", {
|
||||
name: "GITHUB$CONNECT_TO_GITHUB",
|
||||
});
|
||||
await user.click(githubButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const redirectUrl = window.location.href;
|
||||
// The URL should contain an encoded state parameter
|
||||
expect(redirectUrl).toContain("state=");
|
||||
// Decode and verify the state contains invitation_token
|
||||
const url = new URL(redirectUrl);
|
||||
const state = url.searchParams.get("state");
|
||||
if (state) {
|
||||
const decodedState = JSON.parse(atob(state));
|
||||
expect(decodedState.invitation_token).toBe("inv-test-token-12345");
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -92,19 +92,21 @@ describe("PlanPreview", () => {
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is null", () => {
|
||||
renderPlanPreview(<PlanPreview planContent={null} />);
|
||||
// Arrange & Act
|
||||
const { container } = renderPlanPreview(<PlanPreview planContent={null} />);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
// Assert
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should render nothing when planContent is undefined", () => {
|
||||
renderPlanPreview(<PlanPreview planContent={undefined} />);
|
||||
// Arrange & Act
|
||||
const { container } = renderPlanPreview(
|
||||
<PlanPreview planContent={undefined} />,
|
||||
);
|
||||
|
||||
const contentDiv = screen.getByTestId("plan-preview-content");
|
||||
expect(contentDiv).toBeInTheDocument();
|
||||
expect(contentDiv.textContent?.trim() || "").toBe("");
|
||||
// Assert
|
||||
expect(container.firstChild).toBeNull();
|
||||
});
|
||||
|
||||
it("should render markdown content when planContent is provided", () => {
|
||||
@@ -170,7 +172,7 @@ describe("PlanPreview", () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const expectedPrompt =
|
||||
"Execute the plan based on the workspace/project/PLAN.md file.";
|
||||
"Execute the plan based on the .agents_tmp/PLAN.md file.";
|
||||
renderPlanPreview(<PlanPreview planContent="Plan content" />);
|
||||
const buildButton = screen.getByTestId("plan-preview-build-button");
|
||||
|
||||
@@ -201,7 +203,7 @@ describe("PlanPreview", () => {
|
||||
useOptimisticUserMessageStore.setState({ optimisticUserMessage: null });
|
||||
const user = userEvent.setup();
|
||||
const expectedPrompt =
|
||||
"Execute the plan based on the workspace/project/PLAN.md file.";
|
||||
"Execute the plan based on the .agents_tmp/PLAN.md file.";
|
||||
renderPlanPreview(<PlanPreview planContent="Plan content" />);
|
||||
const buildButton = screen.getByTestId("plan-preview-build-button");
|
||||
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
|
||||
// Mutable mock state for controlling breakpoint
|
||||
let mockIsMobile = false;
|
||||
|
||||
// Track ChatInterface unmount via vi.fn()
|
||||
const chatInterfaceUnmount = vi.fn();
|
||||
|
||||
vi.mock("#/hooks/use-breakpoint", () => ({
|
||||
useBreakpoint: () => mockIsMobile,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-resizable-panels", () => ({
|
||||
useResizablePanels: () => ({
|
||||
leftWidth: 50,
|
||||
rightWidth: 50,
|
||||
isDragging: false,
|
||||
containerRef: { current: null },
|
||||
handleMouseDown: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/stores/conversation-store", () => ({
|
||||
useConversationStore: () => ({
|
||||
isRightPanelShown: false,
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock ChatInterface with useEffect to track mount/unmount lifecycle
|
||||
vi.mock("#/components/features/chat/chat-interface", () => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-require-imports
|
||||
const React = require("react");
|
||||
return {
|
||||
ChatInterface: () => {
|
||||
React.useEffect(() => {
|
||||
return () => chatInterfaceUnmount();
|
||||
}, []);
|
||||
return <div data-testid="chat-interface">Chat Interface</div>;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock(
|
||||
"#/components/features/conversation/conversation-tabs/conversation-tab-content/conversation-tab-content",
|
||||
() => ({
|
||||
ConversationTabContent: () => <div data-testid="tab-content" />,
|
||||
}),
|
||||
);
|
||||
|
||||
import { ConversationMain } from "#/components/features/conversation/conversation-main/conversation-main";
|
||||
|
||||
describe("ConversationMain - Layout Transition Stability", () => {
|
||||
beforeEach(() => {
|
||||
mockIsMobile = false;
|
||||
chatInterfaceUnmount.mockClear();
|
||||
});
|
||||
|
||||
it("renders ChatInterface at desktop width", () => {
|
||||
mockIsMobile = false;
|
||||
render(<ConversationMain />);
|
||||
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("renders ChatInterface at mobile width", () => {
|
||||
mockIsMobile = true;
|
||||
render(<ConversationMain />);
|
||||
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("does not unmount ChatInterface when crossing from desktop to mobile", () => {
|
||||
mockIsMobile = false;
|
||||
const { rerender } = render(<ConversationMain />);
|
||||
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
|
||||
|
||||
// Cross the breakpoint to mobile
|
||||
mockIsMobile = true;
|
||||
rerender(<ConversationMain />);
|
||||
|
||||
// ChatInterface must NOT have been unmounted and remounted
|
||||
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
|
||||
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("does not unmount ChatInterface when crossing from mobile to desktop", () => {
|
||||
mockIsMobile = true;
|
||||
const { rerender } = render(<ConversationMain />);
|
||||
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
|
||||
|
||||
// Cross the breakpoint to desktop
|
||||
mockIsMobile = false;
|
||||
rerender(<ConversationMain />);
|
||||
|
||||
// ChatInterface must NOT have been unmounted and remounted
|
||||
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
|
||||
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("survives rapid back-and-forth resize without unmounting ChatInterface", () => {
|
||||
mockIsMobile = false;
|
||||
const { rerender } = render(<ConversationMain />);
|
||||
|
||||
// Simulate rapid resize back and forth across the breakpoint
|
||||
for (const mobile of [true, false, true, false, true]) {
|
||||
mockIsMobile = mobile;
|
||||
rerender(<ConversationMain />);
|
||||
}
|
||||
|
||||
expect(chatInterfaceUnmount).not.toHaveBeenCalled();
|
||||
expect(screen.getByTestId("chat-interface")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -33,7 +33,7 @@ const {
|
||||
})),
|
||||
useConfigMock: vi.fn(() => ({
|
||||
data: {
|
||||
APP_MODE: "oss",
|
||||
app_mode: "oss",
|
||||
},
|
||||
})),
|
||||
}));
|
||||
@@ -659,7 +659,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
|
||||
|
||||
useConfigMock.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "saas",
|
||||
app_mode: "saas",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -685,7 +685,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
|
||||
|
||||
useConfigMock.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "saas",
|
||||
app_mode: "saas",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -718,7 +718,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
|
||||
|
||||
useConfigMock.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "saas",
|
||||
app_mode: "saas",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -751,7 +751,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
|
||||
|
||||
useConfigMock.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "saas",
|
||||
app_mode: "saas",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -781,7 +781,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
|
||||
|
||||
useConfigMock.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "saas",
|
||||
app_mode: "saas",
|
||||
},
|
||||
});
|
||||
|
||||
@@ -810,7 +810,7 @@ describe("ConversationNameContextMenu - Share Link Functionality", () => {
|
||||
|
||||
useConfigMock.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "saas",
|
||||
app_mode: "saas",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -177,10 +177,10 @@ describe("RepoConnector", () => {
|
||||
|
||||
it("should render the 'add github repos' link in dropdown if saas mode and github provider is set", async () => {
|
||||
const getConfiSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return the APP_MODE and APP_SLUG
|
||||
// @ts-expect-error - only return the app_mode and github_app_slug
|
||||
getConfiSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
APP_SLUG: "openhands",
|
||||
app_mode: "saas",
|
||||
github_app_slug: "openhands",
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
@@ -224,10 +224,10 @@ describe("RepoConnector", () => {
|
||||
|
||||
it("should not render the 'add github repos' link if github provider is not set", async () => {
|
||||
const getConfiSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return the APP_MODE and APP_SLUG
|
||||
// @ts-expect-error - only return the app_mode and github_app_slug for this test
|
||||
getConfiSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
APP_SLUG: "openhands",
|
||||
app_mode: "saas",
|
||||
github_app_slug: "openhands",
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
@@ -269,9 +269,9 @@ describe("RepoConnector", () => {
|
||||
|
||||
it("should not render the 'add github repos' link in dropdown if oss mode", async () => {
|
||||
const getConfiSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return the APP_MODE
|
||||
// @ts-expect-error - only return the app_mode
|
||||
getConfiSpy.mockResolvedValue({
|
||||
APP_MODE: "oss",
|
||||
app_mode: "oss",
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
|
||||
@@ -30,7 +30,7 @@ vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
}),
|
||||
}));
|
||||
|
||||
@@ -1,119 +0,0 @@
|
||||
import { fireEvent, render, screen, within } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { act } from "react";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { MaintenanceBanner } from "#/components/features/maintenance/maintenance-banner";
|
||||
|
||||
// Mock react-i18next
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: { time?: string }) => {
|
||||
const translations: Record<string, string> = {
|
||||
MAINTENANCE$SCHEDULED_MESSAGE: `Scheduled maintenance will begin at ${options?.time || "{{time}}"}`,
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe("MaintenanceBanner", () => {
|
||||
afterEach(() => {
|
||||
localStorage.clear();
|
||||
});
|
||||
|
||||
it("renders maintenance banner with formatted time", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00"; // EST timestamp
|
||||
|
||||
const { container } = render(
|
||||
<MemoryRouter>
|
||||
<MaintenanceBanner startTime={startTime} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Check if the banner is rendered
|
||||
const banner = screen.queryByTestId("maintenance-banner");
|
||||
expect(banner).toBeInTheDocument();
|
||||
|
||||
// Check if the warning icon (SVG) is present
|
||||
const svgIcon = container.querySelector("svg");
|
||||
expect(svgIcon).toBeInTheDocument();
|
||||
|
||||
// Check if the button to close is present
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
expect(button).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("handles invalid date gracefully", () => {
|
||||
// Suppress expected console.warn for invalid date parsing
|
||||
const consoleWarnSpy = vi
|
||||
.spyOn(console, "warn")
|
||||
.mockImplementation(() => {});
|
||||
|
||||
const invalidTime = "invalid-date";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MaintenanceBanner startTime={invalidTime} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Check if the banner is rendered
|
||||
const banner = screen.queryByTestId("maintenance-banner");
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
|
||||
// Restore console.warn
|
||||
consoleWarnSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("click on dismiss button removes banner", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00"; // EST timestamp
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<MaintenanceBanner startTime={startTime} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Check if the banner is rendered
|
||||
const banner = screen.queryByTestId("maintenance-banner");
|
||||
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
act(() => {
|
||||
fireEvent.click(button!);
|
||||
});
|
||||
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
});
|
||||
it("banner reappears after dismissing on next maintenance event(future time)", () => {
|
||||
const startTime = "2024-01-15T10:00:00-05:00"; // EST timestamp
|
||||
const nextStartTime = "2025-01-15T10:00:00-05:00"; // EST timestamp
|
||||
|
||||
const { rerender } = render(
|
||||
<MemoryRouter>
|
||||
<MaintenanceBanner startTime={startTime} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Check if the banner is rendered
|
||||
const banner = screen.queryByTestId("maintenance-banner");
|
||||
const button = within(banner!).queryByTestId("dismiss-button");
|
||||
|
||||
act(() => {
|
||||
fireEvent.click(button!);
|
||||
});
|
||||
|
||||
expect(banner).not.toBeInTheDocument();
|
||||
rerender(
|
||||
<MemoryRouter>
|
||||
<MaintenanceBanner startTime={nextStartTime} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("maintenance-banner")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -305,7 +305,7 @@ describe("MicroagentManagement", () => {
|
||||
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: {
|
||||
APP_MODE: "oss",
|
||||
app_mode: "oss",
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -27,17 +27,17 @@ describe("PaymentForm", () => {
|
||||
const renderPaymentForm = () => renderWithProviders(<PaymentForm />);
|
||||
|
||||
beforeEach(() => {
|
||||
// useBalance hook will return the balance only if the APP_MODE is "saas" and the billing feature is enabled
|
||||
// useBalance hook will return the balance only if the app_mode is "saas" and the billing feature is enabled
|
||||
// @ts-expect-error - partial mock for testing
|
||||
getConfigSpy.mockResolvedValue({
|
||||
APP_MODE: "saas",
|
||||
GITHUB_CLIENT_ID: "123",
|
||||
POSTHOG_CLIENT_KEY: "456",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: true,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "456",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,27 +9,34 @@ import { Sidebar } from "#/components/features/sidebar/sidebar";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { GetConfigResponse } from "#/api/option-service/option.types";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
|
||||
// Helper to create mock config with sensible defaults
|
||||
const createMockConfig = (
|
||||
overrides: Omit<Partial<GetConfigResponse>, "FEATURE_FLAGS"> & {
|
||||
FEATURE_FLAGS?: Partial<GetConfigResponse["FEATURE_FLAGS"]>;
|
||||
overrides: Omit<Partial<WebClientConfig>, "feature_flags"> & {
|
||||
feature_flags?: Partial<WebClientConfig["feature_flags"]>;
|
||||
} = {},
|
||||
): GetConfigResponse => {
|
||||
const { FEATURE_FLAGS: featureFlagOverrides, ...restOverrides } = overrides;
|
||||
): WebClientConfig => {
|
||||
const { feature_flags: featureFlagOverrides, ...restOverrides } = overrides;
|
||||
return {
|
||||
APP_MODE: "oss",
|
||||
GITHUB_CLIENT_ID: "test-client-id",
|
||||
POSTHOG_CLIENT_KEY: "test-posthog-key",
|
||||
FEATURE_FLAGS: {
|
||||
ENABLE_BILLING: false,
|
||||
HIDE_LLM_SETTINGS: false,
|
||||
ENABLE_JIRA: false,
|
||||
ENABLE_JIRA_DC: false,
|
||||
ENABLE_LINEAR: false,
|
||||
app_mode: "oss",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
...featureFlagOverrides,
|
||||
},
|
||||
providers_configured: [],
|
||||
maintenance_start_time: null,
|
||||
auth_url: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: "2024-01-14T10:00:00Z",
|
||||
github_app_slug: null,
|
||||
...restOverrides,
|
||||
};
|
||||
};
|
||||
@@ -76,9 +83,9 @@ describe("Sidebar", () => {
|
||||
});
|
||||
|
||||
describe("Settings modal auto-open behavior", () => {
|
||||
it("should NOT open settings modal when HIDE_LLM_SETTINGS is true even with 404 error", async () => {
|
||||
it("should NOT open settings modal when hide_llm_settings is true even with 404 error", async () => {
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: true } }),
|
||||
createMockConfig({ feature_flags: { hide_llm_settings: true } }),
|
||||
);
|
||||
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
|
||||
|
||||
@@ -89,21 +96,21 @@ describe("Sidebar", () => {
|
||||
expect(getSettingsSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// Settings modal should NOT appear when HIDE_LLM_SETTINGS is true
|
||||
// Settings modal should NOT appear when hide_llm_settings is true
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId("ai-config-modal")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should open settings modal when HIDE_LLM_SETTINGS is false and 404 error in OSS mode", async () => {
|
||||
it("should open settings modal when hide_llm_settings is false and 404 error in OSS mode", async () => {
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false } }),
|
||||
createMockConfig({ feature_flags: { hide_llm_settings: false } }),
|
||||
);
|
||||
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
|
||||
|
||||
renderSidebar();
|
||||
|
||||
// Settings modal should appear when HIDE_LLM_SETTINGS is false
|
||||
// Settings modal should appear when hide_llm_settings is false
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("ai-config-modal")).toBeInTheDocument();
|
||||
});
|
||||
@@ -112,8 +119,8 @@ describe("Sidebar", () => {
|
||||
it("should NOT open settings modal in SaaS mode even with 404 error", async () => {
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockConfig({
|
||||
APP_MODE: "saas",
|
||||
FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false },
|
||||
app_mode: "saas",
|
||||
feature_flags: { hide_llm_settings: false },
|
||||
}),
|
||||
);
|
||||
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
|
||||
@@ -133,7 +140,7 @@ describe("Sidebar", () => {
|
||||
|
||||
it("should NOT open settings modal when settings exist (no 404 error)", async () => {
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false } }),
|
||||
createMockConfig({ feature_flags: { hide_llm_settings: false } }),
|
||||
);
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
@@ -152,7 +159,7 @@ describe("Sidebar", () => {
|
||||
|
||||
it("should NOT open settings modal when on /settings path", async () => {
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockConfig({ FEATURE_FLAGS: { HIDE_LLM_SETTINGS: false } }),
|
||||
createMockConfig({ feature_flags: { hide_llm_settings: false } }),
|
||||
);
|
||||
getSettingsSpy.mockRejectedValue(createAxiosNotFoundErrorObject());
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ describe("PostHogWrapper", () => {
|
||||
// Mock the config fetch
|
||||
// @ts-expect-error - partial mock
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
POSTHOG_CLIENT_KEY: "test-posthog-key",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ const useIsAuthedMock = vi
|
||||
|
||||
const useConfigMock = vi
|
||||
.fn()
|
||||
.mockReturnValue({ data: { APP_MODE: "saas" }, isLoading: false });
|
||||
.mockReturnValue({ data: { app_mode: "saas" }, isLoading: false });
|
||||
|
||||
const useUserProvidersMock = vi
|
||||
.fn()
|
||||
@@ -46,7 +46,7 @@ describe("UserActions", () => {
|
||||
// Reset all mocks to default values before each test
|
||||
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -89,7 +89,7 @@ describe("UserActions", () => {
|
||||
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
|
||||
// Keep other mocks with default values
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -126,7 +126,7 @@ describe("UserActions", () => {
|
||||
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
|
||||
// Keep other mocks with default values
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -154,7 +154,7 @@ describe("UserActions", () => {
|
||||
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
|
||||
// Keep other mocks with default values
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -179,7 +179,7 @@ describe("UserActions", () => {
|
||||
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
|
||||
// Ensure config and providers are set correctly
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -211,7 +211,7 @@ describe("UserActions", () => {
|
||||
// Start with authentication and providers
|
||||
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -236,7 +236,7 @@ describe("UserActions", () => {
|
||||
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
|
||||
// Keep other mocks with default values
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
@@ -265,7 +265,7 @@ describe("UserActions", () => {
|
||||
// Ensure authentication and providers are set correctly
|
||||
useIsAuthedMock.mockReturnValue({ data: true, isLoading: false });
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { APP_MODE: "saas" },
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
useUserProvidersMock.mockReturnValue({
|
||||
|
||||
29
frontend/__tests__/helpers/mock-config.ts
Normal file
29
frontend/__tests__/helpers/mock-config.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
|
||||
/**
|
||||
* Creates a mock WebClientConfig with all required fields.
|
||||
* Use this helper to create test config objects with sensible defaults.
|
||||
*/
|
||||
export const createMockWebClientConfig = (
|
||||
overrides: Partial<WebClientConfig> = {},
|
||||
): WebClientConfig => ({
|
||||
app_mode: "oss",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
...overrides.feature_flags,
|
||||
},
|
||||
providers_configured: [],
|
||||
maintenance_start_time: null,
|
||||
auth_url: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: new Date().toISOString(),
|
||||
github_app_slug: null,
|
||||
...overrides,
|
||||
});
|
||||
@@ -0,0 +1,104 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import V1ConversationService from "#/api/conversation-service/v1-conversation-service.api";
|
||||
import { useCreateConversation } from "#/hooks/mutation/use-create-conversation";
|
||||
import { SuggestedTask } from "#/utils/types";
|
||||
|
||||
vi.mock("#/hooks/query/use-settings", async () => {
|
||||
const actual = await vi.importActual<typeof import("#/hooks/query/use-settings")>(
|
||||
"#/hooks/query/use-settings",
|
||||
);
|
||||
return {
|
||||
...actual,
|
||||
useSettings: vi.fn().mockReturnValue({
|
||||
data: {
|
||||
v1_enabled: true,
|
||||
},
|
||||
isLoading: false,
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackConversationCreated: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("useCreateConversation", () => {
|
||||
it("passes suggested tasks to the V1 create conversation API", async () => {
|
||||
const createConversationSpy = vi
|
||||
.spyOn(V1ConversationService, "createConversation")
|
||||
.mockResolvedValue({
|
||||
id: "task-id",
|
||||
created_by_user_id: null,
|
||||
status: "READY",
|
||||
detail: null,
|
||||
app_conversation_id: null,
|
||||
sandbox_id: null,
|
||||
agent_server_url: "http://agent-server.local",
|
||||
request: {
|
||||
sandbox_id: null,
|
||||
initial_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: "Please address the comments" }],
|
||||
},
|
||||
processors: [],
|
||||
llm_model: null,
|
||||
selected_repository: null,
|
||||
selected_branch: null,
|
||||
git_provider: "github",
|
||||
suggested_task: null,
|
||||
title: null,
|
||||
trigger: null,
|
||||
pr_number: [],
|
||||
parent_conversation_id: null,
|
||||
agent_type: "default",
|
||||
},
|
||||
created_at: new Date().toISOString(),
|
||||
updated_at: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useCreateConversation(), {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
const suggestedTask: SuggestedTask = {
|
||||
git_provider: "github",
|
||||
issue_number: 42,
|
||||
repo: "owner/repo",
|
||||
title: "Resolve comments",
|
||||
task_type: "UNRESOLVED_COMMENTS",
|
||||
};
|
||||
|
||||
await result.current.mutateAsync({
|
||||
query: "Please address the comments",
|
||||
repository: {
|
||||
name: "owner/repo",
|
||||
gitProvider: "github",
|
||||
branch: "main",
|
||||
},
|
||||
conversationInstructions: "Focus on review comments",
|
||||
suggestedTask,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(createConversationSpy).toHaveBeenCalledWith(
|
||||
"owner/repo",
|
||||
"github",
|
||||
"Please address the comments",
|
||||
"main",
|
||||
"Focus on review comments",
|
||||
suggestedTask,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,4 +1,4 @@
|
||||
import { describe, it, expect, afterEach, vi } from "vitest";
|
||||
import { describe, it, expect, afterEach, beforeEach, vi } from "vitest";
|
||||
import React from "react";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
@@ -112,3 +112,192 @@ describe("useConversationHistory", () => {
|
||||
expect(EventService.searchEventsV1).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("useConversationHistory cache key stability", () => {
|
||||
let localQueryClient: QueryClient;
|
||||
let localWrapper: ({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) => React.ReactElement;
|
||||
|
||||
beforeEach(() => {
|
||||
localQueryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
localWrapper = ({ children }: { children: React.ReactNode }) =>
|
||||
React.createElement(
|
||||
QueryClientProvider,
|
||||
{ client: localQueryClient },
|
||||
children,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
localQueryClient.clear();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("does not refetch when conversation object changes but version stays the same", async () => {
|
||||
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
|
||||
v1Spy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
const conv1 = makeConversation("V1");
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: conv1,
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
() => useConversationHistory("conv-stable"),
|
||||
{ wrapper: localWrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
expect(v1Spy).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Simulate background polling: new object reference with different mutable fields
|
||||
// but the SAME conversation_version
|
||||
const conv2: Conversation = {
|
||||
...conv1,
|
||||
last_updated_at: "2099-01-01T00:00:00Z",
|
||||
status: "STOPPED",
|
||||
runtime_status: "STATUS$STOPPED",
|
||||
};
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: conv2,
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
rerender();
|
||||
|
||||
// Allow any potential async refetch to trigger
|
||||
await new Promise((r) => {
|
||||
setTimeout(r, 50);
|
||||
});
|
||||
|
||||
// Must NOT refetch — version hasn't changed, only mutable fields did
|
||||
expect(v1Spy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
// Edge case: version change MUST trigger a refetch with the correct endpoint
|
||||
it("refetches when conversation_version changes from V0 to V1", async () => {
|
||||
const v0Spy = vi.spyOn(EventService, "searchEventsV0");
|
||||
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
|
||||
v0Spy.mockResolvedValue([makeEvent()]);
|
||||
v1Spy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
// Start with V0
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V0"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
() => useConversationHistory("conv-version-change"),
|
||||
{ wrapper: localWrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
expect(v0Spy).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Switch to V1 — new version means new cache key, must refetch
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V1"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
rerender();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(v1Spy).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
it("treats cached history as never stale (staleTime is Infinity)", async () => {
|
||||
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
|
||||
v1Spy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V1"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useConversationHistory("conv-stale-check"),
|
||||
{ wrapper: localWrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
// Check the query's staleTime option in the cache
|
||||
const queries = localQueryClient.getQueryCache().findAll({
|
||||
queryKey: ["conversation-history", "conv-stale-check"],
|
||||
});
|
||||
expect(queries).toHaveLength(1);
|
||||
expect((queries[0].options as Record<string, unknown>).staleTime).toBe(
|
||||
Infinity,
|
||||
);
|
||||
});
|
||||
|
||||
it("has gcTime of at least 30 minutes for navigation resilience", async () => {
|
||||
const v1Spy = vi.spyOn(EventService, "searchEventsV1");
|
||||
v1Spy.mockResolvedValue([makeEvent()]);
|
||||
|
||||
vi.mocked(useUserConversation).mockReturnValue({
|
||||
data: makeConversation("V1"),
|
||||
isLoading: false,
|
||||
isPending: false,
|
||||
isError: false,
|
||||
error: null,
|
||||
refetch: vi.fn(),
|
||||
} as any);
|
||||
|
||||
const { result } = renderHook(
|
||||
() => useConversationHistory("conv-gc-check"),
|
||||
{ wrapper: localWrapper },
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.data).toBeDefined();
|
||||
});
|
||||
|
||||
const queries = localQueryClient.getQueryCache().findAll({
|
||||
queryKey: ["conversation-history", "conv-gc-check"],
|
||||
});
|
||||
expect(queries).toHaveLength(1);
|
||||
expect(queries[0].options.gcTime).toBeGreaterThanOrEqual(30 * 60 * 1000);
|
||||
});
|
||||
});
|
||||
|
||||
180
frontend/__tests__/hooks/use-breakpoint.test.ts
Normal file
180
frontend/__tests__/hooks/use-breakpoint.test.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import { describe, expect, it, beforeEach, afterEach, vi } from "vitest";
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import { useBreakpoint } from "#/hooks/use-breakpoint";
|
||||
|
||||
// Helper to set window.innerWidth and dispatch resize event
|
||||
function setWindowWidth(width: number) {
|
||||
Object.defineProperty(window, "innerWidth", {
|
||||
writable: true,
|
||||
configurable: true,
|
||||
value: width,
|
||||
});
|
||||
window.dispatchEvent(new Event("resize"));
|
||||
}
|
||||
|
||||
describe("useBreakpoint", () => {
|
||||
const originalInnerWidth = window.innerWidth;
|
||||
|
||||
beforeEach(() => {
|
||||
// Start at a known desktop width
|
||||
Object.defineProperty(window, "innerWidth", {
|
||||
writable: true,
|
||||
configurable: true,
|
||||
value: 1200,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
Object.defineProperty(window, "innerWidth", {
|
||||
writable: true,
|
||||
configurable: true,
|
||||
value: originalInnerWidth,
|
||||
});
|
||||
});
|
||||
|
||||
it("returns false (not mobile) when window width is above the breakpoint", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 1200 });
|
||||
const { result } = renderHook(() => useBreakpoint());
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true (mobile) when window width is at the breakpoint (1024)", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 1024 });
|
||||
const { result } = renderHook(() => useBreakpoint());
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
|
||||
it("returns true (mobile) when window width is below the breakpoint", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 800 });
|
||||
const { result } = renderHook(() => useBreakpoint());
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
|
||||
it("updates from false to true when window resizes below the breakpoint", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 1200 });
|
||||
const { result } = renderHook(() => useBreakpoint());
|
||||
expect(result.current).toBe(false);
|
||||
|
||||
act(() => {
|
||||
setWindowWidth(800);
|
||||
});
|
||||
|
||||
expect(result.current).toBe(true);
|
||||
});
|
||||
|
||||
it("updates from true to false when window resizes above the breakpoint", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 800 });
|
||||
const { result } = renderHook(() => useBreakpoint());
|
||||
expect(result.current).toBe(true);
|
||||
|
||||
act(() => {
|
||||
setWindowWidth(1200);
|
||||
});
|
||||
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
|
||||
it("does NOT trigger re-render when width changes within the desktop range", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 1200 });
|
||||
const renderCount = vi.fn();
|
||||
const { result } = renderHook(() => {
|
||||
renderCount();
|
||||
return useBreakpoint();
|
||||
});
|
||||
|
||||
expect(result.current).toBe(false);
|
||||
const initialRenderCount = renderCount.mock.calls.length;
|
||||
|
||||
// Resize within desktop range (still above 1024) — should NOT re-render
|
||||
act(() => {
|
||||
setWindowWidth(1300);
|
||||
});
|
||||
act(() => {
|
||||
setWindowWidth(1100);
|
||||
});
|
||||
act(() => {
|
||||
setWindowWidth(1025);
|
||||
});
|
||||
|
||||
expect(result.current).toBe(false);
|
||||
// No additional renders beyond the initial render
|
||||
expect(renderCount.mock.calls.length).toBe(initialRenderCount);
|
||||
});
|
||||
|
||||
it("does NOT trigger re-render when width changes within the mobile range", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 800 });
|
||||
const renderCount = vi.fn();
|
||||
const { result } = renderHook(() => {
|
||||
renderCount();
|
||||
return useBreakpoint();
|
||||
});
|
||||
|
||||
expect(result.current).toBe(true);
|
||||
const initialRenderCount = renderCount.mock.calls.length;
|
||||
|
||||
// Resize within mobile range (still at or below 1024) — should NOT re-render
|
||||
act(() => {
|
||||
setWindowWidth(600);
|
||||
});
|
||||
act(() => {
|
||||
setWindowWidth(1024);
|
||||
});
|
||||
act(() => {
|
||||
setWindowWidth(900);
|
||||
});
|
||||
|
||||
expect(result.current).toBe(true);
|
||||
expect(renderCount.mock.calls.length).toBe(initialRenderCount);
|
||||
});
|
||||
|
||||
it("handles rapid resize across the breakpoint without issues", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 1200 });
|
||||
const { result } = renderHook(() => useBreakpoint());
|
||||
expect(result.current).toBe(false);
|
||||
|
||||
// Rapid toggles across the breakpoint
|
||||
act(() => {
|
||||
setWindowWidth(800);
|
||||
});
|
||||
expect(result.current).toBe(true);
|
||||
|
||||
act(() => {
|
||||
setWindowWidth(1200);
|
||||
});
|
||||
expect(result.current).toBe(false);
|
||||
|
||||
act(() => {
|
||||
setWindowWidth(1024);
|
||||
});
|
||||
expect(result.current).toBe(true);
|
||||
|
||||
act(() => {
|
||||
setWindowWidth(1025);
|
||||
});
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
|
||||
it("cleans up the resize event listener on unmount", () => {
|
||||
const removeEventListenerSpy = vi.spyOn(window, "removeEventListener");
|
||||
const { unmount } = renderHook(() => useBreakpoint());
|
||||
|
||||
unmount();
|
||||
|
||||
expect(removeEventListenerSpy).toHaveBeenCalledWith(
|
||||
"resize",
|
||||
expect.any(Function),
|
||||
);
|
||||
removeEventListenerSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("accepts a custom breakpoint value", () => {
|
||||
Object.defineProperty(window, "innerWidth", { value: 768 });
|
||||
const { result } = renderHook(() => useBreakpoint(768));
|
||||
expect(result.current).toBe(true);
|
||||
|
||||
act(() => {
|
||||
setWindowWidth(769);
|
||||
});
|
||||
expect(result.current).toBe(false);
|
||||
});
|
||||
});
|
||||
383
frontend/__tests__/hooks/use-filtered-events.test.ts
Normal file
383
frontend/__tests__/hooks/use-filtered-events.test.ts
Normal file
@@ -0,0 +1,383 @@
|
||||
import { describe, expect, it, beforeEach } from "vitest";
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import { useFilteredEvents } from "#/hooks/use-filtered-events";
|
||||
import { useEventStore } from "#/stores/use-event-store";
|
||||
import type { OpenHandsAction } from "#/types/core/actions";
|
||||
import type { ActionEvent, MessageEvent } from "#/types/v1/core";
|
||||
import { SecurityRisk } from "#/types/v1/core";
|
||||
|
||||
// --- V0 event factories ---
|
||||
|
||||
function createV0UserMessage(id: number): OpenHandsAction {
|
||||
return {
|
||||
id,
|
||||
source: "user",
|
||||
action: "message",
|
||||
args: { content: `User message ${id}`, image_urls: [], file_urls: [] },
|
||||
message: `User message ${id}`,
|
||||
timestamp: `2025-07-01T00:00:0${id}Z`,
|
||||
};
|
||||
}
|
||||
|
||||
function createV0AgentMessage(id: number): OpenHandsAction {
|
||||
return {
|
||||
id,
|
||||
source: "agent",
|
||||
action: "message",
|
||||
args: {
|
||||
thought: `Agent thought ${id}`,
|
||||
image_urls: null,
|
||||
file_urls: [],
|
||||
wait_for_response: true,
|
||||
},
|
||||
message: `Agent response ${id}`,
|
||||
timestamp: `2025-07-01T00:00:0${id}Z`,
|
||||
};
|
||||
}
|
||||
|
||||
function createV0SystemEvent(id: number): OpenHandsAction {
|
||||
return {
|
||||
id,
|
||||
source: "environment",
|
||||
action: "system",
|
||||
args: {
|
||||
content: "source .openhands/setup.sh",
|
||||
tools: null,
|
||||
openhands_version: null,
|
||||
agent_class: null,
|
||||
},
|
||||
message: "Running setup script",
|
||||
timestamp: `2025-07-01T00:00:0${id}Z`,
|
||||
};
|
||||
}
|
||||
|
||||
// --- V1 event factories ---
|
||||
|
||||
function createV1UserMessage(id: string): MessageEvent {
|
||||
return {
|
||||
id,
|
||||
timestamp: "2025-07-01T00:00:01Z",
|
||||
source: "user",
|
||||
llm_message: {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: `User message ${id}` }],
|
||||
},
|
||||
activated_microagents: [],
|
||||
extended_content: [],
|
||||
};
|
||||
}
|
||||
|
||||
function createV1AgentAction(id: string): ActionEvent {
|
||||
return {
|
||||
id,
|
||||
timestamp: "2025-07-01T00:00:02Z",
|
||||
source: "agent",
|
||||
thought: [{ type: "text", text: "Agent thought" }],
|
||||
thinking_blocks: [],
|
||||
action: {
|
||||
kind: "ExecuteBashAction",
|
||||
command: "echo test",
|
||||
is_input: false,
|
||||
timeout: null,
|
||||
reset: false,
|
||||
},
|
||||
tool_name: "execute_bash",
|
||||
tool_call_id: "call-1",
|
||||
tool_call: {
|
||||
id: "call-1",
|
||||
type: "function",
|
||||
function: { name: "execute_bash", arguments: '{"command": "echo test"}' },
|
||||
},
|
||||
llm_response_id: "response-1",
|
||||
security_risk: SecurityRisk.UNKNOWN,
|
||||
};
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset the event store before each test
|
||||
useEventStore.setState({
|
||||
events: [],
|
||||
eventIds: new Set(),
|
||||
uiEvents: [],
|
||||
});
|
||||
});
|
||||
|
||||
describe("useFilteredEvents", () => {
|
||||
describe("referential stability", () => {
|
||||
it("returns the same v0Events reference when storeEvents has not changed", () => {
|
||||
const v0Event = createV0UserMessage(1);
|
||||
useEventStore.setState({
|
||||
events: [v0Event],
|
||||
eventIds: new Set([1]),
|
||||
uiEvents: [v0Event],
|
||||
});
|
||||
|
||||
const { result, rerender } = renderHook(() => useFilteredEvents());
|
||||
const firstV0Events = result.current.v0Events;
|
||||
|
||||
// Rerender without changing the store
|
||||
rerender();
|
||||
|
||||
expect(result.current.v0Events).toBe(firstV0Events);
|
||||
});
|
||||
|
||||
it("returns the same v1UiEvents reference when uiEvents has not changed", () => {
|
||||
const v1Event = createV1UserMessage("msg-1");
|
||||
useEventStore.setState({
|
||||
events: [v1Event],
|
||||
eventIds: new Set(["msg-1"]),
|
||||
uiEvents: [v1Event],
|
||||
});
|
||||
|
||||
const { result, rerender } = renderHook(() => useFilteredEvents());
|
||||
const firstV1UiEvents = result.current.v1UiEvents;
|
||||
|
||||
rerender();
|
||||
|
||||
expect(result.current.v1UiEvents).toBe(firstV1UiEvents);
|
||||
});
|
||||
|
||||
it("returns the same v1FullEvents reference when storeEvents has not changed", () => {
|
||||
const v1Event = createV1UserMessage("msg-1");
|
||||
useEventStore.setState({
|
||||
events: [v1Event],
|
||||
eventIds: new Set(["msg-1"]),
|
||||
uiEvents: [v1Event],
|
||||
});
|
||||
|
||||
const { result, rerender } = renderHook(() => useFilteredEvents());
|
||||
const firstV1FullEvents = result.current.v1FullEvents;
|
||||
|
||||
rerender();
|
||||
|
||||
expect(result.current.v1FullEvents).toBe(firstV1FullEvents);
|
||||
});
|
||||
|
||||
it("returns a new v0Events reference when storeEvents changes", () => {
|
||||
const v0Event1 = createV0UserMessage(1);
|
||||
useEventStore.setState({
|
||||
events: [v0Event1],
|
||||
eventIds: new Set([1]),
|
||||
uiEvents: [v0Event1],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
const firstV0Events = result.current.v0Events;
|
||||
|
||||
// Add a new event to the store (new array reference)
|
||||
const v0Event2 = createV0AgentMessage(2);
|
||||
act(() => {
|
||||
useEventStore.setState({
|
||||
events: [v0Event1, v0Event2],
|
||||
eventIds: new Set([1, 2]),
|
||||
uiEvents: [v0Event1, v0Event2],
|
||||
});
|
||||
});
|
||||
|
||||
expect(result.current.v0Events).not.toBe(firstV0Events);
|
||||
expect(result.current.v0Events).toHaveLength(2);
|
||||
});
|
||||
});
|
||||
|
||||
describe("V0 event filtering", () => {
|
||||
it("filters V0 events through isV0Event, isActionOrObservation, and shouldRenderEvent", () => {
|
||||
const userMsg = createV0UserMessage(1);
|
||||
const agentMsg = createV0AgentMessage(2);
|
||||
|
||||
useEventStore.setState({
|
||||
events: [userMsg, agentMsg],
|
||||
eventIds: new Set([1, 2]),
|
||||
uiEvents: [userMsg, agentMsg],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
|
||||
expect(result.current.v0Events).toHaveLength(2);
|
||||
expect(result.current.v0Events).toContainEqual(userMsg);
|
||||
expect(result.current.v0Events).toContainEqual(agentMsg);
|
||||
});
|
||||
|
||||
it("excludes V0 system events from v0Events", () => {
|
||||
const userMsg = createV0UserMessage(1);
|
||||
const systemEvent = createV0SystemEvent(2);
|
||||
|
||||
useEventStore.setState({
|
||||
events: [userMsg, systemEvent],
|
||||
eventIds: new Set([1, 2]),
|
||||
uiEvents: [userMsg, systemEvent],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
|
||||
// System events are filtered out by shouldRenderEvent
|
||||
expect(result.current.v0Events).toHaveLength(1);
|
||||
expect(result.current.v0Events[0]).toEqual(userMsg);
|
||||
});
|
||||
|
||||
it("does not include V1 events in v0Events", () => {
|
||||
const v0Event = createV0UserMessage(1);
|
||||
const v1Event = createV1UserMessage("msg-1");
|
||||
|
||||
useEventStore.setState({
|
||||
events: [v0Event, v1Event],
|
||||
eventIds: new Set([1, "msg-1"]),
|
||||
uiEvents: [v0Event, v1Event],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
|
||||
expect(result.current.v0Events).toHaveLength(1);
|
||||
expect(result.current.v0Events[0]).toEqual(v0Event);
|
||||
});
|
||||
});
|
||||
|
||||
describe("V1 event filtering", () => {
|
||||
it("filters V1 events into v1FullEvents", () => {
|
||||
const v1Event = createV1UserMessage("msg-1");
|
||||
|
||||
useEventStore.setState({
|
||||
events: [v1Event],
|
||||
eventIds: new Set(["msg-1"]),
|
||||
uiEvents: [v1Event],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
|
||||
expect(result.current.v1FullEvents).toHaveLength(1);
|
||||
expect(result.current.v1FullEvents[0]).toEqual(v1Event);
|
||||
});
|
||||
|
||||
it("does not include V0 events in v1FullEvents", () => {
|
||||
const v0Event = createV0UserMessage(1);
|
||||
const v1Event = createV1UserMessage("msg-1");
|
||||
|
||||
useEventStore.setState({
|
||||
events: [v0Event, v1Event],
|
||||
eventIds: new Set([1, "msg-1"]),
|
||||
uiEvents: [v0Event, v1Event],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
|
||||
expect(result.current.v1FullEvents).toHaveLength(1);
|
||||
expect(result.current.v1FullEvents[0]).toEqual(v1Event);
|
||||
});
|
||||
});
|
||||
|
||||
describe("totalEvents", () => {
|
||||
it("returns V0 event count when V0 events exist", () => {
|
||||
const v0Event1 = createV0UserMessage(1);
|
||||
const v0Event2 = createV0AgentMessage(2);
|
||||
|
||||
useEventStore.setState({
|
||||
events: [v0Event1, v0Event2],
|
||||
eventIds: new Set([1, 2]),
|
||||
uiEvents: [v0Event1, v0Event2],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.totalEvents).toBe(2);
|
||||
});
|
||||
|
||||
it("returns 0 when no events exist", () => {
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.totalEvents).toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasSubstantiveAgentActions", () => {
|
||||
it("returns false when no events exist", () => {
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.hasSubstantiveAgentActions).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when only user events exist (V0)", () => {
|
||||
const userMsg = createV0UserMessage(1);
|
||||
|
||||
useEventStore.setState({
|
||||
events: [userMsg],
|
||||
eventIds: new Set([1]),
|
||||
uiEvents: [userMsg],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.hasSubstantiveAgentActions).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when V0 agent message actions exist", () => {
|
||||
const agentMsg = createV0AgentMessage(1);
|
||||
|
||||
useEventStore.setState({
|
||||
events: [agentMsg],
|
||||
eventIds: new Set([1]),
|
||||
uiEvents: [agentMsg],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.hasSubstantiveAgentActions).toBe(true);
|
||||
});
|
||||
|
||||
it("returns true when V1 agent action events exist", () => {
|
||||
const agentAction = createV1AgentAction("action-1");
|
||||
|
||||
useEventStore.setState({
|
||||
events: [agentAction],
|
||||
eventIds: new Set(["action-1"]),
|
||||
uiEvents: [agentAction],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.hasSubstantiveAgentActions).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("userEventsExist", () => {
|
||||
it("returns false when no events exist", () => {
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.userEventsExist).toBe(false);
|
||||
});
|
||||
|
||||
it("returns true when V0 user events exist", () => {
|
||||
const userMsg = createV0UserMessage(1);
|
||||
|
||||
useEventStore.setState({
|
||||
events: [userMsg],
|
||||
eventIds: new Set([1]),
|
||||
uiEvents: [userMsg],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.v0UserEventsExist).toBe(true);
|
||||
expect(result.current.userEventsExist).toBe(true);
|
||||
});
|
||||
|
||||
it("returns true when V1 user events exist", () => {
|
||||
const userMsg = createV1UserMessage("msg-1");
|
||||
|
||||
useEventStore.setState({
|
||||
events: [userMsg],
|
||||
eventIds: new Set(["msg-1"]),
|
||||
uiEvents: [userMsg],
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
expect(result.current.v1UserEventsExist).toBe(true);
|
||||
expect(result.current.userEventsExist).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("empty store", () => {
|
||||
it("returns empty arrays and false flags for empty store", () => {
|
||||
const { result } = renderHook(() => useFilteredEvents());
|
||||
|
||||
expect(result.current.v0Events).toEqual([]);
|
||||
expect(result.current.v1UiEvents).toEqual([]);
|
||||
expect(result.current.v1FullEvents).toEqual([]);
|
||||
expect(result.current.totalEvents).toBe(0);
|
||||
expect(result.current.hasSubstantiveAgentActions).toBe(false);
|
||||
expect(result.current.v0UserEventsExist).toBe(false);
|
||||
expect(result.current.v1UserEventsExist).toBe(false);
|
||||
expect(result.current.userEventsExist).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -41,8 +41,7 @@ describe("useHandleBuildPlanClick", () => {
|
||||
(createChatMessage as unknown as ReturnType<typeof vi.fn>).mockReturnValue({
|
||||
action: "message",
|
||||
args: {
|
||||
content:
|
||||
"Execute the plan based on the workspace/project/PLAN.md file.",
|
||||
content: "Execute the plan based on the .agents_tmp/PLAN.md file.",
|
||||
image_urls: [],
|
||||
file_urls: [],
|
||||
timestamp: expect.any(String),
|
||||
@@ -78,7 +77,7 @@ describe("useHandleBuildPlanClick", () => {
|
||||
// Arrange
|
||||
const { result } = renderHook(() => useHandleBuildPlanClick());
|
||||
const expectedPrompt =
|
||||
"Execute the plan based on the workspace/project/PLAN.md file.";
|
||||
"Execute the plan based on the .agents_tmp/PLAN.md file.";
|
||||
|
||||
// Act
|
||||
act(() => {
|
||||
@@ -109,7 +108,7 @@ describe("useHandleBuildPlanClick", () => {
|
||||
useOptimisticUserMessageStore.setState({ optimisticUserMessage: null });
|
||||
const { result } = renderHook(() => useHandleBuildPlanClick());
|
||||
const expectedPrompt =
|
||||
"Execute the plan based on the workspace/project/PLAN.md file.";
|
||||
"Execute the plan based on the .agents_tmp/PLAN.md file.";
|
||||
|
||||
// Act
|
||||
act(() => {
|
||||
@@ -155,7 +154,7 @@ describe("useHandleBuildPlanClick", () => {
|
||||
expect(useConversationStore.getState().conversationMode).toBe("code");
|
||||
expect(mockSend).toHaveBeenCalledTimes(1);
|
||||
expect(useOptimisticUserMessageStore.getState().optimisticUserMessage).toBe(
|
||||
"Execute the plan based on the workspace/project/PLAN.md file.",
|
||||
"Execute the plan based on the .agents_tmp/PLAN.md file.",
|
||||
);
|
||||
});
|
||||
|
||||
|
||||
170
frontend/__tests__/hooks/use-invitation.test.ts
Normal file
170
frontend/__tests__/hooks/use-invitation.test.ts
Normal file
@@ -0,0 +1,170 @@
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
const INVITATION_TOKEN_KEY = "openhands_invitation_token";
|
||||
|
||||
// Mock setSearchParams function
|
||||
const mockSetSearchParams = vi.fn();
|
||||
|
||||
// Default mock searchParams
|
||||
let mockSearchParamsData: Record<string, string> = {};
|
||||
|
||||
// Mock react-router
|
||||
vi.mock("react-router", () => ({
|
||||
useSearchParams: () => [
|
||||
{
|
||||
get: (key: string) => mockSearchParamsData[key] || null,
|
||||
has: (key: string) => key in mockSearchParamsData,
|
||||
},
|
||||
mockSetSearchParams,
|
||||
],
|
||||
}));
|
||||
|
||||
// Import after mocking
|
||||
import { useInvitation } from "#/hooks/use-invitation";
|
||||
|
||||
describe("useInvitation", () => {
|
||||
beforeEach(() => {
|
||||
// Clear localStorage before each test
|
||||
localStorage.clear();
|
||||
// Reset mock data
|
||||
mockSearchParamsData = {};
|
||||
mockSetSearchParams.mockClear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("initialization", () => {
|
||||
it("should initialize with null token when localStorage is empty", () => {
|
||||
// Arrange - localStorage is empty (cleared in beforeEach)
|
||||
|
||||
// Act
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(result.current.invitationToken).toBeNull();
|
||||
expect(result.current.hasInvitation).toBe(false);
|
||||
});
|
||||
|
||||
it("should initialize with token from localStorage if present", () => {
|
||||
// Arrange
|
||||
const storedToken = "inv-stored-token-12345";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
|
||||
|
||||
// Act
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(result.current.invitationToken).toBe(storedToken);
|
||||
expect(result.current.hasInvitation).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("URL token capture", () => {
|
||||
it("should capture invitation_token from URL and store in localStorage", () => {
|
||||
// Arrange
|
||||
const urlToken = "inv-url-token-67890";
|
||||
mockSearchParamsData = { invitation_token: urlToken };
|
||||
|
||||
// Act
|
||||
renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBe(urlToken);
|
||||
expect(mockSetSearchParams).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("completion cleanup", () => {
|
||||
it("should clear localStorage when email_mismatch param is present", () => {
|
||||
// Arrange
|
||||
const storedToken = "inv-token-to-clear";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
|
||||
mockSearchParamsData = { email_mismatch: "true" };
|
||||
|
||||
// Act
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
expect(mockSetSearchParams).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should clear localStorage when invitation_success param is present", () => {
|
||||
// Arrange
|
||||
const storedToken = "inv-token-to-clear";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, storedToken);
|
||||
mockSearchParamsData = { invitation_success: "true" };
|
||||
|
||||
// Act
|
||||
renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
});
|
||||
|
||||
it("should clear localStorage when invitation_expired param is present", () => {
|
||||
// Arrange
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token");
|
||||
mockSearchParamsData = { invitation_expired: "true" };
|
||||
|
||||
// Act
|
||||
renderHook(() => useInvitation());
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
describe("buildOAuthStateData", () => {
|
||||
it("should include invitation_token in OAuth state when token is present", () => {
|
||||
// Arrange
|
||||
const token = "inv-oauth-token-12345";
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, token);
|
||||
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
const baseState = { redirect_url: "/dashboard" };
|
||||
|
||||
// Act
|
||||
const stateData = result.current.buildOAuthStateData(baseState);
|
||||
|
||||
// Assert
|
||||
expect(stateData.invitation_token).toBe(token);
|
||||
expect(stateData.redirect_url).toBe("/dashboard");
|
||||
});
|
||||
|
||||
it("should not include invitation_token when no token is present", () => {
|
||||
// Arrange - no token in localStorage
|
||||
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
const baseState = { redirect_url: "/dashboard" };
|
||||
|
||||
// Act
|
||||
const stateData = result.current.buildOAuthStateData(baseState);
|
||||
|
||||
// Assert
|
||||
expect(stateData.invitation_token).toBeUndefined();
|
||||
expect(stateData.redirect_url).toBe("/dashboard");
|
||||
});
|
||||
});
|
||||
|
||||
describe("clearInvitation", () => {
|
||||
it("should remove token from localStorage when called", () => {
|
||||
// Arrange
|
||||
localStorage.setItem(INVITATION_TOKEN_KEY, "inv-token-to-clear");
|
||||
const { result } = renderHook(() => useInvitation());
|
||||
|
||||
// Act
|
||||
act(() => {
|
||||
result.current.clearInvitation();
|
||||
});
|
||||
|
||||
// Assert
|
||||
expect(localStorage.getItem(INVITATION_TOKEN_KEY)).toBeNull();
|
||||
expect(result.current.invitationToken).toBeNull();
|
||||
expect(result.current.hasInvitation).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
126
frontend/__tests__/hooks/use-scroll-to-bottom.test.ts
Normal file
126
frontend/__tests__/hooks/use-scroll-to-bottom.test.ts
Normal file
@@ -0,0 +1,126 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import { useScrollToBottom } from "#/hooks/use-scroll-to-bottom";
|
||||
import type { RefObject } from "react";
|
||||
|
||||
/**
|
||||
* Creates a mock scroll element with a trackable scrollTop setter.
|
||||
*
|
||||
* state.scrollTop can be set directly (bypassing the spy) to position
|
||||
* the element for onChatBodyScroll calls without polluting the spy.
|
||||
*/
|
||||
function createMockScrollElement(initialScrollHeight = 1000) {
|
||||
const state = {
|
||||
scrollTop: 0,
|
||||
scrollHeight: initialScrollHeight,
|
||||
clientHeight: 500,
|
||||
};
|
||||
|
||||
const scrollTopSetter = vi.fn((value: number) => {
|
||||
state.scrollTop = value;
|
||||
});
|
||||
|
||||
const element = {
|
||||
get scrollTop() {
|
||||
return state.scrollTop;
|
||||
},
|
||||
set scrollTop(value: number) {
|
||||
scrollTopSetter(value);
|
||||
},
|
||||
get scrollHeight() {
|
||||
return state.scrollHeight;
|
||||
},
|
||||
get clientHeight() {
|
||||
return state.clientHeight;
|
||||
},
|
||||
} as unknown as HTMLDivElement;
|
||||
|
||||
return { element, scrollTopSetter, state };
|
||||
}
|
||||
|
||||
describe("useScrollToBottom", () => {
|
||||
let mock: ReturnType<typeof createMockScrollElement>;
|
||||
let ref: RefObject<HTMLDivElement>;
|
||||
|
||||
beforeEach(() => {
|
||||
mock = createMockScrollElement(1000);
|
||||
ref = { current: mock.element } as RefObject<HTMLDivElement>;
|
||||
});
|
||||
|
||||
describe("no automatic scrolling on render", () => {
|
||||
it("does NOT scroll on initial render", () => {
|
||||
renderHook(() => useScrollToBottom(ref));
|
||||
|
||||
// No useLayoutEffect means no automatic scroll-to-bottom
|
||||
expect(mock.scrollTopSetter).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("does NOT scroll when re-rendered (e.g., during resize)", () => {
|
||||
const { rerender } = renderHook(() => useScrollToBottom(ref));
|
||||
|
||||
mock.state.scrollHeight = 1500;
|
||||
rerender();
|
||||
|
||||
expect(mock.scrollTopSetter).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("scroll position tracking", () => {
|
||||
it("tracks hitBottom correctly via onChatBodyScroll", () => {
|
||||
const { result } = renderHook(() => useScrollToBottom(ref));
|
||||
|
||||
// Position at bottom: scrollTop(480) + clientHeight(500) = 980 >= 1000 - 20
|
||||
mock.state.scrollTop = 480;
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
expect(result.current.hitBottom).toBe(true);
|
||||
|
||||
// Position not at bottom: scrollTop(200) + clientHeight(500) = 700 < 980
|
||||
mock.state.scrollTop = 200;
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
expect(result.current.hitBottom).toBe(false);
|
||||
});
|
||||
|
||||
it("disables autoScroll when user scrolls up", () => {
|
||||
const { result } = renderHook(() => useScrollToBottom(ref));
|
||||
|
||||
// First scroll to establish prevScrollTopRef
|
||||
mock.state.scrollTop = 400;
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
|
||||
// Scroll up (lower scrollTop than previous)
|
||||
mock.state.scrollTop = 200;
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
expect(result.current.autoScroll).toBe(false);
|
||||
});
|
||||
|
||||
it("re-enables autoScroll when user reaches bottom", () => {
|
||||
const { result } = renderHook(() => useScrollToBottom(ref));
|
||||
|
||||
// Scroll up to disable autoScroll
|
||||
mock.state.scrollTop = 400;
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
mock.state.scrollTop = 200;
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
expect(result.current.autoScroll).toBe(false);
|
||||
|
||||
// Scroll to bottom
|
||||
mock.state.scrollTop = 500; // 500 + 500 = 1000 >= 980
|
||||
act(() => {
|
||||
result.current.onChatBodyScroll(mock.element);
|
||||
});
|
||||
expect(result.current.autoScroll).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -12,8 +12,8 @@ const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
|
||||
const mockConfig = (appMode: "saas" | "oss", hideLlmSettings = false) => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
APP_MODE: appMode,
|
||||
FEATURE_FLAGS: { HIDE_LLM_SETTINGS: hideLlmSettings },
|
||||
app_mode: appMode,
|
||||
feature_flags: { hide_llm_settings: hideLlmSettings },
|
||||
} as Awaited<ReturnType<typeof OptionService.getConfig>>);
|
||||
};
|
||||
|
||||
@@ -22,7 +22,7 @@ describe("useSettingsNavItems", () => {
|
||||
queryClient.clear();
|
||||
});
|
||||
|
||||
it("should return SAAS_NAV_ITEMS when APP_MODE is 'saas'", async () => {
|
||||
it("should return SAAS_NAV_ITEMS when app_mode is 'saas'", async () => {
|
||||
mockConfig("saas");
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
@@ -31,7 +31,7 @@ describe("useSettingsNavItems", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should return OSS_NAV_ITEMS when APP_MODE is 'oss'", async () => {
|
||||
it("should return OSS_NAV_ITEMS when app_mode is 'oss'", async () => {
|
||||
mockConfig("oss");
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
@@ -40,7 +40,7 @@ describe("useSettingsNavItems", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should filter out '/settings' item when HIDE_LLM_SETTINGS feature flag is enabled", async () => {
|
||||
it("should filter out '/settings' item when hide_llm_settings feature flag is enabled", async () => {
|
||||
mockConfig("saas", true);
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user