mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
149 Commits
1.3.0
...
fix/git-ap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4f7f07d5d | ||
|
|
a34dc949ce | ||
|
|
d6b8d80026 | ||
|
|
1e6a92b454 | ||
|
|
b4a3e5db2f | ||
|
|
80e4fe1226 | ||
|
|
f9d553d0bb | ||
|
|
f6f6c1ab25 | ||
|
|
c511a89426 | ||
|
|
1f82ff04d9 | ||
|
|
eec17311c7 | ||
|
|
c34fdf4b37 | ||
|
|
25076ee44c | ||
|
|
baaec8473a | ||
|
|
402fa47422 | ||
|
|
8dde385843 | ||
|
|
a905e35531 | ||
|
|
1f185173b7 | ||
|
|
ddc7a78723 | ||
|
|
a29ed4d926 | ||
|
|
b8ab4bb44e | ||
|
|
ddd544f8d6 | ||
|
|
3804b66e32 | ||
|
|
b97adf392a | ||
|
|
dcb584913a | ||
|
|
d2fd54a083 | ||
|
|
112d863287 | ||
|
|
c8680caec3 | ||
|
|
d4b9fb1d03 | ||
|
|
409df1287d | ||
|
|
a92bfe6cc0 | ||
|
|
f93e3254d3 | ||
|
|
0476d57451 | ||
|
|
a4cd21e155 | ||
|
|
7f3af371d1 | ||
|
|
1421794c1b | ||
|
|
2fc689457c | ||
|
|
3161b365a8 | ||
|
|
18ab56ef4e | ||
|
|
a9c0df778c | ||
|
|
51b989b5f8 | ||
|
|
dc039d81d6 | ||
|
|
8e4559b14a | ||
|
|
b84f352b63 | ||
|
|
a0dba6124a | ||
|
|
951739f3eb | ||
|
|
0f1ad46a47 | ||
|
|
5367bef43a | ||
|
|
3afeccfe7f | ||
|
|
0677c035ff | ||
|
|
68165b52d9 | ||
|
|
dcc8217317 | ||
|
|
d1410949ff | ||
|
|
a6c0d80fe1 | ||
|
|
0efb1db85d | ||
|
|
8e0f74c92c | ||
|
|
6e1ba3d836 | ||
|
|
0ec97893d1 | ||
|
|
ddb809bc43 | ||
|
|
872f2b87f2 | ||
|
|
ee86005a3a | ||
|
|
d4aa30580b | ||
|
|
2f0e879129 | ||
|
|
3bc2ef954e | ||
|
|
32ab2a24c6 | ||
|
|
a6e148d1e6 | ||
|
|
3fc977eddd | ||
|
|
89a6890269 | ||
|
|
8927ac2230 | ||
|
|
f3429e33ca | ||
|
|
7cd219792b | ||
|
|
2aabe2ed8c | ||
|
|
731a9a813e | ||
|
|
123e556fed | ||
|
|
6676cae249 | ||
|
|
fede37b496 | ||
|
|
3bcd6f18df | ||
|
|
0da18440c2 | ||
|
|
ac76e10048 | ||
|
|
b98bae8b5f | ||
|
|
516721d1ee | ||
|
|
4d6f66ca28 | ||
|
|
b18568da0b | ||
|
|
83dd3c169c | ||
|
|
35bddb14f1 | ||
|
|
e8425218e2 | ||
|
|
0a879fa781 | ||
|
|
41e142bbab | ||
|
|
b06b9eedac | ||
|
|
a9afafa991 | ||
|
|
663ace4b39 | ||
|
|
2d085a6e0a | ||
|
|
8b7112abe8 | ||
|
|
34547ba947 | ||
|
|
5f958ab60d | ||
|
|
d7656bf1c9 | ||
|
|
2bc107564c | ||
|
|
85eb1e1504 | ||
|
|
cd235cc8c7 | ||
|
|
40f52dfabc | ||
|
|
bab7bf85e8 | ||
|
|
c856537f65 | ||
|
|
736f5b2255 | ||
|
|
c1d9d11772 | ||
|
|
85244499fe | ||
|
|
c55084e223 | ||
|
|
e3bb75deb4 | ||
|
|
1948200762 | ||
|
|
affe0af361 | ||
|
|
f20c956196 | ||
|
|
4a089a3a0d | ||
|
|
aa0b2d0b74 | ||
|
|
bef9b80b9d | ||
|
|
c4a90b1f89 | ||
|
|
0d13c57d9f | ||
|
|
b3422f1275 | ||
|
|
f139a9970b | ||
|
|
54d156122c | ||
|
|
ac072bf686 | ||
|
|
a53812c029 | ||
|
|
1d1c0925b5 | ||
|
|
872f41e3c0 | ||
|
|
d43ff82534 | ||
|
|
8cd8c011b2 | ||
|
|
5c68b10983 | ||
|
|
a97fad1976 | ||
|
|
4c3542a91c | ||
|
|
f460057f58 | ||
|
|
4fa2ad0f47 | ||
|
|
dd8be12809 | ||
|
|
89475095d9 | ||
|
|
05d5f8848a | ||
|
|
ee2885eb0b | ||
|
|
545257f870 | ||
|
|
b23ab33a01 | ||
|
|
a9ede73391 | ||
|
|
634c2439b4 | ||
|
|
a1989a40b3 | ||
|
|
e38f1283ea | ||
|
|
07eb791735 | ||
|
|
c355c4819f | ||
|
|
9d8e4c44cc | ||
|
|
25cc55e558 | ||
|
|
0e825c38d7 | ||
|
|
ce04e70b5b | ||
|
|
7b0589ad40 | ||
|
|
682465a862 | ||
|
|
1bb4c844d4 | ||
|
|
d6c11fe517 |
2
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
2
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# disable blank issue creation
|
||||
blank_issues_enabled: false
|
||||
29
.github/workflows/enterprise-preview.yml
vendored
29
.github/workflows/enterprise-preview.yml
vendored
@@ -1,29 +0,0 @@
|
||||
# Feature branch preview for enterprise code
|
||||
name: Enterprise Preview
|
||||
|
||||
# Run on PRs labeled
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
# Match ghcr-build.yml, but don't interrupt it.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
# This must happen for the PR Docker workflow when the label is present,
|
||||
# and also if it's added after the fact. Thus, it exists in both places.
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event.label.name == 'deploy'
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
# This should match the version in ghcr-build.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
16
.github/workflows/ghcr-build.yml
vendored
16
.github/workflows/ghcr-build.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "saas-rel-*"
|
||||
tags:
|
||||
- "*"
|
||||
pull_request:
|
||||
@@ -239,21 +240,6 @@ jobs:
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
needs: [ghcr_build_enterprise]
|
||||
steps:
|
||||
# This should match the version in enterprise-preview.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# "All Runtime Tests Passed" is a required job for PRs to merge
|
||||
# We can remove this once the config changes
|
||||
runtime_tests_check_success:
|
||||
|
||||
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
name: PR Review by OpenHands
|
||||
|
||||
on:
|
||||
# TEMPORARY MITIGATION (Clinejection hardening)
|
||||
#
|
||||
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
|
||||
# workflow is fully hardened for untrusted execution.
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
pr-review:
|
||||
# Note: fork PRs will not have access to repository secrets under `pull_request`.
|
||||
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
|
||||
if: |
|
||||
github.event.pull_request.head.repo.full_name == github.repository &&
|
||||
(
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
|
||||
github.event.action == 'ready_for_review' ||
|
||||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
|
||||
(
|
||||
github.event.action == 'review_requested' &&
|
||||
(
|
||||
github.event.requested_reviewer.login == 'openhands-agent' ||
|
||||
github.event.requested_reviewer.login == 'all-hands-bot'
|
||||
)
|
||||
)
|
||||
)
|
||||
concurrency:
|
||||
group: pr-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Run PR Review
|
||||
uses: OpenHands/extensions/plugins/pr-review@main
|
||||
with:
|
||||
llm-model: litellm_proxy/claude-sonnet-4-5-20250929
|
||||
llm-base-url: https://llm-proxy.app.all-hands.dev
|
||||
review-style: roasted
|
||||
llm-api-key: ${{ secrets.LLM_API_KEY }}
|
||||
github-token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
lmnr-api-key: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
name: PR Review Evaluation
|
||||
|
||||
# This workflow evaluates how well PR review comments were addressed.
|
||||
# It runs when a PR is closed to assess review effectiveness.
|
||||
#
|
||||
# Security note: pull_request_target is safe here because:
|
||||
# 1. Only triggers on PR close (not on code changes)
|
||||
# 2. Does not checkout PR code - only downloads artifacts from trusted workflow runs
|
||||
# 3. Runs evaluation scripts from the extensions repo, not from the PR
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
evaluate:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
PR_MERGED: ${{ github.event.pull_request.merged }}
|
||||
|
||||
steps:
|
||||
- name: Download review trace artifact
|
||||
id: download-trace
|
||||
uses: dawidd6/action-download-artifact@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
workflow: pr-review-by-openhands.yml
|
||||
name: pr-review-trace-${{ github.event.pull_request.number }}
|
||||
path: trace-info
|
||||
search_artifacts: true
|
||||
if_no_artifact_found: warn
|
||||
|
||||
- name: Check if trace file exists
|
||||
id: check-trace
|
||||
run: |
|
||||
if [ -f "trace-info/laminar_trace_info.json" ]; then
|
||||
echo "trace_exists=true" >> $GITHUB_OUTPUT
|
||||
echo "Found trace file for PR #$PR_NUMBER"
|
||||
else
|
||||
echo "trace_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No trace file found for PR #$PR_NUMBER - skipping evaluation"
|
||||
fi
|
||||
|
||||
# Always checkout main branch for security - cannot test script changes in PRs
|
||||
- name: Checkout extensions repository
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
repository: OpenHands/extensions
|
||||
path: extensions
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
run: pip install lmnr
|
||||
|
||||
- name: Run evaluation
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
env:
|
||||
# Script expects LMNR_PROJECT_API_KEY; org secret is named LMNR_SKILLS_API_KEY
|
||||
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
python extensions/plugins/pr-review/scripts/evaluate_review.py \
|
||||
--trace-file trace-info/laminar_trace_info.json
|
||||
|
||||
- name: Upload evaluation logs
|
||||
uses: actions/upload-artifact@v5
|
||||
if: always() && steps.check-trace.outputs.trace_exists == 'true'
|
||||
with:
|
||||
name: pr-review-evaluation-${{ github.event.pull_request.number }}
|
||||
path: '*.log'
|
||||
retention-days: 30
|
||||
@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
|
||||
@@ -23,12 +23,23 @@ RUN apt-get update && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python packages with security fixes
|
||||
RUN /app/.venv/bin/pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace "posthog>=6.0.0" "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy google-cloud-recaptcha-enterprise && \
|
||||
# Update packages with known CVE fixes
|
||||
/app/.venv/bin/pip install --upgrade \
|
||||
"mcp>=1.10.0" \
|
||||
"pillow>=11.3.0"
|
||||
# Install poetry and export before importing current code.
|
||||
RUN /app/.venv/bin/pip install poetry poetry-plugin-export
|
||||
|
||||
# Install Python dependencies from poetry.lock for reproducible builds
|
||||
# Copy lock files first for better Docker layer caching
|
||||
COPY --chown=openhands:openhands enterprise/pyproject.toml enterprise/poetry.lock /tmp/enterprise/
|
||||
RUN cd /tmp/enterprise && \
|
||||
# Export only main dependencies with hashes for supply chain security
|
||||
/app/.venv/bin/poetry export --only main -o requirements.txt && \
|
||||
# Remove the local path dependency (openhands-ai is already in base image)
|
||||
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
|
||||
# Install pinned dependencies from lock file
|
||||
/app/.venv/bin/pip install -r requirements.txt && \
|
||||
# Cleanup - return to /app before removing /tmp/enterprise
|
||||
cd /app && \
|
||||
rm -rf /tmp/enterprise && \
|
||||
/app/.venv/bin/pip uninstall -y poetry poetry-plugin-export
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
|
||||
131
enterprise/doc/design-doc/plugin-launch-flow.md
Normal file
131
enterprise/doc/design-doc/plugin-launch-flow.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Plugin Launch Flow
|
||||
|
||||
This document describes how plugins are launched in OpenHands Saas / Enterprise, from the plugin directory through to agent execution.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Plugin Directory ──▶ Frontend /launch ──▶ App Server ──▶ Agent Server ──▶ SDK
|
||||
(external) (modal) (API) (in sandbox) (plugin loading)
|
||||
```
|
||||
|
||||
| Component | Responsibility |
|
||||
|-----------|---------------|
|
||||
| **Plugin Directory** | Index plugins, present to user, construct launch URLs |
|
||||
| **Frontend** | Display confirmation modal, collect parameters, call API |
|
||||
| **App Server** | Validate request, pass plugin specs to agent server |
|
||||
| **Agent Server** | Run inside sandbox, delegate plugin loading to SDK |
|
||||
| **SDK** | Fetch plugins, load contents, merge skills/hooks/MCP into agent |
|
||||
|
||||
## User Experience
|
||||
|
||||
### Plugin Directory
|
||||
|
||||
The plugin directory presents users with a catalog of available plugins. For each plugin, users see:
|
||||
- Plugin name and description (from `plugin.json`)
|
||||
- Author and version information
|
||||
- A "Launch" button
|
||||
|
||||
When a user clicks "Launch", the plugin directory:
|
||||
1. Reads the plugin's `entry_command` to know which slash command to invoke
|
||||
2. Determines what parameters the plugin accepts (if any)
|
||||
3. Redirects to OpenHands with this information encoded in the URL
|
||||
|
||||
### Parameter Collection
|
||||
|
||||
If a plugin requires user input (API keys, configuration values, etc.), the frontend displays a form modal before starting the conversation. Parameters are passed in the launch URL and rendered as form fields based on their type:
|
||||
|
||||
- **String values** → Text input
|
||||
- **Number values** → Number input
|
||||
- **Boolean values** → Checkbox
|
||||
|
||||
Only primitive types are supported. Complex types (arrays, objects) are not currently supported for parameter input.
|
||||
|
||||
The user fills in required values, then clicks "Start Conversation" to proceed.
|
||||
|
||||
## Launch Flow
|
||||
|
||||
1. **Plugin Directory** (external) constructs a launch URL to the OpenHands app server when user clicks "Launch":
|
||||
```
|
||||
/launch?plugins=BASE64_JSON&message=/city-weather:now%20Tokyo
|
||||
```
|
||||
|
||||
The `plugins` parameter includes any parameter definitions with default values:
|
||||
```json
|
||||
[{
|
||||
"source": "github:owner/repo",
|
||||
"repo_path": "plugins/my-plugin",
|
||||
"parameters": {"api_key": "", "timeout": 30, "debug": false}
|
||||
}]
|
||||
```
|
||||
|
||||
2. **OpenHands Frontend** (`/launch` route, [PR #12699](https://github.com/OpenHands/OpenHands/pull/12699)) displays modal with parameter form, collects user input
|
||||
|
||||
3. **OpenHands App Server** ([PR #12338](https://github.com/OpenHands/OpenHands/pull/12338)) receives the API call:
|
||||
```
|
||||
POST /api/v1/app-conversations
|
||||
{
|
||||
"plugins": [{"source": "github:owner/repo", "repo_path": "plugins/city-weather"}],
|
||||
"initial_message": {"content": [{"type": "text", "text": "/city-weather:now Tokyo"}]}
|
||||
}
|
||||
```
|
||||
|
||||
Call stack:
|
||||
- `AppConversationRouter` receives request with `PluginSpec` list
|
||||
- `LiveStatusAppConversationService._finalize_conversation_request()` converts `PluginSpec` → `PluginSource`
|
||||
- Creates `StartConversationRequest(plugins=sdk_plugins, ...)` and sends to agent server
|
||||
|
||||
4. **Agent Server** (inside sandbox, [SDK PR #1651](https://github.com/OpenHands/software-agent-sdk/pull/1651)) stores specs, defers loading:
|
||||
|
||||
Call stack:
|
||||
- `ConversationService.start_conversation()` receives `StartConversationRequest`
|
||||
- Creates `StoredConversation` with plugin specs
|
||||
- Creates `LocalConversation(plugins=request.plugins, ...)`
|
||||
- Plugin loading deferred until first `run()` or `send_message()`
|
||||
|
||||
5. **SDK** fetches and loads plugins on first use:
|
||||
|
||||
Call stack:
|
||||
- `LocalConversation._ensure_plugins_loaded()` triggered by first message
|
||||
- For each plugin spec:
|
||||
- `Plugin.fetch(source, ref, repo_path)` → clones/caches git repo
|
||||
- `Plugin.load(path)` → parses `plugin.json`, loads commands/skills/hooks
|
||||
- `plugin.add_skills_to(context)` → merges skills into agent
|
||||
- `plugin.add_mcp_config_to(config)` → merges MCP servers
|
||||
|
||||
6. **Agent** receives message, `/city-weather:now` triggers the skill
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### Plugin Loading in Sandbox
|
||||
|
||||
Plugins load **inside the sandbox** because:
|
||||
- Plugin hooks and scripts need isolated execution
|
||||
- MCP servers run inside the sandbox
|
||||
- Skills may reference sandbox filesystem
|
||||
|
||||
### Entry Command Handling
|
||||
|
||||
The `entry_command` field in `plugin.json` allows plugin authors to declare a default command:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "city-weather",
|
||||
"entry_command": "now"
|
||||
}
|
||||
```
|
||||
|
||||
This flows through the system:
|
||||
1. Plugin author declares `entry_command` in plugin.json
|
||||
2. Plugin directory reads it when indexing
|
||||
3. Plugin directory includes `/city-weather:now` in the launch URL's `message` parameter
|
||||
4. Message passes through to agent as `initial_message`
|
||||
|
||||
The SDK exposes this field but does not auto-invoke it—callers control the initial message.
|
||||
|
||||
## Related
|
||||
|
||||
- [OpenHands PR #12338](https://github.com/OpenHands/OpenHands/pull/12338) - App server plugin support
|
||||
- [OpenHands PR #12699](https://github.com/OpenHands/OpenHands/pull/12699) - Frontend `/launch` route
|
||||
- [SDK PR #1651](https://github.com/OpenHands/software-agent-sdk/pull/1651) - Agent server plugin loading
|
||||
- [SDK PR #1647](https://github.com/OpenHands/software-agent-sdk/pull/1647) - Plugin.fetch() for remote plugin fetching
|
||||
@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
return agent
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
# Skip experiment for planning agents which require their specialized prompt
|
||||
if agent.system_prompt_filename != 'system_prompt_planning.j2':
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -145,11 +145,7 @@ class GithubManager(Manager):
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
# Check event types before making expensive API calls (e.g., _user_has_write_access_to_repo)
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
@@ -159,8 +155,17 @@ class GithubManager(Manager):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
user_has_write_access = self._user_has_write_access_to_repo(
|
||||
installation_id, repo_name, username
|
||||
)
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
if (
|
||||
GithubFactory.is_eligible_for_conversation_starter(message)
|
||||
and user_has_write_access
|
||||
):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
return user_has_write_access
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
@@ -167,17 +167,15 @@ async def install_webhook_on_resource(
|
||||
scopes=SCOPES,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Creating new webhook',
|
||||
extra={
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
},
|
||||
)
|
||||
log_extra = {
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
}
|
||||
|
||||
if status == WebhookStatus.RATE_LIMITED:
|
||||
logger.warning('Rate limited while creating webhook', extra=log_extra)
|
||||
raise BreakLoopException()
|
||||
|
||||
if webhook_id:
|
||||
@@ -191,9 +189,8 @@ async def install_webhook_on_resource(
|
||||
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
|
||||
)
|
||||
logger.info('Created new webhook', extra=log_extra)
|
||||
else:
|
||||
logger.error('Failed to create webhook', extra=log_extra)
|
||||
|
||||
return webhook_id, status
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
@@ -14,6 +14,7 @@ class ResolverUserContext(UserContext):
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
self._provider_handler: ProviderHandler | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
@@ -29,12 +30,26 @@ class ResolverUserContext(UserContext):
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def _get_provider_handler(self) -> ProviderHandler:
|
||||
"""Get or create a ProviderHandler for git operations."""
|
||||
if self._provider_handler is None:
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('No provider tokens available')
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
self._provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_id=user_id
|
||||
)
|
||||
return self._provider_handler
|
||||
|
||||
async def get_authenticated_git_url(
|
||||
self, repository: str, is_optional: bool = False
|
||||
) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
provider_handler = await self._get_provider_handler()
|
||||
url = await provider_handler.get_authenticated_git_url(
|
||||
repository, is_optional=is_optional
|
||||
)
|
||||
return url
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
# Suppress alembic.runtime.plugins INFO logs during import to prevent non-JSON logs in production
|
||||
# These plugin setup messages would otherwise appear before logging is configured
|
||||
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine # noqa: E402
|
||||
from storage.base import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add byor_export_enabled flag to org table.
|
||||
|
||||
Revision ID: 091
|
||||
Revises: 090
|
||||
Create Date: 2025-01-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '091'
|
||||
down_revision: Union[str, None] = '090'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add byor_export_enabled column to org table with default false
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'byor_export_enabled',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Set byor_export_enabled to true for orgs that have completed billing sessions
|
||||
op.execute(
|
||||
sa.text("""
|
||||
UPDATE org SET byor_export_enabled = TRUE
|
||||
WHERE id IN (
|
||||
SELECT DISTINCT org_id FROM billing_sessions
|
||||
WHERE status = 'completed' AND org_id IS NOT NULL
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'byor_export_enabled')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Rename 'user' role to 'member' in role table.
|
||||
|
||||
Revision ID: 092
|
||||
Revises: 091
|
||||
Create Date: 2025-02-12 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '092'
|
||||
down_revision: Union[str, None] = '091'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename 'user' role to 'member' for clarity
|
||||
# This avoids confusion between the 'user' role and the 'user' entity/account
|
||||
op.execute(sa.text("UPDATE role SET name = 'member' WHERE name = 'user'"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert 'member' role back to 'user'
|
||||
op.execute(sa.text("UPDATE role SET name = 'user' WHERE name = 'member'"))
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add pending_free_credits flag to org table.
|
||||
|
||||
Revision ID: 093
|
||||
Revises: 092
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '093'
|
||||
down_revision: Union[str, None] = '092'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add pending_free_credits column to org table with default false.
|
||||
# New orgs will have this set to TRUE at creation time.
|
||||
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""create org_invitation table
|
||||
|
||||
Revision ID: 094
|
||||
Revises: 093
|
||||
Create Date: 2026-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '094'
|
||||
down_revision: Union[str, None] = '093'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create org_invitation table
|
||||
op.create_table(
|
||||
'org_invitation',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('token', sa.String(64), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Foreign key constraints
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'],
|
||||
['org.id'],
|
||||
name='org_invitation_org_fkey',
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['role_id'],
|
||||
['role.id'],
|
||||
name='org_invitation_role_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['inviter_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_inviter_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['accepted_by_user_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_accepter_fkey',
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index(
|
||||
'ix_org_invitation_token',
|
||||
'org_invitation',
|
||||
['token'],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_id',
|
||||
'org_invitation',
|
||||
['org_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_email',
|
||||
'org_invitation',
|
||||
['email'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_status',
|
||||
'org_invitation',
|
||||
['status'],
|
||||
)
|
||||
# Composite index for checking pending invitations
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_email_status',
|
||||
'org_invitation',
|
||||
['org_id', 'email', 'status'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('org_invitation')
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Drop pending_free_credits column from org table.
|
||||
|
||||
Revision ID: 095
|
||||
Revises: 094
|
||||
Create Date: 2025-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '095'
|
||||
down_revision: Union[str, None] = '094'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the pending_free_credits column from org table.
|
||||
# This column was used for tracking free credit eligibility but is no longer needed.
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add pending_free_credits column with default false.
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Create resend_synced_users table.
|
||||
|
||||
Revision ID: 096
|
||||
Revises: 095
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '096'
|
||||
down_revision: Union[str, None] = '095'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
|
||||
op.create_table(
|
||||
'resend_synced_users',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('audience_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'synced_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
|
||||
# Create index on email for fast lookups
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_email',
|
||||
'resend_synced_users',
|
||||
['email'],
|
||||
)
|
||||
|
||||
# Create index on audience_id for filtering by audience
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_audience_id',
|
||||
'resend_synced_users',
|
||||
['audience_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop resend_synced_users table."""
|
||||
op.drop_index(
|
||||
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
|
||||
)
|
||||
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
|
||||
op.drop_table('resend_synced_users')
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add session_api_key_hash to v1_remote_sandbox table
|
||||
|
||||
Revision ID: 097
|
||||
Revises: 096
|
||||
Create Date: 2025-02-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '097'
|
||||
down_revision: Union[str, None] = '096'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add session_api_key_hash column to v1_remote_sandbox table."""
|
||||
op.add_column(
|
||||
'v1_remote_sandbox',
|
||||
sa.Column('session_api_key_hash', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
'v1_remote_sandbox',
|
||||
['session_api_key_hash'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
|
||||
op.drop_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
table_name='v1_remote_sandbox',
|
||||
)
|
||||
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Create verified_models table.
|
||||
|
||||
Revision ID: 098
|
||||
Revises: 097
|
||||
Create Date: 2026-02-26 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '098'
|
||||
down_revision: Union[str, None] = '097'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create verified_models table and seed with current model list."""
|
||||
op.create_table(
|
||||
'verified_models',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('model_name', sa.String(255), nullable=False),
|
||||
sa.Column('provider', sa.String(100), nullable=False),
|
||||
sa.Column(
|
||||
'is_enabled',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
'model_name', 'provider', name='uq_verified_model_provider'
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_verified_models_provider',
|
||||
'verified_models',
|
||||
['provider'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_verified_models_is_enabled',
|
||||
'verified_models',
|
||||
['is_enabled'],
|
||||
)
|
||||
|
||||
# Seed with current openhands provider models
|
||||
models = [
|
||||
('claude-opus-4-5-20251101', 'openhands'),
|
||||
('claude-sonnet-4-5-20250929', 'openhands'),
|
||||
('gpt-5.2-codex', 'openhands'),
|
||||
('gpt-5.2', 'openhands'),
|
||||
('minimax-m2.5', 'openhands'),
|
||||
('gemini-3-pro-preview', 'openhands'),
|
||||
('gemini-3-flash-preview', 'openhands'),
|
||||
('deepseek-chat', 'openhands'),
|
||||
('devstral-medium-2512', 'openhands'),
|
||||
('kimi-k2-0711-preview', 'openhands'),
|
||||
('qwen3-coder-480b', 'openhands'),
|
||||
]
|
||||
|
||||
for model_name, provider in models:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO verified_models (model_name, provider)
|
||||
VALUES (:model_name, :provider)
|
||||
"""
|
||||
).bindparams(model_name=model_name, provider=provider)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop verified_models table."""
|
||||
op.drop_index('ix_verified_models_is_enabled', table_name='verified_models')
|
||||
op.drop_index('ix_verified_models_provider', table_name='verified_models')
|
||||
op.drop_table('verified_models')
|
||||
384
enterprise/poetry.lock
generated
384
enterprise/poetry.lock
generated
@@ -1540,66 +1540,58 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "46.0.3"
|
||||
version = "46.0.5"
|
||||
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
||||
optional = false
|
||||
python-versions = "!=3.9.0,!=3.9.1,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "cryptography-46.0.3-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:109d4ddfadf17e8e7779c39f9b18111a09efb969a301a31e987416a0191ed93a"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:09859af8466b69bc3c27bdf4f5d84a665e0f7ab5088412e9e2ec49758eca5cbc"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:01ca9ff2885f3acc98c29f1860552e37f6d7c7d013d7334ff2a9de43a449315d"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:6eae65d4c3d33da080cff9c4ab1f711b15c1d9760809dad6ea763f3812d254cb"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5bf0ed4490068a2e72ac03d786693adeb909981cc596425d09032d372bcc849"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:5ecfccd2329e37e9b7112a888e76d9feca2347f12f37918facbb893d7bb88ee8"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a2c0cd47381a3229c403062f764160d57d4d175e022c1df84e168c6251a22eec"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:549e234ff32571b1f4076ac269fcce7a808d3bf98b76c8dd560e42dbc66d7d91"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:c0a7bb1a68a5d3471880e264621346c48665b3bf1c3759d682fc0864c540bd9e"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:10b01676fc208c3e6feeb25a8b83d81767e8059e1fe86e1dc62d10a3018fa926"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0abf1ffd6e57c67e92af68330d05760b7b7efb243aab8377e583284dbab72c71"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a04bee9ab6a4da801eb9b51f1b708a1b5b5c9eb48c03f74198464c66f0d344ac"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-win32.whl", hash = "sha256:f260d0d41e9b4da1ed1e0f1ce571f97fe370b152ab18778e9e8f67d6af432018"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-win_amd64.whl", hash = "sha256:a9a3008438615669153eb86b26b61e09993921ebdd75385ddd748702c5adfddb"},
|
||||
{file = "cryptography-46.0.3-cp311-abi3-win_arm64.whl", hash = "sha256:5d7f93296ee28f68447397bf5198428c9aeeab45705a55d53a6343455dcb2c3c"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:00a5e7e87938e5ff9ff5447ab086a5706a957137e6e433841e9d24f38a065217"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c8daeb2d2174beb4575b77482320303f3d39b8e81153da4f0fb08eb5fe86a6c5"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:39b6755623145ad5eff1dab323f4eae2a32a77a7abef2c5089a04a3d04366715"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:db391fa7c66df6762ee3f00c95a89e6d428f4d60e7abc8328f4fe155b5ac6e54"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:78a97cf6a8839a48c49271cdcbd5cf37ca2c1d6b7fdd86cc864f302b5e9bf459"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:dfb781ff7eaa91a6f7fd41776ec37c5853c795d3b358d4896fdbb5df168af422"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:6f61efb26e76c45c4a227835ddeae96d83624fb0d29eb5df5b96e14ed1a0afb7"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:23b1a8f26e43f47ceb6d6a43115f33a5a37d57df4ea0ca295b780ae8546e8044"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:b419ae593c86b87014b9be7396b385491ad7f320bde96826d0dd174459e54665"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:50fc3343ac490c6b08c0cf0d704e881d0d660be923fd3076db3e932007e726e3"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:22d7e97932f511d6b0b04f2bfd818d73dcd5928db509460aaf48384778eb6d20"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:d55f3dffadd674514ad19451161118fd010988540cee43d8bc20675e775925de"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-win32.whl", hash = "sha256:8a6e050cb6164d3f830453754094c086ff2d0b2f3a897a1d9820f6139a1f0914"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-win_amd64.whl", hash = "sha256:760f83faa07f8b64e9c33fc963d790a2edb24efb479e3520c14a45741cd9b2db"},
|
||||
{file = "cryptography-46.0.3-cp314-cp314t-win_arm64.whl", hash = "sha256:516ea134e703e9fe26bcd1277a4b59ad30586ea90c365a87781d7887a646fe21"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:cb3d760a6117f621261d662bccc8ef5bc32ca673e037c83fbe565324f5c46936"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4b7387121ac7d15e550f5cb4a43aef2559ed759c35df7336c402bb8275ac9683"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:15ab9b093e8f09daab0f2159bb7e47532596075139dd74365da52ecc9cb46c5d"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:46acf53b40ea38f9c6c229599a4a13f0d46a6c3fa9ef19fc1a124d62e338dfa0"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:10ca84c4668d066a9878890047f03546f3ae0a6b8b39b697457b7757aaf18dbc"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:36e627112085bb3b81b19fed209c05ce2a52ee8b15d161b7c643a7d5a88491f3"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:1000713389b75c449a6e979ffc7dcc8ac90b437048766cef052d4d30b8220971"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:b02cf04496f6576afffef5ddd04a0cb7d49cf6be16a9059d793a30b035f6b6ac"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:71e842ec9bc7abf543b47cf86b9a743baa95f4677d22baa4c7d5c69e49e9bc04"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:402b58fc32614f00980b66d6e56a5b4118e6cb362ae8f3fda141ba4689bd4506"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ef639cb3372f69ec44915fafcd6698b6cc78fbe0c2ea41be867f6ed612811963"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3b51b8ca4f1c6453d8829e1eb7299499ca7f313900dd4d89a24b8b87c0a780d4"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-win32.whl", hash = "sha256:6276eb85ef938dc035d59b87c8a7dc559a232f954962520137529d77b18ff1df"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-win_amd64.whl", hash = "sha256:416260257577718c05135c55958b674000baef9a1c7d9e8f306ec60d71db850f"},
|
||||
{file = "cryptography-46.0.3-cp38-abi3-win_arm64.whl", hash = "sha256:d89c3468de4cdc4f08a57e214384d0471911a3830fcdaf7a8cc587e42a866372"},
|
||||
{file = "cryptography-46.0.3-pp310-pypy310_pp73-macosx_10_9_x86_64.whl", hash = "sha256:a23582810fedb8c0bc47524558fb6c56aac3fc252cb306072fd2815da2a47c32"},
|
||||
{file = "cryptography-46.0.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:e7aec276d68421f9574040c26e2a7c3771060bc0cff408bae1dcb19d3ab1e63c"},
|
||||
{file = "cryptography-46.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:7ce938a99998ed3c8aa7e7272dca1a610401ede816d36d0693907d863b10d9ea"},
|
||||
{file = "cryptography-46.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:191bb60a7be5e6f54e30ba16fdfae78ad3a342a0599eb4193ba88e3f3d6e185b"},
|
||||
{file = "cryptography-46.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:c70cc23f12726be8f8bc72e41d5065d77e4515efae3690326764ea1b07845cfb"},
|
||||
{file = "cryptography-46.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:9394673a9f4de09e28b5356e7fff97d778f8abad85c9d5ac4a4b7e25a0de7717"},
|
||||
{file = "cryptography-46.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:94cd0549accc38d1494e1f8de71eca837d0509d0d44bf11d158524b0e12cebf9"},
|
||||
{file = "cryptography-46.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6b5063083824e5509fdba180721d55909ffacccc8adbec85268b48439423d78c"},
|
||||
{file = "cryptography-46.0.3.tar.gz", hash = "sha256:a8b17438104fed022ce745b362294d9ce35b4c2e45c1d958ad4a4b019285f4a1"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48"},
|
||||
{file = "cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-macosx_10_9_universal2.whl", hash = "sha256:94a76daa32eb78d61339aff7952ea819b1734b46f73646a07decb40e5b3448e2"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5be7bf2fb40769e05739dd0046e7b26f9d4670badc7b032d6ce4db64dddc0678"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fe346b143ff9685e40192a4960938545c699054ba11d4f9029f94751e3f71d87"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:c69fd885df7d089548a42d5ec05be26050ebcd2283d89b3d30676eb32ff87dee"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_ppc64le.whl", hash = "sha256:8293f3dea7fc929ef7240796ba231413afa7b68ce38fd21da2995549f5961981"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:1abfdb89b41c3be0365328a410baa9df3ff8a9110fb75e7b52e66803ddabc9a9"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_31_armv7l.whl", hash = "sha256:d66e421495fdb797610a08f43b05269e0a5ea7f5e652a89bfd5a7d3c1dee3648"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:4e817a8920bfbcff8940ecfd60f23d01836408242b30f1a708d93198393a80b4"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_ppc64le.whl", hash = "sha256:68f68d13f2e1cb95163fa3b4db4bf9a159a418f5f6e7242564fc75fcae667fd0"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a3d1fae9863299076f05cb8a778c467578262fae09f9dc0ee9b12eb4268ce663"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:c4143987a42a2397f2fc3b4d7e3a7d313fbe684f67ff443999e803dd75a76826"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:7d731d4b107030987fd61a7f8ab512b25b53cef8f233a97379ede116f30eb67d"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-win32.whl", hash = "sha256:c3bcce8521d785d510b2aad26ae2c966092b7daa8f45dd8f44734a104dc0bc1a"},
|
||||
{file = "cryptography-46.0.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4d8ae8659ab18c65ced284993c2265910f6c9e650189d4e3f68445ef82a810e4"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72"},
|
||||
{file = "cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a"},
|
||||
{file = "cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356"},
|
||||
{file = "cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -1612,7 +1604,7 @@ nox = ["nox[uv] (>=2024.4.15)"]
|
||||
pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.14)", "ruff (>=0.11.11)"]
|
||||
sdist = ["build (>=1.0.0)"]
|
||||
ssh = ["bcrypt (>=3.1.5)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.3)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test = ["certifi (>=2024)", "cryptography-vectors (==46.0.5)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"]
|
||||
test-randomorder = ["pytest-randomly"]
|
||||
|
||||
[[package]]
|
||||
@@ -5754,14 +5746,14 @@ test = ["flaky", "ipykernel (>=6.19.3)", "ipython", "ipywidgets", "nbconvert (>=
|
||||
|
||||
[[package]]
|
||||
name = "nbconvert"
|
||||
version = "7.16.6"
|
||||
description = "Converting Jupyter Notebooks (.ipynb files) to other formats. Output formats include asciidoc, html, latex, markdown, pdf, py, rst, script. nbconvert can be used both as a Python library (`import nbconvert`) or as a command line tool (invoked as `jupyter nbconvert ...`)."
|
||||
version = "7.17.0"
|
||||
description = "Convert Jupyter Notebooks (.ipynb files) to other formats."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "nbconvert-7.16.6-py3-none-any.whl", hash = "sha256:1375a7b67e0c2883678c48e506dc320febb57685e5ee67faa51b18a90f3a712b"},
|
||||
{file = "nbconvert-7.16.6.tar.gz", hash = "sha256:576a7e37c6480da7b8465eefa66c17844243816ce1ccc372633c6b71c3c0f582"},
|
||||
{file = "nbconvert-7.17.0-py3-none-any.whl", hash = "sha256:4f99a63b337b9a23504347afdab24a11faa7d86b405e5c8f9881cd313336d518"},
|
||||
{file = "nbconvert-7.17.0.tar.gz", hash = "sha256:1b2696f1b5be12309f6c7d707c24af604b87dfaf6d950794c7b07acab96dda78"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -5781,8 +5773,8 @@ pygments = ">=2.4.1"
|
||||
traitlets = ">=5.1"
|
||||
|
||||
[package.extras]
|
||||
all = ["flaky", "ipykernel", "ipython", "ipywidgets (>=7.5)", "myst-parser", "nbsphinx (>=0.2.12)", "playwright", "pydata-sphinx-theme", "pyqtwebengine (>=5.15)", "pytest (>=7)", "sphinx (==5.0.2)", "sphinxcontrib-spelling", "tornado (>=6.1)"]
|
||||
docs = ["ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (==5.0.2)", "sphinxcontrib-spelling"]
|
||||
all = ["flaky", "intersphinx-registry", "ipykernel", "ipython", "ipywidgets (>=7.5)", "myst-parser", "nbsphinx (>=0.2.12)", "playwright", "pydata-sphinx-theme", "pyqtwebengine (>=5.15)", "pytest (>=7)", "sphinx (>=5.0.2)", "sphinxcontrib-spelling", "tornado (>=6.1)"]
|
||||
docs = ["intersphinx-registry", "ipykernel", "ipython", "myst-parser", "nbsphinx (>=0.2.12)", "pydata-sphinx-theme", "sphinx (>=5.0.2)", "sphinxcontrib-spelling"]
|
||||
qtpdf = ["pyqtwebengine (>=5.15)"]
|
||||
qtpng = ["pyqtwebengine (>=5.15)"]
|
||||
serve = ["tornado (>=6.1)"]
|
||||
@@ -6102,14 +6094,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.10.0"
|
||||
version = "1.11.5"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
|
||||
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
|
||||
{file = "openhands_agent_server-1.11.5-py3-none-any.whl", hash = "sha256:8bae7063f232791d58a5c31919f58b557f7cce60e6295773985c7dadc556cb9e"},
|
||||
{file = "openhands_agent_server-1.11.5.tar.gz", hash = "sha256:b61366d727c61ab9b7fcd66faab53f230f8ef0928c1177a388d2c5c4be6ebbd0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6126,7 +6118,7 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "1.2.1"
|
||||
version = "1.4.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@@ -6168,9 +6160,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-sdk = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
openhands-agent-server = "1.11.5"
|
||||
openhands-sdk = "1.11.5"
|
||||
openhands-tools = "1.11.5"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6194,7 +6186,7 @@ python-jose = {version = ">=3.3", extras = ["cryptography"]}
|
||||
python-json-logger = ">=3.2.1"
|
||||
python-multipart = "*"
|
||||
python-pptx = "*"
|
||||
python-socketio = "5.13"
|
||||
python-socketio = "5.14"
|
||||
pythonnet = "*"
|
||||
pyyaml = ">=6.0.2"
|
||||
qtconsole = ">=5.6.1"
|
||||
@@ -6225,14 +6217,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.10.0"
|
||||
version = "1.11.5"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
|
||||
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
|
||||
{file = "openhands_sdk-1.11.5-py3-none-any.whl", hash = "sha256:f949cd540cbecc339d90fb0cca2a5f29e1b62566b82b5aee82ef40f259d14e60"},
|
||||
{file = "openhands_sdk-1.11.5.tar.gz", hash = "sha256:dd6225876b7b8dbb6c608559f2718c3d0bf44d0bb741e990b185c6cdc5150c5a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6253,14 +6245,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.10.0"
|
||||
version = "1.11.5"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
|
||||
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
|
||||
{file = "openhands_tools-1.11.5-py3-none-any.whl", hash = "sha256:1e981e1e7f3544184fe946cee8eb6bd287010cdef77d83ebac945c9f42df3baf"},
|
||||
{file = "openhands_tools-1.11.5.tar.gz", hash = "sha256:d7b1163f6505a51b07147e7d8972062c129ecc46571a71f28d5470355e06650e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6851,103 +6843,103 @@ scramp = ">=1.4.5"
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "12.1.0"
|
||||
version = "12.1.1"
|
||||
description = "Python Imaging Library (fork)"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main", "test"]
|
||||
files = [
|
||||
{file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"},
|
||||
{file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1f1625b72740fdda5d77b4def688eb8fd6490975d06b909fd19f13f391e077e0"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:178aa072084bd88ec759052feca8e56cbb14a60b39322b99a049e58090479713"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b66e95d05ba806247aaa1561f080abc7975daf715c30780ff92a20e4ec546e1b"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89c7e895002bbe49cdc5426150377cbbc04767d7547ed145473f496dfa40408b"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a5cbdcddad0af3da87cb16b60d23648bc3b51967eb07223e9fed77a82b457c4"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9f51079765661884a486727f0729d29054242f74b46186026582b4e4769918e4"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:99c1506ea77c11531d75e3a412832a13a71c7ebc8192ab9e4b2e355555920e3e"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:36341d06738a9f66c8287cf8b876d24b18db9bd8740fa0672c74e259ad408cff"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win32.whl", hash = "sha256:6c52f062424c523d6c4db85518774cc3d50f5539dd6eed32b8f6229b26f24d40"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6008de247150668a705a6338156efb92334113421ceecf7438a12c9a12dab23"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:1a9b0ee305220b392e1124a764ee4265bd063e54a751a6b62eff69992f457fa9"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e879bb6cd5c73848ef3b2b48b8af9ff08c5b71ecda8048b7dd22d8a33f60be32"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:365b10bb9417dd4498c0e3b128018c4a624dc11c7b97d8cc54effe3b096f4c38"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d4ce8e329c93845720cd2014659ca67eac35f6433fd3050393d85f3ecef0dad5"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc354a04072b765eccf2204f588a7a532c9511e8b9c7f900e1b64e3e33487090"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e7976bf1910a8116b523b9f9f58bf410f3e8aa330cd9a2bb2953f9266ab49af"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:597bd9c8419bc7c6af5604e55847789b69123bbe25d65cc6ad3012b4f3c98d8b"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2c1fc0f2ca5f96a3c8407e41cca26a16e46b21060fe6d5b099d2cb01412222f5"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:578510d88c6229d735855e1f278aa305270438d36a05031dfaae5067cc8eb04d"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win32.whl", hash = "sha256:7311c0a0dcadb89b36b7025dfd8326ecfa36964e29913074d47382706e516a7c"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:fbfa2a7c10cc2623f412753cddf391c7f971c52ca40a3f65dc5039b2939e8563"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:b81b5e3511211631b3f672a595e3221252c90af017e399056d0faabb9538aa80"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab323b787d6e18b3d91a72fc99b1a2c28651e4358749842b8f8dfacd28ef2052"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:adebb5bee0f0af4909c30db0d890c773d1a92ffe83da908e2e9e720f8edf3984"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb66b7cc26f50977108790e2456b7921e773f23db5630261102233eb355a3b79"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aee2810642b2898bb187ced9b349e95d2a7272930796e022efaf12e99dccd293"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a0b1cd6232e2b618adcc54d9882e4e662a089d5768cd188f7c245b4c8c44a397"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7aac39bcf8d4770d089588a2e1dd111cbaa42df5a94be3114222057d68336bd0"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ab174cd7d29a62dd139c44bf74b698039328f45cb03b4596c43473a46656b2f3"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:339ffdcb7cbeaa08221cd401d517d4b1fe7a9ed5d400e4a8039719238620ca35"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win32.whl", hash = "sha256:5d1f9575a12bed9e9eedd9a4972834b08c97a352bd17955ccdebfeca5913fa0a"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:21329ec8c96c6e979cd0dfd29406c40c1d52521a90544463057d2aaa937d66a6"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:af9a332e572978f0218686636610555ae3defd1633597be015ed50289a03c523"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:d242e8ac078781f1de88bf823d70c1a9b3c7950a44cdf4b7c012e22ccbcd8e4e"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:02f84dfad02693676692746df05b89cf25597560db2857363a208e393429f5e9"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:e65498daf4b583091ccbb2556c7000abf0f3349fcd57ef7adc9a84a394ed29f6"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6c6db3b84c87d48d0088943bf33440e0c42370b99b1c2a7989216f7b42eede60"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8b7e5304e34942bf62e15184219a7b5ad4ff7f3bb5cca4d984f37df1a0e1aee2"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:18e5bddd742a44b7e6b1e773ab5db102bd7a94c32555ba656e76d319d19c3850"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc44ef1f3de4f45b50ccf9136999d71abb99dca7706bc75d222ed350b9fd2289"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5a8eb7ed8d4198bccbd07058416eeec51686b498e784eda166395a23eb99138e"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47b94983da0c642de92ced1702c5b6c292a84bd3a8e1d1702ff923f183594717"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:518a48c2aab7ce596d3bf79d0e275661b846e86e4d0e7dec34712c30fe07f02a"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a550ae29b95c6dc13cf69e2c9dc5747f814c54eeb2e32d683e5e93af56caa029"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win32.whl", hash = "sha256:a003d7422449f6d1e3a34e3dd4110c22148336918ddbfc6a32581cd54b2e0b2b"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:344cf1e3dab3be4b1fa08e449323d98a2a3f819ad20f4b22e77a0ede31f0faa1"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:5c0dd1636633e7e6a0afe7bf6a51a14992b7f8e60de5789018ebbdfae55b040a"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0330d233c1a0ead844fc097a7d16c0abff4c12e856c0b325f231820fee1f39da"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5dae5f21afb91322f2ff791895ddd8889e5e947ff59f71b46041c8ce6db790bc"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e0c664be47252947d870ac0d327fea7e63985a08794758aa8af5b6cb6ec0c9c"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:691ab2ac363b8217f7d31b3497108fb1f50faab2f75dfb03284ec2f217e87bf8"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9e8064fb1cc019296958595f6db671fba95209e3ceb0c4734c9baf97de04b20"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:472a8d7ded663e6162dafdf20015c486a7009483ca671cece7a9279b512fcb13"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:89b54027a766529136a06cfebeecb3a04900397a3590fd252160b888479517bf"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:86172b0831b82ce4f7877f280055892b31179e1576aa00d0df3bb1bbf8c3e524"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win32.whl", hash = "sha256:44ce27545b6efcf0fdbdceb31c9a5bdea9333e664cda58a7e674bb74608b3986"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a285e3eb7a5a45a2ff504e31f4a8d1b12ef62e84e5411c6804a42197c1cf586c"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cc7d296b5ea4d29e6570dabeaed58d31c3fea35a633a69679fb03d7664f43fb3"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:417423db963cb4be8bac3fc1204fe61610f6abeed1580a7a2cbb2fbda20f12af"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:b957b71c6b2387610f556a7eb0828afbe40b4a98036fc0d2acfa5a44a0c2036f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:097690ba1f2efdeb165a20469d59d8bb03c55fb6621eb2041a060ae8ea3e9642"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2815a87ab27848db0321fb78c7f0b2c8649dee134b7f2b80c6a45c6831d75ccd"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f7ed2c6543bad5a7d5530eb9e78c53132f93dfa44a28492db88b41cdab885202"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:652a2c9ccfb556235b2b501a3a7cf3742148cd22e04b5625c5fe057ea3e3191f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d6e4571eedf43af33d0fc233a382a76e849badbccdf1ac438841308652a08e1f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b574c51cf7d5d62e9be37ba446224b59a2da26dc4c1bb2ecbe936a4fb1a7cb7f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a37691702ed687799de29a518d63d4682d9016932db66d4e90c345831b02fb4e"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f95c00d5d6700b2b890479664a06e754974848afaae5e21beb4d83c106923fd0"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:559b38da23606e68681337ad74622c4dbba02254fc9cb4488a305dd5975c7eeb"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win32.whl", hash = "sha256:03edcc34d688572014ff223c125a3f77fb08091e4607e7745002fc214070b35f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:50480dcd74fa63b8e78235957d302d98d98d82ccbfac4c7e12108ba9ecbdba15"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:5cb1785d97b0c3d1d1a16bc1d710c4a0049daefc4935f3a8f31f827f4d3d2e7f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:1f90cff8aa76835cba5769f0b3121a22bd4eb9e6884cfe338216e557a9a548b8"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1f1be78ce9466a7ee64bfda57bdba0f7cc499d9794d518b854816c41bf0aa4e9"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:42fc1f4677106188ad9a55562bbade416f8b55456f522430fadab3cef7cd4e60"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98edb152429ab62a1818039744d8fbb3ccab98a7c29fc3d5fcef158f3f1f68b7"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d470ab1178551dd17fdba0fef463359c41aaa613cdcd7ff8373f54be629f9f8f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6408a7b064595afcab0a49393a413732a35788f2a5092fdc6266952ed67de586"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5d8c41325b382c07799a3682c1c258469ea2ff97103c53717b7893862d0c98ce"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:c7697918b5be27424e9ce568193efd13d925c4481dd364e43f5dff72d33e10f8"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win32.whl", hash = "sha256:d2912fd8114fc5545aa3a4b5576512f64c55a03f3ebcca4c10194d593d43ea36"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:4ceb838d4bd9dab43e06c363cab2eebf63846d6a4aeaea283bbdfd8f1a8ed58b"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7b03048319bfc6170e93bd60728a1af51d3dd7704935feb228c4d4faab35d334"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:600fd103672b925fe62ed08e0d874ea34d692474df6f4bf7ebe148b30f89f39f"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:665e1b916b043cef294bc54d47bf02d87e13f769bc4bc5fa225a24b3a6c5aca9"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:495c302af3aad1ca67420ddd5c7bd480c8867ad173528767d906428057a11f0e"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8fd420ef0c52c88b5a035a0886f367748c72147b2b8f384c9d12656678dfdfa9"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f975aa7ef9684ce7e2c18a3aa8f8e2106ce1e46b94ab713d156b2898811651d3"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8089c852a56c2966cf18835db62d9b34fef7ba74c726ad943928d494fa7f4735"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:cb9bb857b2d057c6dfc72ac5f3b44836924ba15721882ef103cecb40d002d80e"},
|
||||
{file = "pillow-12.1.1.tar.gz", hash = "sha256:9ad8fa5937ab05218e2b6a4cff30295ad35afd2f83ac592e68c0d871bb0fdbc4"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -7323,23 +7315,23 @@ testing = ["google-api-core (>=1.31.5)"]
|
||||
|
||||
[[package]]
|
||||
name = "protobuf"
|
||||
version = "5.29.5"
|
||||
version = "5.29.6"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "protobuf-5.29.5-cp310-abi3-win32.whl", hash = "sha256:3f1c6468a2cfd102ff4703976138844f78ebd1fb45f49011afc5139e9e283079"},
|
||||
{file = "protobuf-5.29.5-cp310-abi3-win_amd64.whl", hash = "sha256:3f76e3a3675b4a4d867b52e4a5f5b78a2ef9565549d4037e06cf7b0942b1d3fc"},
|
||||
{file = "protobuf-5.29.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e38c5add5a311f2a6eb0340716ef9b039c1dfa428b28f25a7838ac329204a671"},
|
||||
{file = "protobuf-5.29.5-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:fa18533a299d7ab6c55a238bf8629311439995f2e7eca5caaff08663606e9015"},
|
||||
{file = "protobuf-5.29.5-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:63848923da3325e1bf7e9003d680ce6e14b07e55d0473253a690c3a8b8fd6e61"},
|
||||
{file = "protobuf-5.29.5-cp38-cp38-win32.whl", hash = "sha256:ef91363ad4faba7b25d844ef1ada59ff1604184c0bcd8b39b8a6bef15e1af238"},
|
||||
{file = "protobuf-5.29.5-cp38-cp38-win_amd64.whl", hash = "sha256:7318608d56b6402d2ea7704ff1e1e4597bee46d760e7e4dd42a3d45e24b87f2e"},
|
||||
{file = "protobuf-5.29.5-cp39-cp39-win32.whl", hash = "sha256:6f642dc9a61782fa72b90878af134c5afe1917c89a568cd3476d758d3c3a0736"},
|
||||
{file = "protobuf-5.29.5-cp39-cp39-win_amd64.whl", hash = "sha256:470f3af547ef17847a28e1f47200a1cbf0ba3ff57b7de50d22776607cd2ea353"},
|
||||
{file = "protobuf-5.29.5-py3-none-any.whl", hash = "sha256:6cf42630262c59b2d8de33954443d94b746c952b01434fc58a417fdbd2e84bd5"},
|
||||
{file = "protobuf-5.29.5.tar.gz", hash = "sha256:bc1463bafd4b0929216c35f437a8e28731a2b7fe3d98bb77a600efced5a15c84"},
|
||||
{file = "protobuf-5.29.6-cp310-abi3-win32.whl", hash = "sha256:62e8a3114992c7c647bce37dcc93647575fc52d50e48de30c6fcb28a6a291eb1"},
|
||||
{file = "protobuf-5.29.6-cp310-abi3-win_amd64.whl", hash = "sha256:7e6ad413275be172f67fdee0f43484b6de5a904cc1c3ea9804cb6fe2ff366eda"},
|
||||
{file = "protobuf-5.29.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:b5a169e664b4057183a34bdc424540e86eea47560f3c123a0d64de4e137f9269"},
|
||||
{file = "protobuf-5.29.6-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:a8866b2cff111f0f863c1b3b9e7572dc7eaea23a7fae27f6fc613304046483e6"},
|
||||
{file = "protobuf-5.29.6-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:e3387f44798ac1106af0233c04fb8abf543772ff241169946f698b3a9a3d3ab9"},
|
||||
{file = "protobuf-5.29.6-cp38-cp38-win32.whl", hash = "sha256:36ade6ff88212e91aef4e687a971a11d7d24d6948a66751abc1b3238648f5d05"},
|
||||
{file = "protobuf-5.29.6-cp38-cp38-win_amd64.whl", hash = "sha256:831e2da16b6cc9d8f1654c041dd594eda43391affd3c03a91bea7f7f6da106d6"},
|
||||
{file = "protobuf-5.29.6-cp39-cp39-win32.whl", hash = "sha256:cb4c86de9cd8a7f3a256b9744220d87b847371c6b2f10bde87768918ef33ba49"},
|
||||
{file = "protobuf-5.29.6-cp39-cp39-win_amd64.whl", hash = "sha256:76e07e6567f8baf827137e8d5b8204b6c7b6488bbbff1bf0a72b383f77999c18"},
|
||||
{file = "protobuf-5.29.6-py3-none-any.whl", hash = "sha256:6b9edb641441b2da9fa8f428760fc136a49cf97a52076010cf22a2ff73438a86"},
|
||||
{file = "protobuf-5.29.6.tar.gz", hash = "sha256:da9ee6a5424b6b30fd5e45c5ea663aef540ca95f9ad99d1e887e819cdf9b8723"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7562,14 +7554,14 @@ typing-extensions = ">=4.15.0"
|
||||
|
||||
[[package]]
|
||||
name = "pyasn1"
|
||||
version = "0.6.1"
|
||||
version = "0.6.2"
|
||||
description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"},
|
||||
{file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"},
|
||||
{file = "pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf"},
|
||||
{file = "pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -11578,20 +11570,20 @@ diagrams = ["jinja2", "railroad-diagrams"]
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.6.0"
|
||||
version = "6.7.3"
|
||||
description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pypdf-6.6.0-py3-none-any.whl", hash = "sha256:bca9091ef6de36c7b1a81e09327c554b7ce51e88dad68f5890c2b4a4417f1fd7"},
|
||||
{file = "pypdf-6.6.0.tar.gz", hash = "sha256:4c887ef2ea38d86faded61141995a3c7d068c9d6ae8477be7ae5de8a8e16592f"},
|
||||
{file = "pypdf-6.7.3-py3-none-any.whl", hash = "sha256:cd25ac508f20b554a9fafd825186e3ba29591a69b78c156783c5d8a2d63a1c0a"},
|
||||
{file = "pypdf-6.7.3.tar.gz", hash = "sha256:eca55c78d0ec7baa06f9288e2be5c4e8242d5cbb62c7a4b94f2716f8e50076d2"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
crypto = ["cryptography"]
|
||||
cryptodome = ["PyCryptodome"]
|
||||
dev = ["black", "flit", "pip-tools", "pre-commit", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"]
|
||||
dev = ["flit", "pip-tools", "pre-commit", "pytest-cov", "pytest-socket", "pytest-timeout", "pytest-xdist", "wheel"]
|
||||
docs = ["myst_parser", "sphinx", "sphinx_rtd_theme"]
|
||||
full = ["Pillow (>=8.0.0)", "cryptography"]
|
||||
image = ["Pillow (>=8.0.0)"]
|
||||
@@ -11886,14 +11878,14 @@ requests-toolbelt = ">=0.6.0"
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.21"
|
||||
version = "0.0.22"
|
||||
description = "A streaming multipart parser for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "python_multipart-0.0.21-py3-none-any.whl", hash = "sha256:cf7a6713e01c87aa35387f4774e812c4361150938d20d232800f75ffcf266090"},
|
||||
{file = "python_multipart-0.0.21.tar.gz", hash = "sha256:7137ebd4d3bbf70ea1622998f902b97a29434a9e8dc40eb203bbcf7c2a2cba92"},
|
||||
{file = "python_multipart-0.0.22-py3-none-any.whl", hash = "sha256:2b2cd894c83d21bf49d702499531c7bafd057d730c201782048f7945d82de155"},
|
||||
{file = "python_multipart-0.0.22.tar.gz", hash = "sha256:7340bef99a7e0032613f56dc36027b959fd3b30a787ed62d310e951f7c3a3a58"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -11916,14 +11908,14 @@ XlsxWriter = ">=0.5.7"
|
||||
|
||||
[[package]]
|
||||
name = "python-socketio"
|
||||
version = "5.13.0"
|
||||
version = "5.14.0"
|
||||
description = "Socket.IO server and client for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "python_socketio-5.13.0-py3-none-any.whl", hash = "sha256:51f68d6499f2df8524668c24bcec13ba1414117cfb3a90115c559b601ab10caf"},
|
||||
{file = "python_socketio-5.13.0.tar.gz", hash = "sha256:ac4e19a0302ae812e23b712ec8b6427ca0521f7c582d6abb096e36e24a263029"},
|
||||
{file = "python_socketio-5.14.0-py3-none-any.whl", hash = "sha256:7de5ad8a55efc33e17897f6cf91d20168d3d259f98c38d38e2940af83136d6f8"},
|
||||
{file = "python_socketio-5.14.0.tar.gz", hash = "sha256:d057737f658b3948392ff452a5c865c5ccc969859c37cf095a73393ce755f98e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -11969,7 +11961,7 @@ description = "Python for Window Extensions"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
groups = ["main"]
|
||||
markers = "sys_platform == \"win32\" or platform_system == \"Windows\""
|
||||
markers = "platform_system == \"Windows\" or sys_platform == \"win32\""
|
||||
files = [
|
||||
{file = "pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3"},
|
||||
{file = "pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b"},
|
||||
@@ -14917,4 +14909,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "b5cbb1e25176845ac9f95650a802667e2f8be1a536e3e55a9269b5af5a42e3fc"
|
||||
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"
|
||||
|
||||
@@ -44,6 +44,12 @@ httpx = "*"
|
||||
scikit-learn = "^1.7.0"
|
||||
shap = "^0.48.0"
|
||||
google-cloud-recaptcha-enterprise = "^1.24.0"
|
||||
# Dependencies previously only in Dockerfile, now managed via poetry.lock
|
||||
prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -38,15 +38,28 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
accept_router as invitation_accept_router,
|
||||
)
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
invitation_router,
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
override_llm_models_dependency,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -70,6 +83,7 @@ base_app.include_router(api_router) # Add additional route for github auth
|
||||
base_app.include_router(oauth_router) # Add additional route for oauth callback
|
||||
base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes
|
||||
base_app.include_router(saas_user_router) # Add additional route SAAS user calls
|
||||
base_app.include_router(user_app_settings_router) # Add routes for user app settings
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
@@ -78,8 +92,15 @@ base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
|
||||
from integrations.github.github_v1_callback_processor import ( # noqa: E402
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
from server.routes.integration.github import github_integration_router # noqa: E402
|
||||
|
||||
# Bludgeon mypy into not deleting my import
|
||||
logger.debug(f'Loaded {GithubV1CallbackProcessor.__name__}')
|
||||
|
||||
base_app.include_router(
|
||||
github_integration_router
|
||||
) # Add additional route for integration webhook events
|
||||
@@ -92,6 +113,16 @@ if GITLAB_APP_CLIENT_ID:
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
) # Add routes for verified models management
|
||||
|
||||
# Override the default LLM models implementation with SaaS version
|
||||
# This must happen after all routers are included
|
||||
override_llm_models_dependency(base_app)
|
||||
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
@@ -38,3 +38,9 @@ class ExpiredError(AuthError):
|
||||
"""Error when a token has expired (Usually the refresh token)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TokenRefreshError(AuthError):
|
||||
"""Error when token refresh fails due to timeout or lock contention"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from server.auth.sheets_client import GoogleSheetsClient
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@@ -9,12 +7,9 @@ class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured."""
|
||||
@@ -36,23 +31,11 @@ class UserVerifier:
|
||||
except Exception:
|
||||
logger.exception(f'Error reading user list file {waitlist}')
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured."""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.debug('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.debug('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
return bool(self.file_users)
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration."""
|
||||
@@ -63,15 +46,6 @@ class UserVerifier:
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = [
|
||||
u.lower() for u in self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
]
|
||||
if username.lower() in sheet_users:
|
||||
logger.debug(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
306
enterprise/server/auth/authorization.py
Normal file
306
enterprise/server/auth/authorization.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Permission-based authorization dependencies for API endpoints.
|
||||
|
||||
This module provides FastAPI dependencies for checking user permissions
|
||||
within organizations. It uses a permission-based authorization model where
|
||||
roles (owner, admin, member) are mapped to specific permissions.
|
||||
|
||||
Permissions are defined in the Permission enum and mapped to roles via
|
||||
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
|
||||
the familiar role-based hierarchy.
|
||||
|
||||
Usage:
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with VIEW_LLM_SETTINGS permission can access
|
||||
...
|
||||
|
||||
@router.patch('/{org_id}/settings')
|
||||
async def update_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with EDIT_LLM_SETTINGS permission can access
|
||||
...
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Permissions that can be assigned to roles."""
|
||||
|
||||
# Secrets
|
||||
MANAGE_SECRETS = 'manage_secrets'
|
||||
|
||||
# MCP
|
||||
MANAGE_MCP = 'manage_mcp'
|
||||
|
||||
# Integrations
|
||||
MANAGE_INTEGRATIONS = 'manage_integrations'
|
||||
|
||||
# Application Settings
|
||||
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
|
||||
|
||||
# API Keys
|
||||
MANAGE_API_KEYS = 'manage_api_keys'
|
||||
|
||||
# LLM Settings
|
||||
VIEW_LLM_SETTINGS = 'view_llm_settings'
|
||||
EDIT_LLM_SETTINGS = 'edit_llm_settings'
|
||||
|
||||
# Billing
|
||||
VIEW_BILLING = 'view_billing'
|
||||
ADD_CREDITS = 'add_credits'
|
||||
|
||||
# Organization Members
|
||||
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
|
||||
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
|
||||
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
|
||||
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
|
||||
|
||||
# Organization Management
|
||||
VIEW_ORG_SETTINGS = 'view_org_settings'
|
||||
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
|
||||
DELETE_ORGANIZATION = 'delete_organization'
|
||||
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
MEMBER = 'member'
|
||||
|
||||
|
||||
# Permission mappings for each role
|
||||
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
RoleName.OWNER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
Permission.CHANGE_USER_ROLE_OWNER,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (synchronous version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
|
||||
else:
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return RoleStore.get_role_by_id(org_member.role_id)
|
||||
|
||||
|
||||
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
|
||||
parse_uuid(user_id)
|
||||
)
|
||||
else:
|
||||
org_member = await OrgMemberStore.get_org_member_async(
|
||||
org_id, parse_uuid(user_id)
|
||||
)
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return await RoleStore.get_role_by_id_async(org_member.role_id)
|
||||
|
||||
|
||||
def get_role_permissions(role_name: str) -> frozenset[Permission]:
|
||||
"""
|
||||
Get the permissions for a role.
|
||||
|
||||
Args:
|
||||
role_name: Name of the role
|
||||
|
||||
Returns:
|
||||
Set of permissions for the role
|
||||
"""
|
||||
try:
|
||||
role_enum = RoleName(role_name)
|
||||
return ROLE_PERMISSIONS.get(role_enum, frozenset())
|
||||
except ValueError:
|
||||
return frozenset()
|
||||
|
||||
|
||||
def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
"""
|
||||
Check if a role has a specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User's Role object
|
||||
permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if the role has the permission
|
||||
"""
|
||||
permissions = get_role_permissions(user_role.name)
|
||||
return permission in permissions
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
...
|
||||
|
||||
Args:
|
||||
permission: The permission required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that validates permission and returns user_id
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role_async(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'User not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if not has_permission(user_role, permission):
|
||||
logger.warning(
|
||||
'Insufficient permissions',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
'required_permission': permission.value,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'Requires {permission.value} permission',
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
@@ -1,87 +1,11 @@
|
||||
import os
|
||||
|
||||
from integrations.github.github_service import SaaSGitHubService
|
||||
from pydantic import SecretStr
|
||||
from server.auth.sheets_client import GoogleSheetsClient
|
||||
|
||||
from enterprise.server.auth.auth_utils import user_verifier
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import GitHubUser
|
||||
|
||||
|
||||
class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured"""
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if not waitlist:
|
||||
logger.debug('GITHUB_USER_LIST_FILE not configured')
|
||||
return
|
||||
|
||||
if not os.path.exists(waitlist):
|
||||
logger.error(f'User list file not found: {waitlist}')
|
||||
raise FileNotFoundError(f'User list file not found: {waitlist}')
|
||||
|
||||
try:
|
||||
with open(waitlist, 'r') as f:
|
||||
self.file_users = [line.strip().lower() for line in f if line.strip()]
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f'Error reading user list file {waitlist}', exc_info=True)
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured"""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.debug('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.debug('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration"""
|
||||
logger.debug(f'Checking if GitHub user {username} is allowed')
|
||||
if self.file_users:
|
||||
if username.lower() in self.file_users:
|
||||
logger.debug(f'User {username} found in text file allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = [
|
||||
u.lower() for u in self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
]
|
||||
if username.lower() in sheet_users:
|
||||
logger.debug(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
user_verifier = UserVerifier()
|
||||
|
||||
|
||||
def is_user_allowed(user_login: str):
|
||||
if user_verifier.is_active() and not user_verifier.is_user_allowed(user_login):
|
||||
logger.warning(f'GitHub user {user_login} not in allow list')
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class AssessmentResult:
|
||||
"""Result of a reCAPTCHA Enterprise assessment."""
|
||||
|
||||
name: str
|
||||
score: float
|
||||
valid: bool
|
||||
action_valid: bool
|
||||
@@ -63,6 +64,7 @@ class RecaptchaService:
|
||||
user_ip: str,
|
||||
user_agent: str,
|
||||
email: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> AssessmentResult:
|
||||
"""Create a reCAPTCHA Enterprise assessment.
|
||||
|
||||
@@ -72,6 +74,7 @@ class RecaptchaService:
|
||||
user_ip: The user's IP address.
|
||||
user_agent: The user's browser user agent.
|
||||
email: Optional email for Account Defender hashing.
|
||||
user_id: Optional Keycloak user ID for logging correlation.
|
||||
|
||||
Returns:
|
||||
AssessmentResult with score, validity, and allowed status.
|
||||
@@ -100,6 +103,10 @@ class RecaptchaService:
|
||||
|
||||
response = self.client.create_assessment(request)
|
||||
|
||||
# Capture assessment name for potential annotation later
|
||||
# Format: projects/{project_id}/assessments/{assessment_id}
|
||||
assessment_name = response.name
|
||||
|
||||
token_properties = response.token_properties
|
||||
risk_analysis = response.risk_analysis
|
||||
|
||||
@@ -129,6 +136,7 @@ class RecaptchaService:
|
||||
logger.info(
|
||||
'recaptcha_assessment',
|
||||
extra={
|
||||
'assessment_name': assessment_name,
|
||||
'score': score,
|
||||
'valid': valid,
|
||||
'action_valid': action_valid,
|
||||
@@ -137,10 +145,13 @@ class RecaptchaService:
|
||||
'has_suspicious_labels': has_suspicious_labels,
|
||||
'allowed': allowed,
|
||||
'user_ip': user_ip,
|
||||
'user_id': user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
|
||||
return AssessmentResult(
|
||||
name=assessment_name,
|
||||
score=score,
|
||||
valid=valid,
|
||||
action_valid=action_valid,
|
||||
|
||||
@@ -49,6 +49,10 @@ from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.types import SessionExpiredError
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# HTTP timeout for external IDP calls (in seconds)
|
||||
# This prevents indefinite blocking if an IDP is slow or unresponsive
|
||||
IDP_HTTP_TIMEOUT = 15.0
|
||||
|
||||
|
||||
def _before_sleep_callback(retry_state: RetryCallState) -> None:
|
||||
logger.info(f'Retry attempt {retry_state.attempt_number} for Keycloak operation')
|
||||
@@ -202,7 +206,9 @@ class TokenManager:
|
||||
access_token: str,
|
||||
idp: ProviderType,
|
||||
) -> dict[str, str | int]:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
base_url = KEYCLOAK_SERVER_URL_EXT if self.external else KEYCLOAK_SERVER_URL
|
||||
url = f'{base_url}/realms/{KEYCLOAK_REALM_NAME}/broker/{idp.value}/token'
|
||||
headers = {
|
||||
@@ -361,7 +367,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitHub token')
|
||||
@@ -387,7 +395,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitLab token')
|
||||
@@ -415,7 +425,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed Bitbucket token')
|
||||
|
||||
@@ -15,6 +15,11 @@ IS_FEATURE_ENV = (
|
||||
) # Does not include the staging deployment
|
||||
IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
|
||||
# Role name constants
|
||||
ROLE_OWNER = 'owner'
|
||||
ROLE_ADMIN = 'admin'
|
||||
ROLE_MEMBER = 'member'
|
||||
|
||||
# Deprecated - billing margins are now handled internally in litellm
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
@@ -25,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
# Minimax is now the default as it gives results close to claude in terms of quality
|
||||
# but at a much lower price
|
||||
5: 'minimax-m2.5',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
@@ -54,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
|
||||
@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
|
||||
obj['stack_info'] = format_stack(stack_info)
|
||||
|
||||
result = json.dumps(obj, **kwargs)
|
||||
|
||||
# Swap out newlines to make things easier to read. This will produce
|
||||
# invalid json but means we can have similar logs in local development
|
||||
# to production, making things easier to correlate. Obviously,
|
||||
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
|
||||
if LOG_JSON_FOR_CONSOLE:
|
||||
result = result.replace('\\n', '\n')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -103,11 +103,13 @@ class SetAuthCookieMiddleware:
|
||||
keycloak_auth_cookie = request.cookies.get('keycloak_auth')
|
||||
auth_header = request.headers.get('Authorization')
|
||||
mcp_auth_header = request.headers.get('X-Session-API-Key')
|
||||
api_auth_header = request.headers.get('X-Access-Token')
|
||||
accepted_tos: bool | None = False
|
||||
if (
|
||||
keycloak_auth_cookie is None
|
||||
and (auth_header is None or not auth_header.startswith('Bearer '))
|
||||
and mcp_auth_header is None
|
||||
and api_auth_header is None
|
||||
):
|
||||
raise NoCredentialsError
|
||||
|
||||
@@ -160,10 +162,10 @@ class SetAuthCookieMiddleware:
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/api/organizations/members/invite/accept',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
'/api/v1/webhooks/secrets',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
@@ -174,6 +176,10 @@ class SetAuthCookieMiddleware:
|
||||
):
|
||||
return False
|
||||
|
||||
# Webhooks access is controlled using separate API keys
|
||||
if path.startswith('/api/v1/webhooks/'):
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
@@ -2,10 +2,12 @@ from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -52,7 +54,6 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
|
||||
try:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
@@ -135,9 +136,9 @@ class ApiKeyCreate(BaseModel):
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ApiKeyCreateResponse(ApiKeyResponse):
|
||||
@@ -148,8 +149,47 @@ class LlmApiKeyResponse(BaseModel):
|
||||
key: str | None
|
||||
|
||||
|
||||
@api_router.post('', response_model=ApiKeyCreateResponse)
|
||||
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
|
||||
class ByorPermittedResponse(BaseModel):
|
||||
permitted: bool
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor/permitted', tags=['Keys'])
|
||||
async def check_byor_permitted(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> ByorPermittedResponse:
|
||||
"""Check if BYOR key export is permitted for the user's current org."""
|
||||
try:
|
||||
permitted = await OrgService.check_byor_export_enabled(user_id)
|
||||
return ByorPermittedResponse(permitted=permitted)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error checking BYOR export permission', extra={'error': str(e)}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to check BYOR export permission',
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('', tags=['Keys'])
|
||||
async def create_api_key(
|
||||
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
|
||||
) -> ApiKeyCreateResponse:
|
||||
"""Create a new API key for the authenticated user."""
|
||||
try:
|
||||
api_key = await api_key_store.create_api_key(
|
||||
@@ -158,48 +198,29 @@ async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user
|
||||
# Get the created key details
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
for key in keys:
|
||||
if key['name'] == key_data.name:
|
||||
return {
|
||||
**key,
|
||||
'key': api_key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
if key.name == key_data.name:
|
||||
return ApiKeyCreateResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
key=api_key,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error creating API key')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('', response_model=list[ApiKeyResponse])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('', tags=['Keys'])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
|
||||
"""List all API keys for the authenticated user."""
|
||||
try:
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
return [
|
||||
{
|
||||
**key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
for key in keys
|
||||
]
|
||||
return [api_key_to_response(key) for key in keys]
|
||||
except Exception:
|
||||
logger.exception('Error listing API keys')
|
||||
raise HTTPException(
|
||||
@@ -208,8 +229,10 @@ async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete('/{key_id}')
|
||||
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
@api_router.delete('/{key_id}', tags=['Keys'])
|
||||
async def delete_api_key(
|
||||
key_id: int, user_id: str = Depends(get_user_id)
|
||||
) -> MessageResponse:
|
||||
"""Delete an API key."""
|
||||
try:
|
||||
# First, verify the key belongs to the user
|
||||
@@ -217,7 +240,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
key_to_delete = None
|
||||
|
||||
for key in keys:
|
||||
if key['id'] == key_id:
|
||||
if key.id == key_id:
|
||||
key_to_delete = key
|
||||
break
|
||||
|
||||
@@ -235,7 +258,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete API key',
|
||||
)
|
||||
return {'message': 'API key deleted successfully'}
|
||||
return MessageResponse(message='API key deleted successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
@@ -246,22 +269,33 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
|
||||
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
This endpoint validates that the key exists in LiteLLM before returning it.
|
||||
If validation fails, it automatically generates a new key to ensure users
|
||||
always receive a working key.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Check if the BYOR key exists in the database
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
return LlmApiKeyResponse(key=byor_key)
|
||||
else:
|
||||
# Key exists in DB but is invalid in LiteLLM - regenerate it
|
||||
logger.warning(
|
||||
@@ -286,7 +320,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'Successfully generated and stored new BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate new BYOR LLM API key',
|
||||
@@ -308,12 +342,24 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
|
||||
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
|
||||
@api_router.post('/llm/byor/refresh', tags=['Keys'])
|
||||
async def refresh_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
@@ -352,7 +398,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'BYOR LLM API key refresh completed successfully',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
except HTTPException as he:
|
||||
logger.error(
|
||||
'HTTP exception during BYOR LLM API key refresh',
|
||||
|
||||
@@ -5,6 +5,7 @@ import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
import posthog
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
|
||||
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
|
||||
)
|
||||
|
||||
|
||||
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
|
||||
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token, invitation_token).
|
||||
Tokens may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None, None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return (
|
||||
state_data.get('redirect_url', ''),
|
||||
state_data.get('recaptcha_token'),
|
||||
state_data.get('invitation_token'),
|
||||
)
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None, None
|
||||
|
||||
|
||||
# Keep alias for backward compatibility
|
||||
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
|
||||
"""Extract redirect URL and reCAPTCHA token from OAuth state.
|
||||
|
||||
Deprecated: Use _extract_oauth_state instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token). Token may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None
|
||||
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
|
||||
return redirect_url, recaptcha_token
|
||||
|
||||
|
||||
@oauth_router.get('/keycloak/callback')
|
||||
@@ -130,8 +156,8 @@ async def keycloak_callback(
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
):
|
||||
# Extract redirect URL and reCAPTCHA token from state
|
||||
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
|
||||
# Extract redirect URL, reCAPTCHA token, and invitation token from state
|
||||
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
|
||||
if not redirect_url:
|
||||
redirect_url = str(request.base_url)
|
||||
|
||||
@@ -179,6 +205,10 @@ async def keycloak_callback(
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
else:
|
||||
# Existing user — gradually backfill contact_name if it still has a username-style value
|
||||
await UserStore.backfill_contact_name(user_id, user_info)
|
||||
await UserStore.backfill_user_email(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
@@ -219,6 +249,7 @@ async def keycloak_callback(
|
||||
user_ip=user_ip,
|
||||
user_agent=user_agent,
|
||||
email=email,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not result.allowed:
|
||||
@@ -298,8 +329,13 @@ async def keycloak_callback(
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
# Preserve invitation token so it can be included in OAuth state after verification
|
||||
if invitation_token:
|
||||
verification_redirect_url = (
|
||||
f'{verification_redirect_url}&invitation_token={invitation_token}'
|
||||
)
|
||||
response = RedirectResponse(verification_redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
@@ -377,14 +413,90 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
|
||||
# Process invitation token if present (after email verification but before TOS)
|
||||
if invitation_token:
|
||||
try:
|
||||
logger.info(
|
||||
'Processing invitation token during auth callback',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'invitation_token_prefix': invitation_token[:10] + '...',
|
||||
},
|
||||
)
|
||||
|
||||
await OrgInvitationService.accept_invitation(
|
||||
invitation_token, parse_uuid(user_id)
|
||||
)
|
||||
logger.info(
|
||||
'Invitation accepted during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation expired during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Add query param to redirect URL
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invalid invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'User already member during invitation acceptance',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Email mismatch during auth callback invitation acceptance',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error processing invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
# Don't fail the login if invitation processing fails
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_error=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_error=true'
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
tos_redirect_url = (
|
||||
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
|
||||
)
|
||||
if invitation_token:
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
else:
|
||||
if invitation_token:
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
set_response_cookie(
|
||||
@@ -438,7 +550,10 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
|
||||
user_id=user_info['sub'], offline_token=keycloak_refresh_token
|
||||
)
|
||||
|
||||
return RedirectResponse(state if state else request.base_url, status_code=302)
|
||||
redirect_url, _, _ = _extract_oauth_state(state)
|
||||
return RedirectResponse(
|
||||
redirect_url if redirect_url else request.base_url, status_code=302
|
||||
)
|
||||
|
||||
|
||||
@oauth_router.get('/github/callback')
|
||||
|
||||
@@ -9,14 +9,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import (
|
||||
STRIPE_API_KEY,
|
||||
)
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -24,7 +23,7 @@ from openhands.app_server.config import get_global_config
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
billing_router = APIRouter(prefix='/api/billing', tags=['Billing'])
|
||||
|
||||
|
||||
async def validate_billing_enabled() -> None:
|
||||
@@ -94,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(user.current_org_id)
|
||||
)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
@@ -148,7 +147,7 @@ async def create_customer_setup_session(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
success_url=f'{base_url}?setup=success',
|
||||
cancel_url=f'{base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
@@ -250,15 +249,21 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Enable BYOR export for the org now that they've purchased credits
|
||||
if org:
|
||||
org.byor_export_enabled = True
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
|
||||
@@ -8,6 +8,7 @@ from server.auth.keycloak_manager import get_keycloak_admin
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -62,6 +63,10 @@ async def update_email(
|
||||
},
|
||||
)
|
||||
|
||||
await UserStore.update_user_email(
|
||||
user_id=user_id, email=email, email_verified=False
|
||||
)
|
||||
|
||||
user_auth: SaasUserAuth = await get_user_auth(request)
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email = email
|
||||
@@ -144,6 +149,7 @@ async def verified_email(request: Request):
|
||||
user_auth: SaasUserAuth = await get_user_auth(request)
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email_verified = True
|
||||
await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True)
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/settings/user'
|
||||
response = RedirectResponse(redirect_uri, status_code=302)
|
||||
|
||||
@@ -8,11 +8,18 @@ from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
router = APIRouter(prefix='/feedback', tags=['feedback'])
|
||||
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
|
||||
# is protected. The actual protection is provided by SetAuthCookieMiddleware
|
||||
# TODO: It may be an error by you can actually post feedback to a conversation you don't
|
||||
# own right now - maybe this is useful in the context of public shared conversations?
|
||||
router = APIRouter(
|
||||
prefix='/feedback', tags=['feedback'], dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
|
||||
@@ -371,9 +371,7 @@ async def create_jira_workspace(request: Request, workspace_data: JiraWorkspaceC
|
||||
'prompt': 'consent',
|
||||
}
|
||||
|
||||
auth_url = (
|
||||
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
)
|
||||
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -432,9 +430,7 @@ async def create_workspace_link(request: Request, link_data: JiraLinkCreate):
|
||||
'response_type': 'code',
|
||||
'prompt': 'consent',
|
||||
}
|
||||
auth_url = (
|
||||
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
)
|
||||
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import (
|
||||
@@ -316,7 +316,7 @@ async def create_jira_dc_workspace(
|
||||
'response_type': 'code',
|
||||
}
|
||||
|
||||
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -436,7 +436,7 @@ async def create_workspace_link(request: Request, link_data: JiraDcLinkCreate):
|
||||
'state': state,
|
||||
'response_type': 'code',
|
||||
}
|
||||
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
||||
122
enterprise/server/routes/org_invitation_models.py
Normal file
122
enterprise/server/routes/org_invitation_models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Pydantic models and custom exceptions for organization invitations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
class InvitationError(Exception):
|
||||
"""Base exception for invitation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvitationAlreadyExistsError(InvitationError):
|
||||
"""Raised when a pending invitation already exists for the email."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = 'A pending invitation already exists for this email'
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserAlreadyMemberError(InvitationError):
|
||||
"""Raised when the user is already a member of the organization."""
|
||||
|
||||
def __init__(self, message: str = 'User is already a member of this organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationExpiredError(InvitationError):
|
||||
"""Raised when the invitation has expired."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation has expired'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationInvalidError(InvitationError):
|
||||
"""Raised when the invitation is invalid or revoked."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation is no longer valid'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionError(InvitationError):
|
||||
"""Raised when the user lacks permission to perform the action."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EmailMismatchError(InvitationError):
|
||||
"""Raised when the accepting user's email doesn't match the invitation email."""
|
||||
|
||||
def __init__(self, message: str = 'Your email does not match the invitation'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationCreate(BaseModel):
|
||||
"""Request model for creating invitation(s)."""
|
||||
|
||||
emails: list[EmailStr]
|
||||
role: str = 'member' # Default to member role
|
||||
|
||||
|
||||
class InvitationResponse(BaseModel):
|
||||
"""Response model for invitation details."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
created_at: str
|
||||
expires_at: str
|
||||
inviter_email: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_invitation(
|
||||
cls,
|
||||
invitation: OrgInvitation,
|
||||
inviter_email: str | None = None,
|
||||
) -> 'InvitationResponse':
|
||||
"""Create an InvitationResponse from an OrgInvitation entity.
|
||||
|
||||
Args:
|
||||
invitation: The invitation entity to convert
|
||||
inviter_email: Optional email of the inviter
|
||||
|
||||
Returns:
|
||||
InvitationResponse: The response model instance
|
||||
"""
|
||||
role_name = ''
|
||||
if invitation.role:
|
||||
role_name = invitation.role.name
|
||||
elif invitation.role_id:
|
||||
role = RoleStore.get_role_by_id(invitation.role_id)
|
||||
role_name = role.name if role else ''
|
||||
|
||||
return cls(
|
||||
id=invitation.id,
|
||||
email=invitation.email,
|
||||
role=role_name,
|
||||
status=invitation.status,
|
||||
created_at=invitation.created_at.isoformat(),
|
||||
expires_at=invitation.expires_at.isoformat(),
|
||||
inviter_email=inviter_email,
|
||||
)
|
||||
|
||||
|
||||
class InvitationFailure(BaseModel):
|
||||
"""Response model for a failed invitation."""
|
||||
|
||||
email: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchInvitationResponse(BaseModel):
|
||||
"""Response model for batch invitation creation."""
|
||||
|
||||
successful: list[InvitationResponse]
|
||||
failed: list[InvitationFailure]
|
||||
226
enterprise/server/routes/org_invitations.py
Normal file
226
enterprise/server/routes/org_invitations.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""API routes for organization invitations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from server.routes.org_invitation_models import (
|
||||
BatchInvitationResponse,
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationCreate,
|
||||
InvitationExpiredError,
|
||||
InvitationFailure,
|
||||
InvitationInvalidError,
|
||||
InvitationResponse,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
|
||||
# Router for invitation operations on an organization (requires org_id)
|
||||
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
|
||||
|
||||
# Router for accepting invitations (no org_id required)
|
||||
accept_router = APIRouter(prefix='/api/organizations/members/invite')
|
||||
|
||||
|
||||
@invitation_router.post(
|
||||
'/invite',
|
||||
response_model=BatchInvitationResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
invitation_data: InvitationCreate,
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Create organization invitations for multiple email addresses.
|
||||
|
||||
Sends emails to invitees with secure links to join the organization.
|
||||
Supports batch invitations - some may succeed while others fail.
|
||||
|
||||
Permission rules:
|
||||
- Only owners and admins can create invitations
|
||||
- Admins can only invite with 'member' or 'admin' role (not 'owner')
|
||||
- Owners can invite with any role
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
invitation_data: Invitation details (emails array, role)
|
||||
request: FastAPI request
|
||||
user_id: Authenticated user ID (from dependency)
|
||||
|
||||
Returns:
|
||||
BatchInvitationResponse: Lists of successful and failed invitations
|
||||
|
||||
Raises:
|
||||
HTTPException 400: Invalid role or organization not found
|
||||
HTTPException 403: User lacks permission to invite
|
||||
HTTPException 429: Rate limit exceeded
|
||||
"""
|
||||
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
|
||||
await check_rate_limit_by_user_id(
|
||||
request=request,
|
||||
key_prefix='org_invitation_create',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=6,
|
||||
)
|
||||
|
||||
try:
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=[str(email) for email in invitation_data.emails],
|
||||
role_name=invitation_data.role,
|
||||
inviter_id=UUID(user_id),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Batch organization invitations created',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'total_emails': len(invitation_data.emails),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
'inviter_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return BatchInvitationResponse(
|
||||
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
|
||||
failed=[
|
||||
InvitationFailure(email=email, error=error) for email, error in failed
|
||||
],
|
||||
)
|
||||
|
||||
except InsufficientPermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error creating batch invitations',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@accept_router.get('/accept')
|
||||
async def accept_invitation(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""Accept an organization invitation via token.
|
||||
|
||||
This endpoint is accessed via the link in the invitation email.
|
||||
|
||||
Flow:
|
||||
1. If user is authenticated: Accept invitation directly and redirect to home
|
||||
2. If user is not authenticated: Redirect to login page with invitation token
|
||||
- Frontend stores token and includes it in OAuth state during login
|
||||
- After authentication, keycloak_callback processes the invitation
|
||||
|
||||
Args:
|
||||
token: The invitation token from the email link
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
|
||||
or home page with error query params on failure
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
|
||||
# Try to get user_id from auth (may not be authenticated)
|
||||
user_id = None
|
||||
try:
|
||||
user_auth = await get_user_auth(request)
|
||||
if user_auth:
|
||||
user_id = await user_auth.get_user_id()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
# User not authenticated - redirect to login page with invitation token
|
||||
# Frontend will store the token and include it in OAuth state during login
|
||||
logger.info(
|
||||
'Invitation accept: redirecting unauthenticated user to login',
|
||||
extra={'token_prefix': token[:10] + '...'},
|
||||
)
|
||||
login_url = f'{base_url}/login?invitation_token={token}'
|
||||
return RedirectResponse(login_url, status_code=302)
|
||||
|
||||
# User is authenticated - process the invitation directly
|
||||
try:
|
||||
await OrgInvitationService.accept_invitation(token, UUID(user_id))
|
||||
|
||||
logger.info(
|
||||
'Invitation accepted successfully',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Redirect to home page on success
|
||||
return RedirectResponse(f'{base_url}/', status_code=302)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation accept failed: expired',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: invalid',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'Invitation accept: user already member',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: email mismatch',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error accepting invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)
|
||||
@@ -1,7 +1,16 @@
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, StringConstraints
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
EmailStr,
|
||||
Field,
|
||||
SecretStr,
|
||||
StringConstraints,
|
||||
field_validator,
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
@@ -43,6 +52,16 @@ class OrgAuthorizationError(OrgDeletionError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrphanedUserError(OrgDeletionError):
|
||||
"""Raised when deleting an org would leave users without any organization."""
|
||||
|
||||
def __init__(self, user_ids: list[str]):
|
||||
self.user_ids = user_ids
|
||||
super().__init__(
|
||||
f'Cannot delete organization: {len(user_ids)} user(s) would have no remaining organization'
|
||||
)
|
||||
|
||||
|
||||
class OrgNotFoundError(Exception):
|
||||
"""Raised when organization is not found or user doesn't have access."""
|
||||
|
||||
@@ -51,6 +70,61 @@ class OrgNotFoundError(Exception):
|
||||
super().__init__(f'Organization with id "{org_id}" not found')
|
||||
|
||||
|
||||
class OrgMemberNotFoundError(Exception):
|
||||
"""Raised when a member is not found in an organization."""
|
||||
|
||||
def __init__(self, org_id: str, user_id: str):
|
||||
self.org_id = org_id
|
||||
self.user_id = user_id
|
||||
super().__init__(f'Member "{user_id}" not found in organization "{org_id}"')
|
||||
|
||||
|
||||
class RoleNotFoundError(Exception):
|
||||
"""Raised when a role is not found."""
|
||||
|
||||
def __init__(self, role_id: int):
|
||||
self.role_id = role_id
|
||||
super().__init__(f'Role with id "{role_id}" not found')
|
||||
|
||||
|
||||
class InvalidRoleError(Exception):
|
||||
"""Raised when an invalid role name is specified."""
|
||||
|
||||
def __init__(self, role_name: str):
|
||||
self.role_name = role_name
|
||||
super().__init__(f'Invalid role: "{role_name}"')
|
||||
|
||||
|
||||
class InsufficientPermissionError(Exception):
|
||||
"""Raised when user lacks permission to perform an operation."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CannotModifySelfError(Exception):
|
||||
"""Raised when user attempts to modify their own membership."""
|
||||
|
||||
def __init__(self, action: str = 'modify'):
|
||||
self.action = action
|
||||
super().__init__(f'Cannot {action} your own membership')
|
||||
|
||||
|
||||
class LastOwnerError(Exception):
|
||||
"""Raised when attempting to remove or demote the last owner."""
|
||||
|
||||
def __init__(self, action: str = 'remove'):
|
||||
self.action = action
|
||||
super().__init__(f'Cannot {action} the last owner of an organization')
|
||||
|
||||
|
||||
class MemberUpdateError(Exception):
|
||||
"""Raised when member update operation fails."""
|
||||
|
||||
def __init__(self, message: str = 'Failed to update member'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
@@ -91,14 +165,18 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
is_personal: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
def from_org(
|
||||
cls, org: Org, credits: float | None = None, user_id: str | None = None
|
||||
) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
user_id: Optional user ID to determine if org is personal (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
@@ -134,6 +212,7 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
is_personal=str(org.id) == user_id if user_id else False,
|
||||
)
|
||||
|
||||
|
||||
@@ -142,12 +221,17 @@ class OrgPage(BaseModel):
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
current_org_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
"""Request model for updating an organization."""
|
||||
|
||||
# Basic organization information (any authenticated user can update)
|
||||
name: Annotated[
|
||||
str | None,
|
||||
StringConstraints(strip_whitespace=True, min_length=1, max_length=255),
|
||||
] = None
|
||||
contact_name: str | None = None
|
||||
contact_email: EmailStr | None = None
|
||||
conversation_expiration: int | None = None
|
||||
@@ -173,3 +257,230 @@ class OrgUpdate(BaseModel):
|
||||
confirmation_mode: bool | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
|
||||
|
||||
class OrgLLMSettingsResponse(BaseModel):
|
||||
"""Response model for organization LLM settings."""
|
||||
|
||||
default_llm_model: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None # Masked in response
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
security_analyzer: str | None = None
|
||||
enable_default_condenser: bool = True
|
||||
condenser_max_size: int | None = None
|
||||
default_max_iterations: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def _mask_key(secret: SecretStr | None) -> str | None:
|
||||
"""Mask an API key, showing only last 4 characters."""
|
||||
if secret is None:
|
||||
return None
|
||||
raw = secret.get_secret_value()
|
||||
if not raw:
|
||||
return None
|
||||
if len(raw) <= 4:
|
||||
return '****'
|
||||
return '****' + raw[-4:]
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org) -> 'OrgLLMSettingsResponse':
|
||||
"""Create response from Org entity."""
|
||||
return cls(
|
||||
default_llm_model=org.default_llm_model,
|
||||
default_llm_base_url=org.default_llm_base_url,
|
||||
search_api_key=cls._mask_key(org.search_api_key),
|
||||
agent=org.agent,
|
||||
confirmation_mode=org.confirmation_mode,
|
||||
security_analyzer=org.security_analyzer,
|
||||
enable_default_condenser=org.enable_default_condenser
|
||||
if org.enable_default_condenser is not None
|
||||
else True,
|
||||
condenser_max_size=org.condenser_max_size,
|
||||
default_max_iterations=org.default_max_iterations,
|
||||
)
|
||||
|
||||
|
||||
class OrgMemberLLMSettings(BaseModel):
|
||||
"""LLM settings to propagate to organization members.
|
||||
|
||||
Field names match OrgMember DB columns.
|
||||
"""
|
||||
|
||||
llm_model: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
max_iterations: int | None = None
|
||||
llm_api_key: str | None = None
|
||||
|
||||
def has_updates(self) -> bool:
|
||||
"""Check if any field is set (not None)."""
|
||||
return any(getattr(self, field) is not None for field in self.model_fields)
|
||||
|
||||
|
||||
class OrgLLMSettingsUpdate(BaseModel):
|
||||
"""Request model for updating organization LLM settings.
|
||||
|
||||
Field names match Org DB columns exactly.
|
||||
"""
|
||||
|
||||
default_llm_model: str | None = None
|
||||
default_llm_base_url: str | None = None
|
||||
search_api_key: str | None = None
|
||||
agent: str | None = None
|
||||
confirmation_mode: bool | None = None
|
||||
security_analyzer: str | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
default_max_iterations: int | None = Field(default=None, gt=0)
|
||||
llm_api_key: str | None = None
|
||||
|
||||
def has_updates(self) -> bool:
|
||||
"""Check if any field is set (not None)."""
|
||||
return any(getattr(self, field) is not None for field in self.model_fields)
|
||||
|
||||
def apply_to_org(self, org: Org) -> None:
|
||||
"""Apply non-None settings to the organization model.
|
||||
|
||||
Args:
|
||||
org: Organization entity to update in place
|
||||
"""
|
||||
for field_name in self.model_fields:
|
||||
value = getattr(self, field_name)
|
||||
# Skip llm_api_key - it's only for member propagation, not org-level
|
||||
if value is not None and field_name != 'llm_api_key':
|
||||
setattr(org, field_name, value)
|
||||
|
||||
def get_member_updates(self) -> OrgMemberLLMSettings | None:
|
||||
"""Get updates that need to be propagated to org members.
|
||||
|
||||
Returns:
|
||||
OrgMemberLLMSettings with mapped field values, or None if no member updates needed.
|
||||
Maps: default_llm_model → llm_model, default_llm_base_url → llm_base_url,
|
||||
default_max_iterations → max_iterations, llm_api_key → llm_api_key
|
||||
"""
|
||||
member_settings = OrgMemberLLMSettings(
|
||||
llm_model=self.default_llm_model,
|
||||
llm_base_url=self.default_llm_base_url,
|
||||
max_iterations=self.default_max_iterations,
|
||||
llm_api_key=self.llm_api_key,
|
||||
)
|
||||
return member_settings if member_settings.has_updates() else None
|
||||
|
||||
|
||||
class OrgMemberResponse(BaseModel):
|
||||
"""Response model for a single organization member."""
|
||||
|
||||
user_id: str
|
||||
email: str | None
|
||||
role_id: int
|
||||
role: str
|
||||
role_rank: int
|
||||
status: str | None
|
||||
|
||||
|
||||
class OrgMemberPage(BaseModel):
|
||||
"""Paginated response for organization members."""
|
||||
|
||||
items: list[OrgMemberResponse]
|
||||
current_page: int = 1
|
||||
per_page: int = 10
|
||||
|
||||
|
||||
class OrgMemberUpdate(BaseModel):
|
||||
"""Request model for updating an organization member."""
|
||||
|
||||
role: str | None = None # Role name: 'owner', 'admin', or 'member'
|
||||
|
||||
|
||||
class MeResponse(BaseModel):
|
||||
"""Response model for the current user's membership in an organization."""
|
||||
|
||||
org_id: str
|
||||
user_id: str
|
||||
email: str
|
||||
role: str
|
||||
llm_api_key: str
|
||||
max_iterations: int | None = None
|
||||
llm_model: str | None = None
|
||||
llm_api_key_for_byor: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _mask_key(secret: SecretStr | None) -> str:
|
||||
"""Mask an API key, showing only last 4 characters."""
|
||||
if secret is None:
|
||||
return ''
|
||||
raw = secret.get_secret_value()
|
||||
if not raw:
|
||||
return ''
|
||||
if len(raw) <= 4:
|
||||
return '****'
|
||||
return '****' + raw[-4:]
|
||||
|
||||
@classmethod
|
||||
def from_org_member(cls, member: OrgMember, role: Role, email: str) -> 'MeResponse':
|
||||
"""Create a MeResponse from an OrgMember, Role, and user email.
|
||||
|
||||
Args:
|
||||
member: The OrgMember entity
|
||||
role: The Role entity (provides role name)
|
||||
email: The user's email address
|
||||
|
||||
Returns:
|
||||
MeResponse with masked API keys
|
||||
"""
|
||||
return cls(
|
||||
org_id=str(member.org_id),
|
||||
user_id=str(member.user_id),
|
||||
email=email,
|
||||
role=role.name,
|
||||
llm_api_key=cls._mask_key(member.llm_api_key),
|
||||
max_iterations=member.max_iterations,
|
||||
llm_model=member.llm_model,
|
||||
llm_api_key_for_byor=cls._mask_key(member.llm_api_key_for_byor) or None,
|
||||
llm_base_url=member.llm_base_url,
|
||||
status=member.status,
|
||||
)
|
||||
|
||||
|
||||
class OrgAppSettingsResponse(BaseModel):
|
||||
"""Response model for organization app settings."""
|
||||
|
||||
enable_proactive_conversation_starters: bool = True
|
||||
enable_solvability_analysis: bool | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org) -> 'OrgAppSettingsResponse':
|
||||
"""Create an OrgAppSettingsResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse with app settings
|
||||
"""
|
||||
return cls(
|
||||
enable_proactive_conversation_starters=org.enable_proactive_conversation_starters
|
||||
if org.enable_proactive_conversation_starters is not None
|
||||
else True,
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
max_budget_per_task=org.max_budget_per_task,
|
||||
)
|
||||
|
||||
|
||||
class OrgAppSettingsUpdate(BaseModel):
|
||||
"""Request model for updating organization app settings."""
|
||||
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
enable_solvability_analysis: bool | None = None
|
||||
max_budget_per_task: float | None = None
|
||||
|
||||
@field_validator('max_budget_per_task')
|
||||
@classmethod
|
||||
def validate_max_budget_per_task(cls, v: float | None) -> float | None:
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError('max_budget_per_task must be greater than 0')
|
||||
return v
|
||||
|
||||
@@ -2,25 +2,62 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
OrgUpdate,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.services.org_app_settings_service import (
|
||||
OrgAppSettingsService,
|
||||
OrgAppSettingsServiceInjector,
|
||||
)
|
||||
from server.services.org_llm_settings_service import (
|
||||
OrgLLMSettingsService,
|
||||
OrgLLMSettingsServiceInjector,
|
||||
)
|
||||
from server.services.org_member_service import OrgMemberService
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
# Initialize API router
|
||||
org_router = APIRouter(prefix='/api/organizations')
|
||||
org_router = APIRouter(prefix='/api/organizations', tags=['Orgs'])
|
||||
|
||||
# Create injector instance and dependency for LLM settings
|
||||
_org_llm_settings_injector = OrgLLMSettingsServiceInjector()
|
||||
org_llm_settings_service_dependency = Depends(_org_llm_settings_injector.depends)
|
||||
# Create injector instance and dependency at module level
|
||||
_org_app_settings_injector = OrgAppSettingsServiceInjector()
|
||||
org_app_settings_service_dependency = Depends(_org_app_settings_injector.depends)
|
||||
|
||||
|
||||
@org_router.get('', response_model=OrgPage)
|
||||
@@ -61,6 +98,12 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch user to get current_org_id
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
current_org_id = (
|
||||
str(user.current_org_id) if user and user.current_org_id else None
|
||||
)
|
||||
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
@@ -69,7 +112,9 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
org_responses = [
|
||||
OrgResponse.from_org(org, credits=None, user_id=user_id) for org in orgs
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
@@ -80,7 +125,11 @@ async def list_user_orgs(
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
return OrgPage(
|
||||
items=org_responses,
|
||||
next_page_id=next_page_id,
|
||||
current_org_id=current_org_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@@ -136,7 +185,7 @@ async def create_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -171,26 +220,218 @@ async def create_org(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get(
|
||||
'/llm',
|
||||
response_model=OrgLLMSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.VIEW_LLM_SETTINGS))],
|
||||
)
|
||||
async def get_org_llm_settings(
|
||||
service: OrgLLMSettingsService = org_llm_settings_service_dependency,
|
||||
) -> OrgLLMSettingsResponse:
|
||||
"""Get LLM settings for the user's current organization.
|
||||
|
||||
This endpoint retrieves the LLM configuration settings for the
|
||||
authenticated user's current organization. All organization members
|
||||
can view these settings.
|
||||
|
||||
Args:
|
||||
service: OrgLLMSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The organization's LLM settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated
|
||||
HTTPException: 403 if not a member of any organization
|
||||
HTTPException: 404 if current organization not found
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
return await service.get_org_llm_settings()
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error getting organization LLM settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve LLM settings',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/llm',
|
||||
response_model=OrgLLMSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.EDIT_LLM_SETTINGS))],
|
||||
)
|
||||
async def update_org_llm_settings(
|
||||
settings: OrgLLMSettingsUpdate,
|
||||
service: OrgLLMSettingsService = org_llm_settings_service_dependency,
|
||||
) -> OrgLLMSettingsResponse:
|
||||
"""Update LLM settings for the user's current organization.
|
||||
|
||||
This endpoint updates the LLM configuration settings for the
|
||||
authenticated user's current organization. Only admins and owners
|
||||
can update these settings.
|
||||
|
||||
Args:
|
||||
settings: The LLM settings to update (only non-None fields are updated)
|
||||
service: OrgLLMSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The updated organization's LLM settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated
|
||||
HTTPException: 403 if user lacks EDIT_LLM_SETTINGS permission
|
||||
HTTPException: 404 if current organization not found
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
try:
|
||||
return await service.update_org_llm_settings(settings)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error updating LLM settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update LLM settings',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error updating organization LLM settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update LLM settings',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get(
|
||||
'/app',
|
||||
response_model=OrgAppSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.MANAGE_APPLICATION_SETTINGS))],
|
||||
)
|
||||
async def get_org_app_settings(
|
||||
service: OrgAppSettingsService = org_app_settings_service_dependency,
|
||||
) -> OrgAppSettingsResponse:
|
||||
"""Get organization app settings for the user's current organization.
|
||||
|
||||
This endpoint retrieves application settings for the authenticated user's
|
||||
current organization. Access requires the MANAGE_APPLICATION_SETTINGS permission,
|
||||
which is granted to all organization members (member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
service: OrgAppSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The organization app settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks MANAGE_APPLICATION_SETTINGS permission
|
||||
HTTPException: 404 if current organization not found
|
||||
"""
|
||||
try:
|
||||
return await service.get_org_app_settings()
|
||||
except OrgNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Current organization not found',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving organization app settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/app',
|
||||
response_model=OrgAppSettingsResponse,
|
||||
dependencies=[Depends(require_permission(Permission.MANAGE_APPLICATION_SETTINGS))],
|
||||
)
|
||||
async def update_org_app_settings(
|
||||
update_data: OrgAppSettingsUpdate,
|
||||
service: OrgAppSettingsService = org_app_settings_service_dependency,
|
||||
) -> OrgAppSettingsResponse:
|
||||
"""Update organization app settings for the user's current organization.
|
||||
|
||||
This endpoint updates application settings for the authenticated user's
|
||||
current organization. Access requires the MANAGE_APPLICATION_SETTINGS permission,
|
||||
which is granted to all organization members (member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
update_data: App settings update data
|
||||
service: OrgAppSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The updated organization app settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks MANAGE_APPLICATION_SETTINGS permission
|
||||
HTTPException: 404 if current organization not found
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
try:
|
||||
return await service.update_org_app_settings(update_data)
|
||||
except OrgNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Current organization not found',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error updating organization app settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
This endpoint retrieves details for a specific organization. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -211,7 +452,7 @@ async def get_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -228,26 +469,86 @@ async def get_org(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/me', response_model=MeResponse)
|
||||
async def get_me(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> MeResponse:
|
||||
"""Get the current user's membership record for an organization.
|
||||
|
||||
Returns the authenticated user's role, status, email, and LLM override
|
||||
fields (with masked API keys) within the specified organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
MeResponse: The user's membership data
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user is not a member or org doesn't exist
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving current member details',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
return OrgMemberService.get_me(org_id, user_uuid)
|
||||
|
||||
except OrgMemberNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Organization with id "{org_id}" not found',
|
||||
)
|
||||
except RoleNotFoundError as e:
|
||||
logger.exception(
|
||||
'Role not found for org member',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'role_id': e.role_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving member details',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
This endpoint permanently deletes an organization and all associated data including
|
||||
organization members, conversations, billing data, and external LiteLLM team resources.
|
||||
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
org_id: Organization ID to delete (UUID)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
@@ -303,6 +604,19 @@ async def delete_org(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrphanedUserError as e:
|
||||
logger.warning(
|
||||
'Cannot delete organization: users would be orphaned',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'orphaned_users': e.user_ids,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error during organization deletion',
|
||||
@@ -327,25 +641,26 @@ async def delete_org(
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
|
||||
permission, which is granted to admin and owner roles.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
org_id: Organization ID to update (UUID)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 409 if organization name already exists
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
@@ -368,7 +683,7 @@ async def update_org(
|
||||
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
|
||||
credits = await OrgService.get_org_credits(user_id, updated_org.id)
|
||||
|
||||
return OrgResponse.from_org(updated_org, credits=credits)
|
||||
return OrgResponse.from_org(updated_org, credits=credits, user_id=user_id)
|
||||
|
||||
except ValueError as e:
|
||||
# Organization not found
|
||||
@@ -376,6 +691,11 @@ async def update_org(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# User lacks permission for LLM settings
|
||||
raise HTTPException(
|
||||
@@ -400,3 +720,384 @@ async def update_org(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/members')
|
||||
async def get_org_members(
|
||||
org_id: UUID,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional page offset for pagination'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 10,
|
||||
email: Annotated[
|
||||
str | None,
|
||||
Query(
|
||||
title='Filter members by email (case-insensitive partial match)',
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
),
|
||||
] = None,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgMemberPage:
|
||||
"""Get all members of an organization with pagination and optional email filter.
|
||||
|
||||
This endpoint retrieves a paginated list of organization members. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
page_id: Optional page offset for pagination
|
||||
limit: Maximum number of members to return (1-100, default 10)
|
||||
email: Optional email filter (case-insensitive partial match)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgMemberPage: Paginated list of organization members with
|
||||
current_page and per_page metadata. Use the /count endpoint
|
||||
to get the total count separately.
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 400 if org_id or page_id format is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
success, error_code, data = await OrgMemberService.get_org_members(
|
||||
org_id=org_id,
|
||||
current_user_id=UUID(user_id),
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
email_filter=email,
|
||||
)
|
||||
|
||||
if not success:
|
||||
error_map = {
|
||||
'not_a_member': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You are not a member of this organization',
|
||||
),
|
||||
'invalid_page_id': (
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'Invalid page_id format',
|
||||
),
|
||||
}
|
||||
status_code, detail = error_map.get(
|
||||
error_code, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
|
||||
if data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve members',
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error retrieving organization members')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve members',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/members/count')
|
||||
async def get_org_members_count(
|
||||
org_id: UUID,
|
||||
email: Annotated[
|
||||
str | None,
|
||||
Query(
|
||||
title='Filter members by email (case-insensitive partial match)',
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
),
|
||||
] = None,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> int:
|
||||
"""Get count of organization members with optional email filter.
|
||||
|
||||
This endpoint returns the total count of organization members matching
|
||||
the filter criteria. Access requires the VIEW_ORG_SETTINGS permission,
|
||||
which is granted to all organization members (member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
email: Optional email filter (case-insensitive partial match)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
int: Total count of organization members matching the filter
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission or is not a member
|
||||
HTTPException: 400 if org_id format is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
return await OrgMemberService.get_org_members_count(
|
||||
org_id=org_id,
|
||||
current_user_id=UUID(user_id),
|
||||
email_filter=email,
|
||||
)
|
||||
except OrgMemberNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You are not a member of this organization',
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error retrieving organization member count')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve member count',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}/members/{user_id}')
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Remove a member from an organization.
|
||||
|
||||
Only owners and admins can remove members:
|
||||
- Owners can remove admins and regular users
|
||||
- Admins can only remove regular users
|
||||
|
||||
Users cannot remove themselves. The last owner cannot be removed.
|
||||
"""
|
||||
try:
|
||||
success, error = await OrgMemberService.remove_org_member(
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
)
|
||||
|
||||
if not success:
|
||||
error_map = {
|
||||
'not_a_member': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You are not a member of this organization',
|
||||
),
|
||||
'cannot_remove_self': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'Cannot remove yourself from an organization',
|
||||
),
|
||||
'member_not_found': (
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
'Member not found in this organization',
|
||||
),
|
||||
'insufficient_permission': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You do not have permission to remove this member',
|
||||
),
|
||||
'cannot_remove_last_owner': (
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'Cannot remove the last owner of an organization',
|
||||
),
|
||||
'removal_failed': (
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'Failed to remove member',
|
||||
),
|
||||
}
|
||||
status_code, detail = error_map.get(
|
||||
error, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
|
||||
return {'message': 'Member removed successfully'}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization or user ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error removing organization member')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to remove member',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def switch_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Switch to a different organization.
|
||||
|
||||
This endpoint allows authenticated users to switch their current active
|
||||
organization. The user must be a member of the target organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to switch to (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details that was switched to
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 403 if user is not a member of the organization
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if switch fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to switch organization with membership validation
|
||||
org = await OrgService.switch_org(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM for the new current org
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed during organization switch',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to switch organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error switching organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
|
||||
async def update_org_member(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
update_data: OrgMemberUpdate,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
) -> OrgMemberResponse:
|
||||
"""Update a member's role in an organization.
|
||||
|
||||
Permission rules:
|
||||
- Admins can change roles of regular members to Admin or Member
|
||||
- Admins cannot modify other Admins or Owners
|
||||
- Owners can change roles of Admins and Members to any role (Owner, Admin, Member)
|
||||
- Owners cannot modify other Owners
|
||||
|
||||
Members cannot modify their own role. The last owner cannot be demoted.
|
||||
"""
|
||||
try:
|
||||
return await OrgMemberService.update_org_member(
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
update_data=update_data,
|
||||
)
|
||||
except OrgMemberNotFoundError as e:
|
||||
# Distinguish between requester not being a member vs target not found
|
||||
if str(current_user_id) in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You are not a member of this organization',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Member not found in this organization',
|
||||
)
|
||||
except CannotModifySelfError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Cannot modify your own role',
|
||||
)
|
||||
except RoleNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Role configuration error',
|
||||
)
|
||||
except InvalidRoleError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid role specified',
|
||||
)
|
||||
except InsufficientPermissionError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You do not have permission to modify this member',
|
||||
)
|
||||
except LastOwnerError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot demote the last owner of an organization',
|
||||
)
|
||||
except MemberUpdateError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization or user ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error updating organization member')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
|
||||
@@ -4,6 +4,8 @@ from fastapi import APIRouter, Depends, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.user_store import UserStore
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
@@ -114,13 +116,23 @@ async def saas_get_user(
|
||||
content='Failed to retrieve user_info.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
# Prefer email from DB; fall back to Keycloak if not yet persisted
|
||||
email = user_info.get('email') if user_info else None
|
||||
sub = user_info.get('sub') if user_info else ''
|
||||
if sub:
|
||||
db_user = await UserStore.get_user_by_id_async(sub)
|
||||
if db_user and db_user.email is not None:
|
||||
email = db_user.email
|
||||
|
||||
retval = await _check_idp(
|
||||
access_token=access_token,
|
||||
default_value=User(
|
||||
id=(user_info.get('sub') if user_info else '') or '',
|
||||
id=sub,
|
||||
login=(user_info.get('preferred_username') if user_info else '') or '',
|
||||
avatar_url='',
|
||||
email=user_info.get('email') if user_info else None,
|
||||
email=email,
|
||||
name=resolve_display_name(user_info) if user_info else None,
|
||||
company=user_info.get('company') if user_info else None,
|
||||
),
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
115
enterprise/server/routes/user_app_settings.py
Normal file
115
enterprise/server/routes/user_app_settings.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""Routes for user app settings API.
|
||||
|
||||
Provides endpoints for managing user-level app preferences:
|
||||
- GET /api/users/app - Retrieve current user's app settings
|
||||
- POST /api/users/app - Update current user's app settings
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from server.routes.user_app_settings_models import (
|
||||
UserAppSettingsResponse,
|
||||
UserAppSettingsUpdate,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from server.services.user_app_settings_service import (
|
||||
UserAppSettingsService,
|
||||
UserAppSettingsServiceInjector,
|
||||
)
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
user_app_settings_router = APIRouter(prefix='/api/users')
|
||||
|
||||
# Create injector instance and dependency at module level
|
||||
_injector = UserAppSettingsServiceInjector()
|
||||
user_app_settings_service_dependency = Depends(_injector.depends)
|
||||
|
||||
|
||||
@user_app_settings_router.get('/app', response_model=UserAppSettingsResponse)
|
||||
async def get_user_app_settings(
|
||||
service: UserAppSettingsService = user_app_settings_service_dependency,
|
||||
) -> UserAppSettingsResponse:
|
||||
"""Get the current user's app settings.
|
||||
|
||||
Returns language, analytics consent, sound notifications, and git config.
|
||||
|
||||
Args:
|
||||
service: UserAppSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
UserAppSettingsResponse: The user's app settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 404 if user not found
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
return await service.get_user_app_settings()
|
||||
|
||||
except ValueError as e:
|
||||
# User not authenticated
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except UserNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving user app settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve user app settings',
|
||||
)
|
||||
|
||||
|
||||
@user_app_settings_router.post('/app', response_model=UserAppSettingsResponse)
|
||||
async def update_user_app_settings(
|
||||
update_data: UserAppSettingsUpdate,
|
||||
service: UserAppSettingsService = user_app_settings_service_dependency,
|
||||
) -> UserAppSettingsResponse:
|
||||
"""Update the current user's app settings (partial update).
|
||||
|
||||
Only provided fields will be updated. Pass null to clear a field.
|
||||
|
||||
Args:
|
||||
update_data: Fields to update
|
||||
service: UserAppSettingsService (injected by dependency)
|
||||
|
||||
Returns:
|
||||
UserAppSettingsResponse: The updated user's app settings
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 404 if user not found
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
try:
|
||||
return await service.update_user_app_settings(update_data)
|
||||
|
||||
except ValueError as e:
|
||||
# User not authenticated
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=str(e),
|
||||
)
|
||||
except UserNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to update user app settings',
|
||||
extra={'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update user app settings',
|
||||
)
|
||||
57
enterprise/server/routes/user_app_settings_models.py
Normal file
57
enterprise/server/routes/user_app_settings_models.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Pydantic models for user app settings API.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from storage.user import User
|
||||
|
||||
|
||||
class UserAppSettingsError(Exception):
|
||||
"""Base exception for user app settings errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UserNotFoundError(UserAppSettingsError):
|
||||
"""Raised when user is not found."""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
super().__init__(f'User with id "{user_id}" not found')
|
||||
|
||||
|
||||
class UserAppSettingsUpdateError(UserAppSettingsError):
|
||||
"""Raised when user app settings update fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class UserAppSettingsResponse(BaseModel):
|
||||
"""Response model for user app settings."""
|
||||
|
||||
language: str | None = None
|
||||
user_consents_to_analytics: bool | None = None
|
||||
enable_sound_notifications: bool | None = None
|
||||
git_user_name: str | None = None
|
||||
git_user_email: EmailStr | None = None
|
||||
|
||||
@classmethod
|
||||
def from_user(cls, user: User) -> 'UserAppSettingsResponse':
|
||||
"""Create response from User entity."""
|
||||
return cls(
|
||||
language=user.language,
|
||||
user_consents_to_analytics=user.user_consents_to_analytics,
|
||||
enable_sound_notifications=user.enable_sound_notifications,
|
||||
git_user_name=user.git_user_name,
|
||||
git_user_email=user.git_user_email,
|
||||
)
|
||||
|
||||
|
||||
class UserAppSettingsUpdate(BaseModel):
|
||||
"""Request model for updating user app settings (partial update)."""
|
||||
|
||||
language: str | None = None
|
||||
user_consents_to_analytics: bool | None = None
|
||||
enable_sound_notifications: bool | None = None
|
||||
git_user_name: str | None = None
|
||||
git_user_email: EmailStr | None = None
|
||||
@@ -1139,6 +1139,71 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
}
|
||||
update_conversation_metadata(conversation_id, metadata_content)
|
||||
|
||||
async def list_files(self, sid: str, path: str | None = None) -> list[str]:
|
||||
"""List files in the workspace for a conversation.
|
||||
|
||||
Delegates to the nested container's list-files endpoint.
|
||||
|
||||
Args:
|
||||
sid: The session/conversation ID.
|
||||
path: Optional path to list files from. If None, lists from workspace root.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_list_files_from_nested(
|
||||
sid, nested_url, session_api_key, path
|
||||
)
|
||||
|
||||
async def select_file(self, sid: str, file: str) -> tuple[str | None, str | None]:
|
||||
"""Read a file from the workspace via nested container.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_select_file_from_nested(
|
||||
sid, nested_url, session_api_key, file
|
||||
)
|
||||
|
||||
async def upload_files(
|
||||
self, sid: str, files: list[tuple[str, bytes]]
|
||||
) -> tuple[list[str], list[dict[str, str]]]:
|
||||
"""Upload files to the workspace via nested container.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_upload_files_to_nested(
|
||||
sid, nested_url, session_api_key, files
|
||||
)
|
||||
|
||||
|
||||
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
|
||||
last_updated_at = conversation.last_updated_at
|
||||
|
||||
131
enterprise/server/services/email_service.py
Normal file
131
enterprise/server/services/email_service.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Email service for sending transactional emails via Resend."""
|
||||
|
||||
import os
|
||||
|
||||
try:
|
||||
import resend
|
||||
|
||||
RESEND_AVAILABLE = True
|
||||
except ImportError:
|
||||
RESEND_AVAILABLE = False
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
|
||||
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""Service for sending transactional emails."""
|
||||
|
||||
@staticmethod
|
||||
def _get_resend_client() -> bool:
|
||||
"""Initialize and return the Resend client.
|
||||
|
||||
Returns:
|
||||
bool: True if client is ready, False otherwise
|
||||
"""
|
||||
if not RESEND_AVAILABLE:
|
||||
logger.warning('Resend library not installed, skipping email')
|
||||
return False
|
||||
|
||||
resend_api_key = os.environ.get('RESEND_API_KEY')
|
||||
if not resend_api_key:
|
||||
logger.warning('RESEND_API_KEY not configured, skipping email')
|
||||
return False
|
||||
|
||||
resend.api_key = resend_api_key
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def send_invitation_email(
|
||||
to_email: str,
|
||||
org_name: str,
|
||||
inviter_name: str,
|
||||
role_name: str,
|
||||
invitation_token: str,
|
||||
invitation_id: int,
|
||||
) -> None:
|
||||
"""Send an organization invitation email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient's email address
|
||||
org_name: Name of the organization
|
||||
inviter_name: Display name of the person who sent the invite
|
||||
role_name: Role being offered (e.g., 'member', 'admin')
|
||||
invitation_token: The secure invitation token
|
||||
invitation_id: The invitation ID for logging
|
||||
"""
|
||||
if not EmailService._get_resend_client():
|
||||
return
|
||||
|
||||
# Build invitation URL
|
||||
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
|
||||
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
|
||||
|
||||
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
|
||||
|
||||
params = {
|
||||
'from': from_email,
|
||||
'to': [to_email],
|
||||
'subject': f"You're invited to join {org_name} on OpenHands",
|
||||
'html': f"""
|
||||
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<p>Hi,</p>
|
||||
|
||||
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
|
||||
|
||||
<p>Click the button below to accept the invitation:</p>
|
||||
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{invitation_url}"
|
||||
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
|
||||
text-decoration: none; border-radius: 8px; display: inline-block;
|
||||
font-size: 14px; font-weight: 600;">
|
||||
Accept Invitation
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
Or copy and paste this link into your browser:<br>
|
||||
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
This invitation will expire in 7 days.
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
If you weren't expecting this invitation, you can safely ignore this email.
|
||||
</p>
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
|
||||
|
||||
<p style="color: #999; font-size: 12px;">
|
||||
Best,<br>
|
||||
The OpenHands Team
|
||||
</p>
|
||||
</div>
|
||||
""",
|
||||
}
|
||||
|
||||
try:
|
||||
response = resend.Emails.send(params)
|
||||
logger.info(
|
||||
'Invitation email sent',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'response_id': response.get('id') if response else None,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise
|
||||
130
enterprise/server/services/org_app_settings_service.py
Normal file
130
enterprise/server/services/org_app_settings_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Service class for managing organization app settings.
|
||||
|
||||
Separates business logic from route handlers.
|
||||
Uses dependency injection for db_session and user_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from server.routes.org_models import (
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from storage.org_app_settings_store import OrgAppSettingsStore
|
||||
|
||||
from openhands.app_server.services.injector import Injector, InjectorState
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgAppSettingsService:
|
||||
"""Service for organization app settings with injected dependencies."""
|
||||
|
||||
store: OrgAppSettingsStore
|
||||
user_context: UserContext
|
||||
|
||||
async def get_org_app_settings(self) -> OrgAppSettingsResponse:
|
||||
"""Get organization app settings.
|
||||
|
||||
User ID is obtained from the injected user_context.
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The organization's app settings
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If current organization is not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
|
||||
logger.info(
|
||||
'Getting organization app settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
|
||||
if not org:
|
||||
raise OrgNotFoundError('current')
|
||||
|
||||
return OrgAppSettingsResponse.from_org(org)
|
||||
|
||||
async def update_org_app_settings(
|
||||
self,
|
||||
update_data: OrgAppSettingsUpdate,
|
||||
) -> OrgAppSettingsResponse:
|
||||
"""Update organization app settings.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
User ID is obtained from the injected user_context.
|
||||
Session auto-commits at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
update_data: The update data from the request
|
||||
|
||||
Returns:
|
||||
OrgAppSettingsResponse: The updated organization's app settings
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If current organization is not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
|
||||
logger.info(
|
||||
'Updating organization app settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get current org first
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
|
||||
if not org:
|
||||
raise OrgNotFoundError('current')
|
||||
|
||||
# Check if any fields are provided
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_dict:
|
||||
# No fields to update, just return current settings
|
||||
logger.info(
|
||||
'No fields to update in app settings',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
return OrgAppSettingsResponse.from_org(org)
|
||||
|
||||
updated_org = await self.store.update_org_app_settings(
|
||||
org_id=org.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
if not updated_org:
|
||||
raise OrgNotFoundError('current')
|
||||
|
||||
logger.info(
|
||||
'Organization app settings updated successfully',
|
||||
extra={'user_id': user_id, 'updated_fields': list(update_dict.keys())},
|
||||
)
|
||||
|
||||
return OrgAppSettingsResponse.from_org(updated_org)
|
||||
|
||||
|
||||
class OrgAppSettingsServiceInjector(Injector[OrgAppSettingsService]):
|
||||
"""Injector that composes store and user_context for OrgAppSettingsService."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[OrgAppSettingsService, None]:
|
||||
# Local imports to avoid circular dependencies
|
||||
from openhands.app_server.config import get_db_session, get_user_context
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
store = OrgAppSettingsStore(db_session=db_session)
|
||||
yield OrgAppSettingsService(store=store, user_context=user_context)
|
||||
397
enterprise/server/services/org_invitation_service.py
Normal file
397
enterprise/server/services/org_invitation_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Service for managing organization invitations."""
|
||||
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import ROLE_ADMIN, ROLE_OWNER
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.email_service import EmailService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import OrgInvitationStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgInvitationService:
|
||||
"""Service for organization invitation operations."""
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the organization exists
|
||||
2. Validates this is not a personal workspace
|
||||
3. Checks inviter has owner/admin role
|
||||
4. Validates role assignment permissions
|
||||
5. Checks if user is already a member
|
||||
6. Creates the invitation
|
||||
7. Sends the invitation email
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
UserAlreadyMemberError: If email is already a member
|
||||
InvitationAlreadyExistsError: If pending invitation exists
|
||||
"""
|
||||
email = email.lower().strip()
|
||||
|
||||
logger.info(
|
||||
'Creating organization invitation',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
# Step 2: Check this is not a personal workspace
|
||||
# A personal workspace has org_id matching the user's id
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
# Step 3: Check inviter is a member and has permission
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
# Step 4: Validate role assignment permissions
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
# Get the target role
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 5: Check if user is already a member (by email)
|
||||
existing_user = await UserStore.get_user_by_email_async(email)
|
||||
if existing_user:
|
||||
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'User is already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 6: Create the invitation
|
||||
invitation = await OrgInvitationStore.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_id=target_role.id,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Step 7: Send invitation email
|
||||
try:
|
||||
# Get inviter info for the email
|
||||
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
|
||||
inviter_name = 'A team member'
|
||||
if inviter_user and inviter_user.email:
|
||||
inviter_name = inviter_user.email.split('@')[0]
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email=email,
|
||||
org_name=org.name,
|
||||
inviter_name=inviter_name,
|
||||
role_name=target_role.name,
|
||||
invitation_token=invitation.token,
|
||||
invitation_id=invitation.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'email': email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Don't fail the invitation creation if email fails
|
||||
# The user can still access via direct link
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def create_invitations_batch(
|
||||
org_id: UUID,
|
||||
emails: list[str],
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
|
||||
"""Create multiple organization invitations concurrently.
|
||||
|
||||
Validates permissions once upfront, then creates invitations in parallel.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
emails: List of invitee email addresses
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitations
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_invitations, failed_emails_with_errors)
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
"""
|
||||
logger.info(
|
||||
'Creating batch organization invitations',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email_count': len(emails),
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate permissions upfront (shared for all emails)
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 2: Create invitations concurrently
|
||||
async def create_single(
|
||||
email: str,
|
||||
) -> tuple[str, OrgInvitation | None, str | None]:
|
||||
"""Create single invitation, return (email, invitation, error)."""
|
||||
try:
|
||||
invitation = await OrgInvitationService.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_name=role_name,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
return (email, invitation, None)
|
||||
except (UserAlreadyMemberError, ValueError) as e:
|
||||
return (email, None, str(e))
|
||||
|
||||
results = await asyncio.gather(*[create_single(email) for email in emails])
|
||||
|
||||
# Step 3: Separate successes and failures
|
||||
successful: list[OrgInvitation] = []
|
||||
failed: list[tuple[str, str]] = []
|
||||
for email, invitation, error in results:
|
||||
if invitation:
|
||||
successful.append(invitation)
|
||||
elif error:
|
||||
failed.append((email, error))
|
||||
|
||||
logger.info(
|
||||
'Batch invitation creation completed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
},
|
||||
)
|
||||
|
||||
return successful, failed
|
||||
|
||||
@staticmethod
|
||||
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
|
||||
"""Accept an organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the token and invitation status
|
||||
2. Checks expiration
|
||||
3. Verifies user is not already a member
|
||||
4. Creates LiteLLM integration
|
||||
5. Adds user to the organization
|
||||
6. Marks invitation as accepted
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
user_id: The user accepting the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The accepted invitation
|
||||
|
||||
Raises:
|
||||
InvitationInvalidError: If token is invalid or invitation not pending
|
||||
InvitationExpiredError: If invitation has expired
|
||||
UserAlreadyMemberError: If user is already a member
|
||||
"""
|
||||
logger.info(
|
||||
'Accepting organization invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
|
||||
'user_id': str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Get and validate invitation
|
||||
invitation = await OrgInvitationStore.get_invitation_by_token(token)
|
||||
|
||||
if not invitation:
|
||||
raise InvitationInvalidError('Invalid invitation token')
|
||||
|
||||
if invitation.status != OrgInvitation.STATUS_PENDING:
|
||||
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
|
||||
raise InvitationInvalidError('Invitation has already been accepted')
|
||||
elif invitation.status == OrgInvitation.STATUS_REVOKED:
|
||||
raise InvitationInvalidError('Invitation has been revoked')
|
||||
else:
|
||||
raise InvitationInvalidError('Invitation is no longer valid')
|
||||
|
||||
# Step 2: Check expiration
|
||||
if OrgInvitationStore.is_token_expired(invitation):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
raise InvitationExpiredError('Invitation has expired')
|
||||
|
||||
# Step 2.5: Verify user email matches invitation email
|
||||
user = await UserStore.get_user_by_id_async(str(user_id))
|
||||
if not user:
|
||||
raise InvitationInvalidError('User not found')
|
||||
|
||||
user_email = user.email
|
||||
# Fallback: fetch email from Keycloak if not in database (for existing users)
|
||||
if not user_email:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
|
||||
user_email = user_info.get('email') if user_info else None
|
||||
|
||||
if not user_email:
|
||||
raise EmailMismatchError('Your account does not have an email address')
|
||||
|
||||
user_email = user_email.lower().strip()
|
||||
invitation_email = invitation.email.lower().strip()
|
||||
|
||||
if user_email != invitation_email:
|
||||
logger.warning(
|
||||
'Email mismatch during invitation acceptance',
|
||||
extra={
|
||||
'user_id': str(user_id),
|
||||
'user_email': user_email,
|
||||
'invitation_email': invitation_email,
|
||||
'invitation_id': invitation.id,
|
||||
},
|
||||
)
|
||||
raise EmailMismatchError()
|
||||
|
||||
# Step 3: Check if user is already a member
|
||||
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'You are already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 4: Create LiteLLM integration for the user in the new org
|
||||
try:
|
||||
settings = await OrgService.create_litellm_integration(
|
||||
invitation.org_id, str(user_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to create LiteLLM integration for invitation acceptance',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise InvitationInvalidError(
|
||||
'Failed to set up organization access. Please try again.'
|
||||
)
|
||||
|
||||
# Step 5: Add user to organization
|
||||
from storage.org_member_store import OrgMemberStore as OMS
|
||||
|
||||
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
|
||||
# Don't override with org defaults - use invitation-specified role
|
||||
org_member_kwargs.pop('llm_model', None)
|
||||
org_member_kwargs.pop('llm_base_url', None)
|
||||
|
||||
OrgMemberStore.add_user_to_org(
|
||||
org_id=invitation.org_id,
|
||||
user_id=user_id,
|
||||
role_id=invitation.role_id,
|
||||
llm_api_key=settings.llm_api_key,
|
||||
status='active',
|
||||
)
|
||||
|
||||
# Step 6: Mark invitation as accepted
|
||||
updated_invitation = await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id,
|
||||
OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization invitation accepted',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'role_id': invitation.role_id,
|
||||
},
|
||||
)
|
||||
|
||||
return updated_invitation
|
||||
130
enterprise/server/services/org_llm_settings_service.py
Normal file
130
enterprise/server/services/org_llm_settings_service.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Service class for managing organization LLM settings.
|
||||
|
||||
Separates business logic from route handlers.
|
||||
Uses dependency injection for db_session and user_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from server.routes.org_models import (
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from storage.org_llm_settings_store import OrgLLMSettingsStore
|
||||
|
||||
from openhands.app_server.services.injector import Injector, InjectorState
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgLLMSettingsService:
|
||||
"""Service for org LLM settings with injected dependencies."""
|
||||
|
||||
store: OrgLLMSettingsStore
|
||||
user_context: UserContext
|
||||
|
||||
async def get_org_llm_settings(self) -> OrgLLMSettingsResponse:
|
||||
"""Get LLM settings for user's current organization.
|
||||
|
||||
User ID is obtained from the injected user_context.
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The organization's LLM settings
|
||||
|
||||
Raises:
|
||||
ValueError: If user is not authenticated
|
||||
OrgNotFoundError: If current organization not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if not user_id:
|
||||
raise ValueError('User is not authenticated')
|
||||
|
||||
logger.info(
|
||||
'Getting organization LLM settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
|
||||
if not org:
|
||||
raise OrgNotFoundError('No current organization')
|
||||
|
||||
return OrgLLMSettingsResponse.from_org(org)
|
||||
|
||||
async def update_org_llm_settings(
|
||||
self,
|
||||
update_data: OrgLLMSettingsUpdate,
|
||||
) -> OrgLLMSettingsResponse:
|
||||
"""Update LLM settings for user's current organization.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
User ID is obtained from the injected user_context.
|
||||
Session auto-commits at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
update_data: The update data from the request
|
||||
|
||||
Returns:
|
||||
OrgLLMSettingsResponse: The updated organization's LLM settings
|
||||
|
||||
Raises:
|
||||
ValueError: If user is not authenticated
|
||||
OrgNotFoundError: If current organization not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if not user_id:
|
||||
raise ValueError('User is not authenticated')
|
||||
|
||||
logger.info(
|
||||
'Updating organization LLM settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Check if any fields are provided
|
||||
if not update_data.has_updates():
|
||||
# No fields to update, just return current settings
|
||||
return await self.get_org_llm_settings()
|
||||
|
||||
# Get user's current org first
|
||||
org = await self.store.get_current_org_by_user_id(user_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError('No current organization')
|
||||
|
||||
# Update the org LLM settings
|
||||
updated_org = await self.store.update_org_llm_settings(
|
||||
org_id=org.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
if not updated_org:
|
||||
raise OrgNotFoundError(str(org.id))
|
||||
|
||||
logger.info(
|
||||
'Organization LLM settings updated successfully',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
return OrgLLMSettingsResponse.from_org(updated_org)
|
||||
|
||||
|
||||
class OrgLLMSettingsServiceInjector(Injector[OrgLLMSettingsService]):
|
||||
"""Injector that composes store and user_context for OrgLLMSettingsService."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[OrgLLMSettingsService, None]:
|
||||
# Local imports to avoid circular dependencies
|
||||
from openhands.app_server.config import get_db_session, get_user_context
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
store = OrgLLMSettingsStore(db_session=db_session)
|
||||
yield OrgLLMSettingsService(store=store, user_context=user_context)
|
||||
417
enterprise/server/services/org_member_service.py
Normal file
417
enterprise/server/services/org_member_service.py
Normal file
@@ -0,0 +1,417 @@
|
||||
"""Service for managing organization members."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import ROLE_ADMIN, ROLE_OWNER
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class OrgMemberService:
|
||||
"""Service for organization member operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_me(org_id: UUID, user_id: UUID) -> MeResponse:
|
||||
"""Get the current user's membership record for an organization.
|
||||
|
||||
Retrieves the authenticated user's role, status, email, and LLM override
|
||||
fields (with masked API keys) within the specified organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: User ID (UUID)
|
||||
|
||||
Returns:
|
||||
MeResponse: The user's membership data with masked API keys
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If user is not a member of the organization
|
||||
RoleNotFoundError: If the role associated with the member is not found
|
||||
"""
|
||||
# Look up the user's membership in this org
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
if org_member is None:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(user_id))
|
||||
|
||||
# Resolve role name from role_id
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if role is None:
|
||||
raise RoleNotFoundError(org_member.role_id)
|
||||
|
||||
# Get user email
|
||||
user = UserStore.get_user_by_id(str(user_id))
|
||||
email = user.email if user and user.email else ''
|
||||
|
||||
return MeResponse.from_org_member(org_member, role, email)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members(
|
||||
org_id: UUID,
|
||||
current_user_id: UUID,
|
||||
page_id: str | None = None,
|
||||
limit: int = 10,
|
||||
email_filter: str | None = None,
|
||||
) -> tuple[bool, str | None, OrgMemberPage | None]:
|
||||
"""Get organization members with authorization check.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
current_user_id: Requesting user's UUID.
|
||||
page_id: Offset encoded as string (e.g., "0", "10", "20").
|
||||
limit: Items per page (default 10).
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_code, data). If success is True, error_code is None.
|
||||
"""
|
||||
# Verify current user is a member of the organization
|
||||
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
|
||||
if not requester_membership:
|
||||
return False, 'not_a_member', None
|
||||
|
||||
# Parse page_id to get offset (page_id is offset encoded as string)
|
||||
offset = 0
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
if offset < 0:
|
||||
return False, 'invalid_page_id', None
|
||||
except ValueError:
|
||||
return False, 'invalid_page_id', None
|
||||
|
||||
# Call store to get paginated members
|
||||
members, _ = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
email_filter=email_filter,
|
||||
)
|
||||
|
||||
# Transform data to response format
|
||||
items = []
|
||||
for member in members:
|
||||
# Access user and role relationships (eagerly loaded)
|
||||
user = member.user
|
||||
role = member.role
|
||||
|
||||
items.append(
|
||||
OrgMemberResponse(
|
||||
user_id=str(member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=member.role_id,
|
||||
role=role.name if role else '',
|
||||
role_rank=role.rank if role else 0,
|
||||
status=member.status,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate current page (1-indexed)
|
||||
current_page = (offset // limit) + 1
|
||||
|
||||
return (
|
||||
True,
|
||||
None,
|
||||
OrgMemberPage(
|
||||
items=items,
|
||||
current_page=current_page,
|
||||
per_page=limit,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_count(
|
||||
org_id: UUID,
|
||||
current_user_id: UUID,
|
||||
email_filter: str | None = None,
|
||||
) -> int:
|
||||
"""Get count of organization members with authorization check.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
current_user_id: Requesting user's UUID.
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
int: Count of organization members matching the filter.
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If requesting user is not a member of the organization.
|
||||
"""
|
||||
# Verify current user is a member of the organization
|
||||
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
|
||||
if not requester_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
|
||||
|
||||
return await OrgMemberStore.get_org_members_count(
|
||||
org_id=org_id,
|
||||
email_filter=email_filter,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
target_user_id: UUID,
|
||||
current_user_id: UUID,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Remove a member from an organization.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_message). If success is True, error_message is None.
|
||||
"""
|
||||
|
||||
def _remove_member():
|
||||
# Get current user's membership in the org
|
||||
requester_membership = OrgMemberStore.get_org_member(
|
||||
org_id, current_user_id
|
||||
)
|
||||
if not requester_membership:
|
||||
return False, 'not_a_member'
|
||||
|
||||
# Check if trying to remove self
|
||||
if str(current_user_id) == str(target_user_id):
|
||||
return False, 'cannot_remove_self'
|
||||
|
||||
# Get target user's membership
|
||||
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
|
||||
if not target_membership:
|
||||
return False, 'member_not_found'
|
||||
|
||||
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
|
||||
target_role = RoleStore.get_role_by_id(target_membership.role_id)
|
||||
|
||||
if not requester_role or not target_role:
|
||||
return False, 'role_not_found'
|
||||
|
||||
# Check permission based on roles
|
||||
if not OrgMemberService._can_remove_member(
|
||||
requester_role.name, target_role.name
|
||||
):
|
||||
return False, 'insufficient_permission'
|
||||
|
||||
# Check if removing the last owner
|
||||
if target_role.name == ROLE_OWNER:
|
||||
if OrgMemberService._is_last_owner(org_id, target_user_id):
|
||||
return False, 'cannot_remove_last_owner'
|
||||
|
||||
# Perform the removal
|
||||
success = OrgMemberStore.remove_user_from_org(org_id, target_user_id)
|
||||
if not success:
|
||||
return False, 'removal_failed'
|
||||
|
||||
# Update user's current_org_id if it points to the org they were removed from
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
if user and user.current_org_id == org_id:
|
||||
# Set current_org_id to personal workspace (org.id == user.id)
|
||||
UserStore.update_current_org(str(target_user_id), target_user_id)
|
||||
|
||||
return True, None
|
||||
|
||||
success, error = await call_sync_from_async(_remove_member)
|
||||
|
||||
# If database removal succeeded, also remove from LiteLLM team
|
||||
if success:
|
||||
try:
|
||||
await LiteLlmManager.remove_user_from_team(
|
||||
str(target_user_id), str(org_id)
|
||||
)
|
||||
logger.info(
|
||||
'Successfully removed user from LiteLLM team',
|
||||
extra={
|
||||
'user_id': str(target_user_id),
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail the operation - database removal already succeeded
|
||||
# LiteLLM state will be eventually consistent
|
||||
logger.warning(
|
||||
'Failed to remove user from LiteLLM team',
|
||||
extra={
|
||||
'user_id': str(target_user_id),
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
return success, error
|
||||
|
||||
@staticmethod
|
||||
async def update_org_member(
|
||||
org_id: UUID,
|
||||
target_user_id: UUID,
|
||||
current_user_id: UUID,
|
||||
update_data: OrgMemberUpdate,
|
||||
) -> OrgMemberResponse:
|
||||
"""Update a member's role in an organization.
|
||||
|
||||
Permission rules:
|
||||
- Owners can modify anyone (including other owners), can set any role
|
||||
- Admins can modify other admins and users
|
||||
- Admins can only set admin or user roles (not owner)
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
target_user_id: User ID of the member to update
|
||||
current_user_id: User ID of the requester
|
||||
update_data: Update data containing fields to modify
|
||||
|
||||
Returns:
|
||||
OrgMemberResponse: The updated member data
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If requester or target is not a member
|
||||
CannotModifySelfError: If trying to modify self
|
||||
RoleNotFoundError: If role configuration is invalid
|
||||
InvalidRoleError: If new_role_name is not a valid role
|
||||
InsufficientPermissionError: If requester lacks permission
|
||||
LastOwnerError: If trying to demote the last owner
|
||||
MemberUpdateError: If update operation fails
|
||||
"""
|
||||
new_role_name = update_data.role
|
||||
|
||||
def _update_member():
|
||||
# Get current user's membership in the org
|
||||
requester_membership = OrgMemberStore.get_org_member(
|
||||
org_id, current_user_id
|
||||
)
|
||||
if not requester_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
|
||||
|
||||
# Check if trying to modify self
|
||||
if str(current_user_id) == str(target_user_id):
|
||||
raise CannotModifySelfError('modify')
|
||||
|
||||
# Get target user's membership
|
||||
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
|
||||
if not target_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(target_user_id))
|
||||
|
||||
# Get roles
|
||||
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
|
||||
target_role = RoleStore.get_role_by_id(target_membership.role_id)
|
||||
|
||||
if not requester_role:
|
||||
raise RoleNotFoundError(requester_membership.role_id)
|
||||
if not target_role:
|
||||
raise RoleNotFoundError(target_membership.role_id)
|
||||
|
||||
# If no role change requested, return current state
|
||||
if new_role_name is None:
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
return OrgMemberResponse(
|
||||
user_id=str(target_membership.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=target_membership.role_id,
|
||||
role=target_role.name,
|
||||
role_rank=target_role.rank,
|
||||
status=target_membership.status,
|
||||
)
|
||||
|
||||
# Validate new role exists
|
||||
new_role = RoleStore.get_role_by_name(new_role_name.lower())
|
||||
if not new_role:
|
||||
raise InvalidRoleError(new_role_name)
|
||||
|
||||
# Check permission to modify target
|
||||
if not OrgMemberService._can_update_member_role(
|
||||
requester_role.name, target_role.name, new_role.name
|
||||
):
|
||||
raise InsufficientPermissionError(
|
||||
'You do not have permission to modify this member'
|
||||
)
|
||||
|
||||
# Check if demoting the last owner
|
||||
if (
|
||||
target_role.name == ROLE_OWNER
|
||||
and new_role.name != ROLE_OWNER
|
||||
and OrgMemberService._is_last_owner(org_id, target_user_id)
|
||||
):
|
||||
raise LastOwnerError('demote')
|
||||
|
||||
# Perform the update
|
||||
updated_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id, target_user_id, new_role.id
|
||||
)
|
||||
if not updated_member:
|
||||
raise MemberUpdateError('Failed to update member')
|
||||
|
||||
# Get user email for response
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
|
||||
return OrgMemberResponse(
|
||||
user_id=str(updated_member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=updated_member.role_id,
|
||||
role=new_role.name,
|
||||
role_rank=new_role.rank,
|
||||
status=updated_member.status,
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_update_member)
|
||||
|
||||
@staticmethod
|
||||
def _can_update_member_role(
|
||||
requester_role_name: str, target_role_name: str, new_role_name: str
|
||||
) -> bool:
|
||||
"""Check if requester can change target's role to new_role.
|
||||
|
||||
Permission rules:
|
||||
- Owners can modify anyone (including other owners), can set any role
|
||||
- Admins can modify other admins and users
|
||||
- Admins can only set admin or user roles (not owner)
|
||||
"""
|
||||
is_requester_owner = requester_role_name == ROLE_OWNER
|
||||
is_requester_admin = requester_role_name == ROLE_ADMIN
|
||||
is_target_owner = target_role_name == ROLE_OWNER
|
||||
is_new_role_owner = new_role_name == ROLE_OWNER
|
||||
|
||||
if is_requester_owner:
|
||||
# Owners can modify anyone (including other owners)
|
||||
return True
|
||||
elif is_requester_admin:
|
||||
# Admins cannot modify owners
|
||||
if is_target_owner:
|
||||
return False
|
||||
# Admins can only set admin or user roles (not owner)
|
||||
return not is_new_role_owner
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _can_remove_member(requester_role_name: str, target_role_name: str) -> bool:
|
||||
"""Check if requester can remove target based on roles."""
|
||||
if requester_role_name == ROLE_OWNER:
|
||||
return True
|
||||
elif requester_role_name == ROLE_ADMIN:
|
||||
# Admins can remove admins and members (not owners)
|
||||
return target_role_name != ROLE_OWNER
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_last_owner(org_id: UUID, user_id: UUID) -> bool:
|
||||
"""Check if user is the last owner of the organization."""
|
||||
members = OrgMemberStore.get_org_members(org_id)
|
||||
owners = []
|
||||
for m in members:
|
||||
# Use role_id (column) instead of role (relationship) to avoid DetachedInstanceError
|
||||
role = RoleStore.get_role_by_id(m.role_id)
|
||||
if role and role.name == ROLE_OWNER:
|
||||
owners.append(m)
|
||||
return len(owners) == 1 and str(owners[0].user_id) == str(user_id)
|
||||
126
enterprise/server/services/user_app_settings_service.py
Normal file
126
enterprise/server/services/user_app_settings_service.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Service class for managing user app settings.
|
||||
|
||||
Separates business logic from route handlers.
|
||||
Uses dependency injection for db_session and user_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from server.routes.user_app_settings_models import (
|
||||
UserAppSettingsResponse,
|
||||
UserAppSettingsUpdate,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from storage.user_app_settings_store import UserAppSettingsStore
|
||||
|
||||
from openhands.app_server.services.injector import Injector, InjectorState
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserAppSettingsService:
|
||||
"""Service for user app settings with injected dependencies."""
|
||||
|
||||
store: UserAppSettingsStore
|
||||
user_context: UserContext
|
||||
|
||||
async def get_user_app_settings(self) -> UserAppSettingsResponse:
|
||||
"""Get user app settings.
|
||||
|
||||
User ID is obtained from the injected user_context.
|
||||
|
||||
Returns:
|
||||
UserAppSettingsResponse: The user's app settings
|
||||
|
||||
Raises:
|
||||
ValueError: If user is not authenticated
|
||||
UserNotFoundError: If user is not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if not user_id:
|
||||
raise ValueError('User is not authenticated')
|
||||
|
||||
logger.info(
|
||||
'Getting user app settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
user = await self.store.get_user_by_id(user_id)
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
return UserAppSettingsResponse.from_user(user)
|
||||
|
||||
async def update_user_app_settings(
|
||||
self,
|
||||
update_data: UserAppSettingsUpdate,
|
||||
) -> UserAppSettingsResponse:
|
||||
"""Update user app settings.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
User ID is obtained from the injected user_context.
|
||||
Session auto-commits at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
update_data: The update data from the request
|
||||
|
||||
Returns:
|
||||
UserAppSettingsResponse: The updated user's app settings
|
||||
|
||||
Raises:
|
||||
ValueError: If user is not authenticated
|
||||
UserNotFoundError: If user is not found
|
||||
"""
|
||||
user_id = await self.user_context.get_user_id()
|
||||
if not user_id:
|
||||
raise ValueError('User is not authenticated')
|
||||
|
||||
logger.info(
|
||||
'Updating user app settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Check if any fields are provided
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_dict:
|
||||
# No fields to update, just return current settings
|
||||
return await self.get_user_app_settings()
|
||||
|
||||
user = await self.store.update_user_app_settings(
|
||||
user_id=user_id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
if not user:
|
||||
raise UserNotFoundError(user_id)
|
||||
|
||||
logger.info(
|
||||
'User app settings updated successfully',
|
||||
extra={'user_id': user_id, 'updated_fields': list(update_dict.keys())},
|
||||
)
|
||||
|
||||
return UserAppSettingsResponse.from_user(user)
|
||||
|
||||
|
||||
class UserAppSettingsServiceInjector(Injector[UserAppSettingsService]):
|
||||
"""Injector that composes store and user_context for UserAppSettingsService."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[UserAppSettingsService, None]:
|
||||
# Local imports to avoid circular dependencies
|
||||
from openhands.app_server.config import get_db_session, get_user_context
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
store = UserAppSettingsStore(db_session=db_session)
|
||||
yield UserAppSettingsService(store=store, user_context=user_context)
|
||||
@@ -22,11 +22,70 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.errors import AuthError
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _get_current_user(self) -> User | None:
|
||||
"""Get the current user using the existing db_session.
|
||||
|
||||
Uses self.db_session to avoid opening a separate database session.
|
||||
|
||||
Returns:
|
||||
User object or None if no user_id is available
|
||||
"""
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str:
|
||||
return None
|
||||
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
result = await self.db_session.execute(
|
||||
select(User).where(User.id == user_id_uuid)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _apply_user_and_org_filter(self, query):
|
||||
"""Apply user_id and org_id filters to ensure conversation isolation.
|
||||
|
||||
Filters conversations by:
|
||||
- user_id: Only show conversations belonging to the current user
|
||||
- org_id: Only show conversations belonging to the user's current organization
|
||||
|
||||
Args:
|
||||
query: SQLAlchemy query to apply filters to
|
||||
|
||||
Returns:
|
||||
Query with user and organization filters applied
|
||||
|
||||
Raises:
|
||||
AuthError: If no user_id is available (secure default: deny access)
|
||||
"""
|
||||
# For internal operations such as getting a conversation by session_api_key
|
||||
# we need a mode that does not have filtering. The dependency `as_admin()`
|
||||
# is used to enable it
|
||||
if self.user_context == ADMIN:
|
||||
return query
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str:
|
||||
# Secure default: no user means no access, not "show everything"
|
||||
raise AuthError('User authentication required')
|
||||
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
# Filter by organization ID to ensure conversations are isolated per organization
|
||||
user = await self._get_current_user()
|
||||
if user and user.current_org_id is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadataSaas.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
@@ -38,13 +97,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
return await self._apply_user_and_org_filter(query)
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
@@ -57,13 +110,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
return await self._apply_user_and_org_filter(query)
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
@@ -155,21 +202,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
# Apply user and organization filtering
|
||||
query = await self._apply_user_and_org_filter(query)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
@@ -233,7 +275,13 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
# Fetch sub-conversation IDs
|
||||
sub_conversation_ids = await self.get_sub_conversation_ids(conversation_id)
|
||||
return self._to_info_with_user_id(
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
sub_conversation_ids=sub_conversation_ids,
|
||||
)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
@@ -262,8 +310,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
# Fetch sub-conversation IDs for each conversation
|
||||
sub_conversation_ids = await self.get_sub_conversation_ids(
|
||||
UUID(conversation_id)
|
||||
)
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
self._to_info_with_user_id(
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
sub_conversation_ids=sub_conversation_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
@@ -316,10 +372,11 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
info = self._to_info(stored, sub_conversation_ids=sub_conversation_ids)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
|
||||
33
enterprise/server/verified_models/verified_model_models.py
Normal file
33
enterprise/server/verified_models/verified_model_models.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, StringConstraints
|
||||
|
||||
|
||||
class VerifiedModelCreate(BaseModel):
|
||||
model_name: Annotated[
|
||||
str,
|
||||
StringConstraints(max_length=255),
|
||||
]
|
||||
provider: Annotated[
|
||||
str,
|
||||
StringConstraints(max_length=100),
|
||||
]
|
||||
is_enabled: bool = True
|
||||
|
||||
|
||||
class VerifiedModel(VerifiedModelCreate):
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class VerifiedModelUpdate(BaseModel):
|
||||
is_enabled: bool | None = None
|
||||
|
||||
|
||||
class VerifiedModelPage(BaseModel):
|
||||
"""Paginated response model for verified model list."""
|
||||
|
||||
items: list[VerifiedModel]
|
||||
next_page_id: str | None = None
|
||||
143
enterprise/server/verified_models/verified_model_router.py
Normal file
143
enterprise/server/verified_models/verified_model_router.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""API routes for managing verified LLM models (admin only)."""
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelCreate,
|
||||
VerifiedModelPage,
|
||||
VerifiedModelUpdate,
|
||||
)
|
||||
from server.verified_models.verified_model_service import (
|
||||
VerifiedModelService,
|
||||
verified_model_store_dependency,
|
||||
)
|
||||
|
||||
from openhands.app_server.config import get_db_session
|
||||
from openhands.server.routes import public
|
||||
from openhands.utils.llm import get_supported_llm_models
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
|
||||
|
||||
@api_router.get('')
|
||||
async def search_verified_models(
|
||||
provider: str | None = None,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int, Query(title='The max number of results in the page', gt=0, le=100)
|
||||
] = 100,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModelPage:
|
||||
"""List all verified models, optionally filtered by provider."""
|
||||
# Use SQL-level filtering and pagination
|
||||
result = await verified_model_service.search_verified_models(
|
||||
provider=provider,
|
||||
enabled_only=False, # Admin sees all models including disabled
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@api_router.post('', status_code=201)
|
||||
async def create_verified_model(
|
||||
data: VerifiedModelCreate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model."""
|
||||
try:
|
||||
model = await verified_model_service.create_verified_model(
|
||||
model_name=data.model_name,
|
||||
provider=data.provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
return model
|
||||
except ValueError as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ex),
|
||||
)
|
||||
|
||||
|
||||
@api_router.put('/{provider}/{model_name:path}')
|
||||
async def update_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
data: VerifiedModelUpdate,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> VerifiedModel:
|
||||
"""Update a verified model by provider and model name."""
|
||||
model = await verified_model_service.update_verified_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=data.is_enabled,
|
||||
)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Model {provider}/{model_name} not found',
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@api_router.delete('/{provider}/{model_name:path}')
|
||||
async def delete_verified_model(
|
||||
provider: str,
|
||||
model_name: str,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
verified_model_service: VerifiedModelService = Depends(
|
||||
verified_model_store_dependency
|
||||
),
|
||||
) -> bool:
|
||||
"""Delete a verified model by provider and model name."""
|
||||
try:
|
||||
await verified_model_service.delete_verified_model(
|
||||
model_name=model_name, provider=provider
|
||||
)
|
||||
return True
|
||||
except ValueError as ex:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(ex),
|
||||
)
|
||||
|
||||
|
||||
async def get_saas_llm_models_dependency(request: Request) -> list[str]:
|
||||
"""SaaS implementation for the LLM models endpoint."""
|
||||
async with get_db_session(request.state, request) as db_session:
|
||||
# Prevent circular import
|
||||
from openhands.server.shared import config
|
||||
|
||||
verified_model_service = VerifiedModelService(db_session)
|
||||
page = await verified_model_service.search_verified_models(enabled_only=True)
|
||||
if page.next_page_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Too many models defined in database',
|
||||
)
|
||||
verified_models = [f'{m.provider}/{m.model_name}' for m in page.items]
|
||||
return get_supported_llm_models(config, verified_models)
|
||||
|
||||
|
||||
# Override the default implementation with SaaS implementation
|
||||
# This must be called after the app is created in saas_server.py
|
||||
def override_llm_models_dependency(app):
|
||||
"""Override the default LLM models implementation with SaaS version."""
|
||||
app.dependency_overrides[public.get_llm_models_dependency] = (
|
||||
get_saas_llm_models_dependency
|
||||
)
|
||||
242
enterprise/server/verified_models/verified_model_service.py
Normal file
242
enterprise/server/verified_models/verified_model_service.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Store for managing verified LLM models in the database."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Identity,
|
||||
Integer,
|
||||
String,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
func,
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.base import Base
|
||||
|
||||
from enterprise.server.verified_models.verified_model_models import (
|
||||
VerifiedModel,
|
||||
VerifiedModelPage,
|
||||
)
|
||||
from openhands.app_server.config import depends_db_session
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class StoredVerifiedModel(Base): # type: ignore
|
||||
"""A verified LLM model available in the model selector.
|
||||
|
||||
The composite unique constraint on (model_name, provider) allows the same
|
||||
model name to exist under different providers (e.g. 'claude-sonnet' under
|
||||
both 'openhands' and 'anthropic').
|
||||
"""
|
||||
|
||||
__tablename__ = 'verified_models'
|
||||
__table_args__ = (
|
||||
UniqueConstraint('model_name', 'provider', name='uq_verified_model_provider'),
|
||||
)
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
model_name = Column(String(255), nullable=False)
|
||||
provider = Column(String(100), nullable=False, index=True)
|
||||
is_enabled = Column(
|
||||
Boolean, nullable=False, default=True, server_default=text('true')
|
||||
)
|
||||
created_at = Column(DateTime, nullable=False, server_default=func.now())
|
||||
updated_at = Column(
|
||||
DateTime, nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
def verified_model(result: StoredVerifiedModel) -> VerifiedModel:
|
||||
return VerifiedModel(
|
||||
id=result.id,
|
||||
model_name=result.model_name,
|
||||
provider=result.provider,
|
||||
is_enabled=result.is_enabled,
|
||||
created_at=result.created_at,
|
||||
updated_at=result.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerifiedModelService:
|
||||
"""Store for CRUD operations on verified models.
|
||||
|
||||
Follows the async pattern with db_session as an attribute.
|
||||
"""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def search_verified_models(
|
||||
self,
|
||||
provider: str | None = None,
|
||||
enabled_only: bool = True,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> VerifiedModelPage:
|
||||
"""Search for verified models with optional filtering and pagination.
|
||||
|
||||
Args:
|
||||
provider: Optional provider name to filter by (e.g., 'openhands', 'anthropic')
|
||||
enabled_only: If True, only return enabled models (default: True)
|
||||
page_id: Page id for pagination
|
||||
limit: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
SearchModelsResult containing items list and has_more flag
|
||||
"""
|
||||
query = select(StoredVerifiedModel)
|
||||
|
||||
# Build filters
|
||||
filters = []
|
||||
if provider:
|
||||
filters.append(StoredVerifiedModel.provider == provider)
|
||||
if enabled_only:
|
||||
filters.append(StoredVerifiedModel.is_enabled.is_(True))
|
||||
|
||||
if filters:
|
||||
query = query.where(and_(*filters))
|
||||
|
||||
# Order by provider, then model_name
|
||||
query = query.order_by(
|
||||
StoredVerifiedModel.provider, StoredVerifiedModel.model_name
|
||||
)
|
||||
|
||||
# Fetch limit + 1 to check if there are more results
|
||||
offset = int(page_id or '0')
|
||||
query = query.offset(offset).limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
results = list(result.scalars().all())
|
||||
has_more = len(results) > limit
|
||||
next_page_id = None
|
||||
|
||||
# Return only the requested number of results
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
results.pop()
|
||||
|
||||
items = [verified_model(result) for result in results]
|
||||
return VerifiedModelPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def get_model(self, model_name: str, provider: str) -> VerifiedModel | None:
|
||||
"""Get a model by its composite key (model_name, provider).
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
return result.scalars().first()
|
||||
|
||||
async def create_verified_model(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool = True,
|
||||
) -> VerifiedModel:
|
||||
"""Create a new verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model identifier
|
||||
provider: The provider name
|
||||
is_enabled: Whether the model is enabled (default True)
|
||||
|
||||
Raises:
|
||||
ValueError: If a model with the same (model_name, provider) already exists
|
||||
"""
|
||||
existing_query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(existing_query)
|
||||
existing = result.scalars().first()
|
||||
if existing:
|
||||
raise ValueError(f'Model {provider}/{model_name} already exists')
|
||||
|
||||
model = StoredVerifiedModel(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
is_enabled=is_enabled,
|
||||
)
|
||||
self.db_session.add(model)
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(model)
|
||||
logger.info(f'Created verified model: {provider}/{model_name}')
|
||||
return verified_model(model)
|
||||
|
||||
async def update_verified_model(
|
||||
self,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
is_enabled: bool | None = None,
|
||||
) -> VerifiedModel | None:
|
||||
"""Update an existing verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to update
|
||||
provider: The provider name
|
||||
is_enabled: New enabled state (optional)
|
||||
|
||||
Returns:
|
||||
The updated model if found, None otherwise
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
if is_enabled is not None:
|
||||
model.is_enabled = is_enabled
|
||||
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(model)
|
||||
logger.info(f'Updated verified model: {provider}/{model_name}')
|
||||
return verified_model(model)
|
||||
|
||||
async def delete_verified_model(self, model_name: str, provider: str):
|
||||
"""Delete a verified model.
|
||||
|
||||
Args:
|
||||
model_name: The model name to delete
|
||||
provider: The provider name
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
query = select(StoredVerifiedModel).where(
|
||||
and_(
|
||||
StoredVerifiedModel.model_name == model_name,
|
||||
StoredVerifiedModel.provider == provider,
|
||||
)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
model = result.scalars().first()
|
||||
if not model:
|
||||
raise ValueError('Unknown model')
|
||||
|
||||
await self.db_session.delete(model)
|
||||
await self.db_session.commit()
|
||||
logger.info(f'Deleted verified model: {provider}/{model_name}')
|
||||
|
||||
|
||||
def verified_model_store_dependency(db_session: AsyncSession = depends_db_session()):
|
||||
return VerifiedModelService(db_session)
|
||||
@@ -20,8 +20,10 @@ from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
@@ -65,8 +67,10 @@ __all__ = [
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgInvitation',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'ResendSyncedUser',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
|
||||
@@ -126,7 +126,7 @@ class ApiKeyStore:
|
||||
|
||||
return True
|
||||
|
||||
async def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all API keys for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
@@ -134,24 +134,14 @@ class ApiKeyStore:
|
||||
|
||||
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
|
||||
with self.session_maker() as session:
|
||||
keys = (
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'id': key.id,
|
||||
'name': key.name,
|
||||
'created_at': key.created_at,
|
||||
'last_used_at': key.last_used_at,
|
||||
'expires_at': key.expires_at,
|
||||
}
|
||||
for key in keys
|
||||
if 'MCP_API_KEY' != key.name
|
||||
]
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
|
||||
@@ -4,7 +4,9 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
@@ -12,6 +14,14 @@ from storage.database import a_session_maker
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
# Time buffer (in seconds) before actual expiration to consider token expired
|
||||
# This ensures tokens are refreshed before they actually expire. The
|
||||
# github default is 8 hours, so 15 minutes leeway is ~3% of this.
|
||||
ACCESS_TOKEN_EXPIRY_BUFFER = 900 # 15 minutes
|
||||
|
||||
# Database lock timeout to prevent indefinite blocking
|
||||
LOCK_TIMEOUT_SECONDS = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthTokenStore:
|
||||
@@ -23,6 +33,31 @@ class AuthTokenStore:
|
||||
def identity_provider_value(self) -> str:
|
||||
return self.idp.value
|
||||
|
||||
def _is_token_expired(
|
||||
self, access_token_expires_at: int, refresh_token_expires_at: int
|
||||
) -> tuple[bool, bool]:
|
||||
"""Check if access and refresh tokens are expired.
|
||||
|
||||
Args:
|
||||
access_token_expires_at: Expiration time for access token (seconds since epoch)
|
||||
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
|
||||
|
||||
Returns:
|
||||
Tuple of (access_expired, refresh_expired)
|
||||
"""
|
||||
current_time = int(time.time())
|
||||
access_expired = (
|
||||
False
|
||||
if access_token_expires_at == 0
|
||||
else access_token_expires_at < current_time + ACCESS_TOKEN_EXPIRY_BUFFER
|
||||
)
|
||||
refresh_expired = (
|
||||
False
|
||||
if refresh_token_expires_at == 0
|
||||
else refresh_token_expires_at < current_time
|
||||
)
|
||||
return access_expired, refresh_expired
|
||||
|
||||
async def store_tokens(
|
||||
self,
|
||||
access_token: str,
|
||||
@@ -73,87 +108,149 @@ class AuthTokenStore:
|
||||
]
|
||||
| None = None,
|
||||
) -> Dict[str, str | int] | None:
|
||||
"""
|
||||
Load authentication tokens from the database and refresh them if necessary.
|
||||
"""Load authentication tokens from the database and refresh them if necessary.
|
||||
|
||||
This method retrieves the current authentication tokens for the user and checks if they have expired.
|
||||
It uses the provided `check_expiration_and_refresh` function to determine if the tokens need
|
||||
to be refreshed and to refresh the tokens if needed.
|
||||
This method uses a double-checked locking pattern to minimize lock contention:
|
||||
1. First, check if the token is valid WITHOUT acquiring a lock (fast path)
|
||||
2. If refresh is needed, acquire a lock with a timeout
|
||||
3. Double-check if refresh is still needed (another request may have refreshed)
|
||||
4. Perform the refresh if still needed
|
||||
|
||||
The method ensures that only one refresh operation is performed per refresh token by using a
|
||||
row-level lock on the token record.
|
||||
|
||||
The method is designed to handle race conditions where multiple requests might attempt to refresh
|
||||
the same token simultaneously, ensuring that only one refresh call occurs per refresh token.
|
||||
The row-level lock ensures that only one refresh operation is performed per
|
||||
refresh token, which is important because most IDPs invalidate the old refresh
|
||||
token after it's used once.
|
||||
|
||||
Args:
|
||||
check_expiration_and_refresh (Callable, optional): A function that checks if the tokens have expired
|
||||
and attempts to refresh them. It should return a dictionary containing the new access_token, refresh_token,
|
||||
and their respective expiration timestamps. If no refresh is needed, it should return `None`.
|
||||
check_expiration_and_refresh: A function that checks if the tokens have
|
||||
expired and attempts to refresh them. It should return a dictionary
|
||||
containing the new access_token, refresh_token, and their respective
|
||||
expiration timestamps. If no refresh is needed, it should return None.
|
||||
|
||||
Returns:
|
||||
Dict[str, str | int] | None:
|
||||
A dictionary containing the access_token, refresh_token, access_token_expires_at,
|
||||
and refresh_token_expires_at. If no token record is found, returns `None`.
|
||||
A dictionary containing the access_token, refresh_token,
|
||||
access_token_expires_at, and refresh_token_expires_at.
|
||||
If no token record is found, returns None.
|
||||
|
||||
Raises:
|
||||
TokenRefreshError: If the lock cannot be acquired within the timeout
|
||||
period. This typically means another request is holding the lock
|
||||
for an extended period. Callers should handle this by returning
|
||||
a 401 response to prompt the user to re-authenticate.
|
||||
"""
|
||||
# FAST PATH: Check without lock first to avoid unnecessary lock contention
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin(): # Ensures transaction management
|
||||
# Lock the row while we check if we need to refresh the tokens.
|
||||
# There is a race condition where 2 or more calls can load tokens simultaneously.
|
||||
# If it turns out the loaded tokens are expired, then there will be multiple
|
||||
# refresh token calls with the same refresh token. Most IDPs only allow one refresh
|
||||
# per refresh token. This lock ensure that only one refresh call occurs per refresh token
|
||||
result = await session.execute(
|
||||
select(AuthTokens)
|
||||
.filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
AuthTokens.identity_provider == self.identity_provider_value,
|
||||
)
|
||||
.with_for_update()
|
||||
result = await session.execute(
|
||||
select(AuthTokens).filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
AuthTokens.identity_provider == self.identity_provider_value,
|
||||
)
|
||||
token_record = result.scalars().one_or_none()
|
||||
)
|
||||
token_record = result.scalars().one_or_none()
|
||||
|
||||
if not token_record:
|
||||
return None
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
token_refresh = (
|
||||
await check_expiration_and_refresh(
|
||||
# Check if token needs refresh
|
||||
access_expired, _ = self._is_token_expired(
|
||||
token_record.access_token_expires_at,
|
||||
token_record.refresh_token_expires_at,
|
||||
)
|
||||
|
||||
# If token is still valid, return it without acquiring a lock
|
||||
if not access_expired or check_expiration_and_refresh is None:
|
||||
return {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
|
||||
# SLOW PATH: Token needs refresh, acquire lock
|
||||
try:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Set a lock timeout to prevent indefinite blocking
|
||||
# This ensures we don't hold connections forever if something goes wrong
|
||||
await session.execute(
|
||||
text(f"SET LOCAL lock_timeout = '{LOCK_TIMEOUT_SECONDS}s'")
|
||||
)
|
||||
|
||||
# Acquire row-level lock to prevent concurrent refresh attempts
|
||||
result = await session.execute(
|
||||
select(AuthTokens)
|
||||
.filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
AuthTokens.identity_provider
|
||||
== self.identity_provider_value,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
token_record = result.scalars().one_or_none()
|
||||
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
# Double-check: another request may have refreshed while we waited for the lock
|
||||
access_expired, _ = self._is_token_expired(
|
||||
token_record.access_token_expires_at,
|
||||
token_record.refresh_token_expires_at,
|
||||
)
|
||||
|
||||
if not access_expired:
|
||||
# Token was refreshed by another request while we waited
|
||||
logger.debug(
|
||||
'Token was refreshed by another request while waiting for lock'
|
||||
)
|
||||
return {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
|
||||
# We're the one doing the refresh
|
||||
token_refresh = await check_expiration_and_refresh(
|
||||
self.idp,
|
||||
token_record.refresh_token,
|
||||
token_record.access_token_expires_at,
|
||||
token_record.refresh_token_expires_at,
|
||||
)
|
||||
if check_expiration_and_refresh
|
||||
else None
|
||||
)
|
||||
|
||||
if token_refresh:
|
||||
await session.execute(
|
||||
update(AuthTokens)
|
||||
.where(AuthTokens.id == token_record.id)
|
||||
.values(
|
||||
access_token=token_refresh['access_token'],
|
||||
refresh_token=token_refresh['refresh_token'],
|
||||
access_token_expires_at=token_refresh[
|
||||
'access_token_expires_at'
|
||||
],
|
||||
refresh_token_expires_at=token_refresh[
|
||||
'refresh_token_expires_at'
|
||||
],
|
||||
if token_refresh:
|
||||
await session.execute(
|
||||
update(AuthTokens)
|
||||
.where(AuthTokens.id == token_record.id)
|
||||
.values(
|
||||
access_token=token_refresh['access_token'],
|
||||
refresh_token=token_refresh['refresh_token'],
|
||||
access_token_expires_at=token_refresh[
|
||||
'access_token_expires_at'
|
||||
],
|
||||
refresh_token_expires_at=token_refresh[
|
||||
'refresh_token_expires_at'
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
await session.commit()
|
||||
|
||||
return (
|
||||
token_refresh
|
||||
if token_refresh
|
||||
else {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
)
|
||||
return (
|
||||
token_refresh
|
||||
if token_refresh
|
||||
else {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
)
|
||||
except OperationalError as e:
|
||||
# Lock timeout - another request is holding the lock for too long
|
||||
logger.warning(
|
||||
f'Token refresh lock timeout for user {self.keycloak_user_id}: {e}'
|
||||
)
|
||||
raise TokenRefreshError(
|
||||
'Unable to refresh token due to lock timeout. Please try again.'
|
||||
) from e
|
||||
|
||||
async def is_access_token_valid(self) -> bool:
|
||||
"""Check if the access token is still valid.
|
||||
@@ -194,8 +291,8 @@ class AuthTokenStore:
|
||||
"""Get an instance of the AuthTokenStore.
|
||||
|
||||
Args:
|
||||
config: The application configuration
|
||||
keycloak_user_id: The Keycloak user ID
|
||||
idp: The identity provider type
|
||||
|
||||
Returns:
|
||||
An instance of AuthTokenStore
|
||||
|
||||
@@ -18,17 +18,17 @@ def _get_db_session_injector():
|
||||
return _config.db_session
|
||||
|
||||
|
||||
def session_maker():
|
||||
def session_maker(**kwargs):
|
||||
db_session_injector = _get_db_session_injector()
|
||||
session_maker = db_session_injector.get_session_maker()
|
||||
return session_maker()
|
||||
factory = db_session_injector.get_session_maker()
|
||||
return factory(**kwargs)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def a_session_maker():
|
||||
async def a_session_maker(**kwargs):
|
||||
db_session_injector = _get_db_session_injector()
|
||||
a_session_maker = await db_session_injector.get_async_session_maker()
|
||||
async with a_session_maker() as session:
|
||||
factory = await db_session_injector.get_async_session_maker()
|
||||
async with factory(**kwargs) as session:
|
||||
yield session
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
@@ -24,16 +23,54 @@ from storage.user_settings import UserSettings
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
# Timeout in seconds for key verification requests to LiteLLM
|
||||
KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
|
||||
def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for OpenHands Cloud managed keys."""
|
||||
return f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}'
|
||||
|
||||
|
||||
def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for BYOR (Bring Your Own Runtime) keys."""
|
||||
return f'BYOR Key - user {keycloak_user_id}, org {org_id}'
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
def get_budget_from_team_info(
|
||||
user_team_info: dict | None, user_id: str, org_id: str
|
||||
) -> tuple[float, float]:
|
||||
"""Extract max_budget and spend from user team info.
|
||||
|
||||
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
|
||||
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
|
||||
|
||||
Args:
|
||||
user_team_info: The response from get_user_team_info
|
||||
user_id: The user's ID
|
||||
org_id: The organization's ID
|
||||
|
||||
Returns:
|
||||
Tuple of (max_budget, spend)
|
||||
"""
|
||||
if not user_team_info:
|
||||
return 0, 0
|
||||
spend = user_team_info.get('spend', 0)
|
||||
if user_id == org_id:
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
else:
|
||||
max_budget = user_team_info.get('max_budget_in_team') or 0
|
||||
return max_budget, spend
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
@@ -62,8 +99,33 @@ class LiteLlmManager:
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Check if team already exists and get its budget
|
||||
# New users joining existing orgs should inherit the team's budget
|
||||
team_budget = 0.0
|
||||
try:
|
||||
existing_team = await LiteLlmManager._get_team(client, org_id)
|
||||
if existing_team:
|
||||
team_info = existing_team.get('team_info', {})
|
||||
team_budget = team_info.get('max_budget', 0.0) or 0.0
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:existing_team_budget',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_budget': team_budget,
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Team doesn't exist yet (404) - this is expected for first user
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:no_existing_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
)
|
||||
|
||||
if create_user:
|
||||
@@ -72,14 +134,14 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
|
||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -251,7 +313,7 @@ class LiteLlmManager:
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
|
||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
||||
None,
|
||||
)
|
||||
if new_key:
|
||||
@@ -884,21 +946,31 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
team_response = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_response:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
for membership in team_response.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not user_membership:
|
||||
return None
|
||||
|
||||
# For team orgs (user_id != team_id), include team-level budget info
|
||||
# The team's max_budget and spend are shared across all members
|
||||
if keycloak_user_id != team_id:
|
||||
team_info = team_response.get('team_info', {})
|
||||
user_membership['max_budget_in_team'] = team_info.get('max_budget')
|
||||
user_membership['spend'] = team_info.get('spend', 0)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
@@ -1044,7 +1116,7 @@ class LiteLlmManager:
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
timeout=KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
@@ -1058,7 +1130,7 @@ class LiteLlmManager:
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
'Key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
@@ -1066,7 +1138,7 @@ class LiteLlmManager:
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
'Key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
@@ -1079,7 +1151,7 @@ class LiteLlmManager:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
'Key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
@@ -1123,6 +1195,103 @@ class LiteLlmManager:
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_all_keys_for_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
) -> list[dict]:
|
||||
"""Get all keys for a user from LiteLLM.
|
||||
|
||||
Returns a list of key info dictionaries containing:
|
||||
- token: the key value (hashed or partial)
|
||||
- key_alias: the alias for the key
|
||||
- key_name: the name of the key
|
||||
- spend: the amount spent on this key
|
||||
- max_budget: the max budget for this key
|
||||
- team_id: the team the key belongs to
|
||||
- metadata: any metadata associated with the key
|
||||
|
||||
Returns an empty list if no keys found or on error.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={keycloak_user_id}',
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_json = response.json()
|
||||
# The user/info endpoint returns keys in the 'keys' field
|
||||
return user_json.get('keys', [])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'LiteLlmManager:_get_all_keys_for_user:error',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _verify_existing_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_value: str,
|
||||
keycloak_user_id: str,
|
||||
org_id: str,
|
||||
openhands_type: bool = False,
|
||||
) -> bool:
|
||||
"""Check if an existing key exists for the user/org in LiteLLM.
|
||||
|
||||
Verifies the provided key_value matches a key registered in LiteLLM for
|
||||
the given user and organization. For openhands_type=True, looks for keys
|
||||
with metadata type='openhands' and matching team_id. For openhands_type=False,
|
||||
looks for keys with matching alias and team_id.
|
||||
|
||||
Returns True if the key is found and valid, False otherwise.
|
||||
"""
|
||||
found = False
|
||||
keys = await LiteLlmManager._get_all_keys_for_user(client, keycloak_user_id)
|
||||
for key_info in keys:
|
||||
metadata = key_info.get('metadata') or {}
|
||||
team_id = key_info.get('team_id')
|
||||
key_alias = key_info.get('key_alias')
|
||||
token = None
|
||||
if (
|
||||
openhands_type
|
||||
and metadata.get('type') == 'openhands'
|
||||
and team_id == org_id
|
||||
):
|
||||
# Found an existing OpenHands key for this org
|
||||
key_name = key_info.get('key_name')
|
||||
token = key_name[-4:] if key_name else None # last 4 digits of key
|
||||
if token and key_value.endswith(
|
||||
token
|
||||
): # check if this is our current key
|
||||
found = True
|
||||
break
|
||||
if (
|
||||
not openhands_type
|
||||
and team_id == org_id
|
||||
and (
|
||||
key_alias == get_openhands_cloud_key_alias(keycloak_user_id, org_id)
|
||||
or key_alias == get_byor_key_alias(keycloak_user_id, org_id)
|
||||
)
|
||||
):
|
||||
# Found an existing key for this org (regardless of type)
|
||||
key_name = key_info.get('key_name')
|
||||
token = key_name[-4:] if key_name else None # last 4 digits of key
|
||||
if token and key_value.endswith(
|
||||
token
|
||||
): # check if this is our current key
|
||||
found = True
|
||||
break
|
||||
|
||||
return found
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key_by_alias(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -1220,6 +1389,8 @@ class LiteLlmManager:
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
verify_existing_key = staticmethod(with_http_client(_verify_existing_key))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
get_user_keys = staticmethod(with_http_client(_get_user_keys))
|
||||
delete_key_by_alias = staticmethod(with_http_client(_delete_key_by_alias))
|
||||
update_user_keys = staticmethod(with_http_client(_update_user_keys))
|
||||
|
||||
@@ -46,10 +46,14 @@ class Org(Base): # type: ignore
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
byor_export_enabled = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
invitations = relationship(
|
||||
'OrgInvitation', back_populates='org', passive_deletes=True
|
||||
)
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
|
||||
105
enterprise/storage/org_app_settings_store.py
Normal file
105
enterprise/storage/org_app_settings_store.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Store class for managing organization app settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import (
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.routes.org_models import OrgAppSettingsUpdate
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgAppSettingsStore:
|
||||
"""Store for organization app settings with injected db_session."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def get_current_org_by_user_id(self, user_id: str) -> Org | None:
|
||||
"""Get the current organization for a user.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
|
||||
Returns:
|
||||
Org: The organization object, or None if not found
|
||||
"""
|
||||
# Get user with their current_org_id
|
||||
result = await self.db_session.execute(
|
||||
select(User).filter(User.id == UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
# Get the organization
|
||||
result = await self.db_session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
return await self._validate_org_version(org)
|
||||
|
||||
async def _validate_org_version(self, org: Org) -> Org:
|
||||
"""Check if we need to update org version.
|
||||
|
||||
Args:
|
||||
org: The organization to validate
|
||||
|
||||
Returns:
|
||||
Org: The validated (and potentially updated) organization
|
||||
"""
|
||||
if org.org_version < ORG_SETTINGS_VERSION:
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
org.llm_base_url = LITE_LLM_API_URL
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(org)
|
||||
|
||||
return org
|
||||
|
||||
async def update_org_app_settings(
|
||||
self, org_id: UUID, update_data: OrgAppSettingsUpdate
|
||||
) -> Org | None:
|
||||
"""Update organization app settings.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
Uses flush() - commit happens at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
org_id: The organization's ID
|
||||
update_data: Pydantic model with fields to update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization object, or None if not found
|
||||
"""
|
||||
result = await self.db_session.execute(
|
||||
select(Org).filter(Org.id == org_id).with_for_update()
|
||||
)
|
||||
org = result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
# Update only explicitly provided fields
|
||||
for field, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(org, field, value)
|
||||
|
||||
# flush instead of commit - DbSessionInjector auto-commits at request end
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(org)
|
||||
return org
|
||||
59
enterprise/storage/org_invitation.py
Normal file
59
enterprise/storage/org_invitation.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization Invitation.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class OrgInvitation(Base): # type: ignore
|
||||
"""Organization invitation model.
|
||||
|
||||
Represents an invitation for a user to join an organization.
|
||||
Invitations are created by organization owners/admins and contain
|
||||
a secure token that can be used to accept the invitation.
|
||||
"""
|
||||
|
||||
__tablename__ = 'org_invitation'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
token = Column(String(64), nullable=False, unique=True, index=True)
|
||||
org_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('org.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
email = Column(String(255), nullable=False, index=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
status = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
server_default=text("'pending'"),
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
accepted_at = Column(DateTime, nullable=True)
|
||||
accepted_by_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('user.id'),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='invitations')
|
||||
role = relationship('Role')
|
||||
inviter = relationship('User', foreign_keys=[inviter_id])
|
||||
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
|
||||
|
||||
# Status constants
|
||||
STATUS_PENDING = 'pending'
|
||||
STATUS_ACCEPTED = 'accepted'
|
||||
STATUS_REVOKED = 'revoked'
|
||||
STATUS_EXPIRED = 'expired'
|
||||
227
enterprise/storage/org_invitation_store.py
Normal file
227
enterprise/storage/org_invitation_store.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Store class for managing organization invitations.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Invitation token configuration
|
||||
INVITATION_TOKEN_PREFIX = 'inv-'
|
||||
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
|
||||
DEFAULT_EXPIRATION_DAYS = 7
|
||||
|
||||
|
||||
class OrgInvitationStore:
|
||||
"""Store for managing organization invitations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
|
||||
"""Generate a secure invitation token.
|
||||
|
||||
Uses cryptographically secure random generation for tokens.
|
||||
Pattern from api_key_store.py.
|
||||
|
||||
Args:
|
||||
length: Length of the random part of the token
|
||||
|
||||
Returns:
|
||||
str: Token with prefix (e.g., 'inv-aBcDeF123...')
|
||||
"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_id: int,
|
||||
inviter_id: UUID,
|
||||
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_id: Role ID to assign on acceptance
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
expiration_days: Days until the invitation expires
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation record
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
|
||||
|
||||
invitation = OrgInvitation(
|
||||
token=token,
|
||||
org_id=org_id,
|
||||
email=email.lower().strip(),
|
||||
role_id=role_id,
|
||||
inviter_id=inviter_id,
|
||||
status=OrgInvitation.STATUS_PENDING,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(invitation)
|
||||
await session.commit()
|
||||
|
||||
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.id == invitation.id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
logger.info(
|
||||
'Created organization invitation',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'inviter_id': str(inviter_id),
|
||||
'expires_at': expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
|
||||
"""Get an invitation by its token.
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.token == token)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_invitation(
|
||||
org_id: UUID, email: str
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Get a pending invitation for an email in an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Email address to check
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if no pending invitation exists
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(
|
||||
and_(
|
||||
OrgInvitation.org_id == org_id,
|
||||
OrgInvitation.email == email.lower().strip(),
|
||||
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def update_invitation_status(
|
||||
invitation_id: int,
|
||||
status: str,
|
||||
accepted_by_user_id: Optional[UUID] = None,
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Update an invitation's status.
|
||||
|
||||
Args:
|
||||
invitation_id: The invitation ID
|
||||
status: New status (pending, accepted, revoked, expired)
|
||||
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
|
||||
|
||||
Returns:
|
||||
Updated OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
if not invitation:
|
||||
return None
|
||||
|
||||
old_status = invitation.status
|
||||
invitation.status = status
|
||||
|
||||
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
invitation.accepted_at = datetime.utcnow()
|
||||
invitation.accepted_by_user_id = accepted_by_user_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(invitation)
|
||||
|
||||
logger.info(
|
||||
'Updated invitation status',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'old_status': old_status,
|
||||
'new_status': status,
|
||||
'accepted_by_user_id': (
|
||||
str(accepted_by_user_id) if accepted_by_user_id else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
def is_token_expired(invitation: OrgInvitation) -> bool:
|
||||
"""Check if an invitation token has expired.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if expired, False otherwise
|
||||
"""
|
||||
# Use timezone-naive datetime for comparison (database stores without timezone)
|
||||
now = datetime.utcnow()
|
||||
return invitation.expires_at < now
|
||||
|
||||
@staticmethod
|
||||
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
|
||||
"""Check if invitation is expired and update status if needed.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if invitation was marked as expired, False otherwise
|
||||
"""
|
||||
if (
|
||||
invitation.status == OrgInvitation.STATUS_PENDING
|
||||
and OrgInvitationStore.is_token_expired(invitation)
|
||||
):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
return True
|
||||
return False
|
||||
83
enterprise/storage/org_llm_settings_store.py
Normal file
83
enterprise/storage/org_llm_settings_store.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Store class for managing organization LLM settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.org import Org
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgLLMSettingsStore:
|
||||
"""Store for org LLM settings with injected db_session."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def get_current_org_by_user_id(self, user_id: str) -> Org | None:
|
||||
"""Get the user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
|
||||
Returns:
|
||||
Org: The user's current organization, or None if not found
|
||||
"""
|
||||
# First get the user to find their current_org_id
|
||||
result = await self.db_session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user or not user.current_org_id:
|
||||
return None
|
||||
|
||||
# Then get the org
|
||||
result = await self.db_session.execute(
|
||||
select(Org).filter(Org.id == user.current_org_id)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_org_llm_settings(
|
||||
self, org_id: UUID, update_data: OrgLLMSettingsUpdate
|
||||
) -> Org | None:
|
||||
"""Update organization LLM settings.
|
||||
|
||||
Also propagates relevant settings to all org members.
|
||||
Uses flush() - commit happens at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
org_id: The organization's ID
|
||||
update_data: Pydantic model with fields to update
|
||||
|
||||
Returns:
|
||||
Org: The updated organization, or None if org not found
|
||||
"""
|
||||
result = await self.db_session.execute(
|
||||
select(Org).filter(Org.id == org_id).with_for_update()
|
||||
)
|
||||
org = result.scalars().first()
|
||||
|
||||
if not org:
|
||||
return None
|
||||
|
||||
# Apply updates to org (excludes llm_api_key which is member-only)
|
||||
update_data.apply_to_org(org)
|
||||
|
||||
# Propagate relevant settings to all org members
|
||||
member_updates = update_data.get_member_updates()
|
||||
if member_updates:
|
||||
await OrgMemberStore.update_all_members_llm_settings_async(
|
||||
self.db_session, org_id, member_updates
|
||||
)
|
||||
|
||||
# flush instead of commit - DbSessionInjector auto-commits at request end
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(org)
|
||||
return org
|
||||
@@ -5,8 +5,14 @@ Store class for managing organization-member relationships.
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from server.routes.org_models import OrgMemberLLMSettings
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import encrypt_value
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -38,7 +44,7 @@ class OrgMemberStore:
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
@@ -48,7 +54,63 @@ class OrgMemberStore:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
async def get_org_member_async(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember).filter(
|
||||
OrgMember.org_id == org_id, OrgMember.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
result = (
|
||||
session.query(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_org_member_for_current_org_async(
|
||||
user_id: UUID,
|
||||
) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
@@ -68,7 +130,7 @@ class OrgMemberStore:
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
org_id: UUID, user_id: UUID, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
@@ -90,7 +152,7 @@ class OrgMemberStore:
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
@@ -123,3 +185,100 @@ class OrgMemberStore:
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_count(
|
||||
org_id: UUID,
|
||||
email_filter: str | None = None,
|
||||
) -> int:
|
||||
"""Get total count of organization members, optionally filtered by email.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
Total count of matching members.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
query = select(func.count(OrgMember.user_id)).filter(
|
||||
OrgMember.org_id == org_id
|
||||
)
|
||||
|
||||
if email_filter:
|
||||
query = query.join(User, User.id == OrgMember.user_id).filter(
|
||||
User.email.ilike(f'%{email_filter}%')
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_paginated(
|
||||
org_id: UUID,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
email_filter: str | None = None,
|
||||
) -> tuple[list[OrgMember], bool]:
|
||||
"""Get paginated list of organization members with user and role info.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID.
|
||||
offset: Number of records to skip.
|
||||
limit: Maximum number of records to return.
|
||||
email_filter: Optional case-insensitive partial email match.
|
||||
|
||||
Returns:
|
||||
Tuple of (members_list, has_more) where has_more indicates if there are more results.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
# Query for limit + 1 items to determine if there are more results
|
||||
# Order by user_id for consistent pagination
|
||||
query = (
|
||||
select(OrgMember)
|
||||
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(OrgMember.org_id == org_id)
|
||||
)
|
||||
|
||||
# Apply email filter if provided
|
||||
if email_filter:
|
||||
query = query.filter(User.email.ilike(f'%{email_filter}%'))
|
||||
|
||||
query = query.order_by(OrgMember.user_id).offset(offset).limit(limit + 1)
|
||||
|
||||
result = await session.execute(query)
|
||||
members = list(result.unique().scalars().all())
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(members) > limit
|
||||
if has_more:
|
||||
# Remove the extra item
|
||||
members = members[:limit]
|
||||
|
||||
return members, has_more
|
||||
|
||||
@staticmethod
|
||||
async def update_all_members_llm_settings_async(
|
||||
session: AsyncSession,
|
||||
org_id: UUID,
|
||||
member_settings: OrgMemberLLMSettings,
|
||||
) -> None:
|
||||
"""Update LLM settings for all members of an organization.
|
||||
|
||||
Args:
|
||||
session: Database session (passed from caller for transaction)
|
||||
org_id: Organization ID
|
||||
member_settings: Typed LLM settings to apply to all members
|
||||
"""
|
||||
# Build update values from non-None fields
|
||||
values = member_settings.model_dump(exclude_none=True)
|
||||
|
||||
# Handle encrypted llm_api_key field - map to _llm_api_key column with encryption
|
||||
if 'llm_api_key' in values:
|
||||
raw_key = values.pop('llm_api_key')
|
||||
values['_llm_api_key'] = encrypt_value(raw_key)
|
||||
|
||||
if values:
|
||||
stmt = update(OrgMember).where(OrgMember.org_id == org_id).values(**values)
|
||||
await session.execute(stmt)
|
||||
|
||||
@@ -521,6 +521,7 @@ class OrgService:
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
|
||||
OrgNameExistsError: If new name already exists for another organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -550,6 +551,24 @@ class OrgService:
|
||||
'User must be a member of the organization to update it'
|
||||
)
|
||||
|
||||
# Check if name is being updated and validate uniqueness
|
||||
if update_data.name is not None:
|
||||
# Check if new name conflicts with another org
|
||||
existing_org_with_name = OrgStore.get_org_by_name(update_data.name)
|
||||
if (
|
||||
existing_org_with_name is not None
|
||||
and existing_org_with_name.id != org_id
|
||||
):
|
||||
logger.warning(
|
||||
'Attempted to update organization with duplicate name',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'attempted_name': update_data.name,
|
||||
},
|
||||
)
|
||||
raise OrgNameExistsError(update_data.name)
|
||||
|
||||
# Check if update contains any LLM settings
|
||||
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
|
||||
if llm_fields_being_updated:
|
||||
@@ -637,10 +656,9 @@ class OrgService:
|
||||
)
|
||||
return None
|
||||
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(org_id)
|
||||
)
|
||||
spend = user_team_info.get('spend', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
|
||||
logger.debug(
|
||||
@@ -842,3 +860,94 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def check_byor_export_enabled(user_id: str) -> bool:
|
||||
"""Check if BYOR export is enabled for the user's current org.
|
||||
|
||||
Returns True if the user's current org has byor_export_enabled set to True.
|
||||
Returns False if the user is not found, has no current org, or the flag is False.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
|
||||
Returns:
|
||||
bool: True if BYOR export is enabled, False otherwise
|
||||
"""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
org = OrgStore.get_org_by_id(user.current_org_id)
|
||||
if not org:
|
||||
return False
|
||||
|
||||
return org.byor_export_enabled
|
||||
|
||||
@staticmethod
|
||||
async def switch_org(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Switch user's current organization to the specified organization.
|
||||
|
||||
This method:
|
||||
1. Validates that the organization exists
|
||||
2. Validates that the user is a member of the organization
|
||||
3. Updates the user's current_org_id
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID to switch to
|
||||
|
||||
Returns:
|
||||
Org: The organization that was switched to
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not a member of the organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Step 2: Validate user is a member of the organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'User attempted to switch to organization they are not a member of',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgAuthorizationError(
|
||||
'User must be a member of the organization to switch to it'
|
||||
)
|
||||
|
||||
# Step 3: Update user's current_org_id
|
||||
try:
|
||||
updated_user = UserStore.update_current_org(user_id, org_id)
|
||||
if not updated_user:
|
||||
raise OrgDatabaseError('User not found')
|
||||
|
||||
logger.info(
|
||||
'Successfully switched user organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except OrgDatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to switch user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to switch organization: {str(e)}')
|
||||
|
||||
@@ -10,9 +10,10 @@ from server.constants import (
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate, OrphanedUserError
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
@@ -320,17 +321,41 @@ class OrgStore:
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Delete organization memberships
|
||||
# 3. Handle users with this as current_org_id BEFORE deleting memberships
|
||||
# Single query to find orphaned users (those with no alternative org)
|
||||
orphaned_users = session.execute(
|
||||
text("""
|
||||
SELECT u.id
|
||||
FROM "user" u
|
||||
WHERE u.current_org_id = :org_id
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
).fetchall()
|
||||
|
||||
if orphaned_users:
|
||||
raise OrphanedUserError([str(row[0]) for row in orphaned_users])
|
||||
|
||||
# Batch update: reassign current_org_id to an alternative org for all affected users
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
text("""
|
||||
UPDATE "user" u
|
||||
SET current_org_id = (
|
||||
SELECT om.org_id FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE u.current_org_id = :org_id
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Handle users with this as current_org_id
|
||||
# 4. Delete organization memberships (now safe)
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
|
||||
),
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
@@ -361,3 +386,47 @@ class OrgStore:
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_org_by_id_async(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID (async version)."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
return OrgStore._validate_org_version(org) if org else None
|
||||
|
||||
@staticmethod
|
||||
async def update_org_llm_settings_async(
|
||||
org_id: UUID,
|
||||
llm_settings: OrgLLMSettingsUpdate,
|
||||
) -> Org | None:
|
||||
"""Update organization LLM settings and propagate to members (async version).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
llm_settings: Typed LLM settings update model
|
||||
|
||||
Returns:
|
||||
Updated Org or None if not found
|
||||
"""
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Org).filter(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
# Apply updates to org
|
||||
llm_settings.apply_to_org(org)
|
||||
|
||||
# Propagate relevant settings to all org members
|
||||
member_updates = llm_settings.get_member_updates()
|
||||
if member_updates:
|
||||
await OrgMemberStore.update_all_members_llm_settings_async(
|
||||
session, org_id, member_updates
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
return org
|
||||
|
||||
35
enterprise/storage/resend_synced_user.py
Normal file
35
enterprise/storage/resend_synced_user.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""SQLAlchemy model for tracking users synced to Resend audiences."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class ResendSyncedUser(Base): # type: ignore
|
||||
"""Tracks users that have been synced to a Resend audience.
|
||||
|
||||
This table ensures that once a user is synced to a Resend audience,
|
||||
they won't be re-added even if they are later deleted from the
|
||||
Resend UI. This respects manual deletions/unsubscribes.
|
||||
"""
|
||||
|
||||
__tablename__ = 'resend_synced_users'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
email = Column(String, nullable=False, index=True)
|
||||
audience_id = Column(String, nullable=False, index=True)
|
||||
synced_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
keycloak_user_id = Column(String, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
125
enterprise/storage/resend_synced_user_store.py
Normal file
125
enterprise/storage/resend_synced_user_store.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Store class for managing Resend synced users."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Set
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResendSyncedUserStore:
|
||||
"""Store for tracking users synced to Resend audiences."""
|
||||
|
||||
session_maker: sessionmaker
|
||||
|
||||
def is_user_synced(self, email: str, audience_id: str) -> bool:
|
||||
"""Check if a user has been synced to a specific audience.
|
||||
|
||||
Args:
|
||||
email: The email address to check.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
True if the user has been synced, False otherwise.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = select(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt).first()
|
||||
return result is not None
|
||||
|
||||
def get_synced_emails_for_audience(self, audience_id: str) -> Set[str]:
|
||||
"""Get all synced email addresses for a specific audience.
|
||||
|
||||
Args:
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
A set of lowercase email addresses that have been synced.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = select(ResendSyncedUser.email).where(
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt).scalars().all()
|
||||
return set(result)
|
||||
|
||||
def mark_user_synced(
|
||||
self,
|
||||
email: str,
|
||||
audience_id: str,
|
||||
keycloak_user_id: Optional[str] = None,
|
||||
) -> ResendSyncedUser:
|
||||
"""Mark a user as synced to a specific audience.
|
||||
|
||||
Uses upsert to handle race conditions - if the user is already
|
||||
marked as synced, this is a no-op.
|
||||
|
||||
Args:
|
||||
email: The email address of the user.
|
||||
audience_id: The Resend audience ID.
|
||||
keycloak_user_id: Optional Keycloak user ID.
|
||||
|
||||
Returns:
|
||||
The ResendSyncedUser record.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the record could not be created or retrieved.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = (
|
||||
insert(ResendSyncedUser)
|
||||
.values(
|
||||
email=email.lower(),
|
||||
audience_id=audience_id,
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
synced_at=datetime.now(UTC),
|
||||
)
|
||||
.on_conflict_do_nothing(constraint='uq_resend_synced_email_audience')
|
||||
.returning(ResendSyncedUser)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0]
|
||||
|
||||
# on_conflict_do_nothing triggered, fetch the existing record
|
||||
existing = session.execute(
|
||||
select(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
return existing[0]
|
||||
|
||||
raise RuntimeError(
|
||||
f'Failed to create or retrieve synced user record for {email}'
|
||||
)
|
||||
|
||||
def remove_synced_user(self, email: str, audience_id: str) -> bool:
|
||||
"""Remove a user's synced status for a specific audience.
|
||||
|
||||
Args:
|
||||
email: The email address of the user.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
True if a record was deleted, False if no record existed.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = delete(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -29,6 +29,20 @@ class RoleStore:
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_id_async(
|
||||
role_id: int,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by ID (async version)."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
|
||||
@@ -8,10 +8,11 @@ from dataclasses import dataclass
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import SecretStr
|
||||
from server.constants import LITE_LLM_API_URL
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
@@ -23,6 +24,7 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -143,23 +145,28 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
|
||||
org: Org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if we need to generate an LLM key.
|
||||
if item.llm_base_url == LITE_LLM_API_URL:
|
||||
await self._ensure_api_key(
|
||||
item, str(org_id), openhands_type=is_openhands_model(item.llm_model)
|
||||
)
|
||||
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
@@ -223,41 +230,49 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
async def _ensure_api_key(
|
||||
self, item: Settings, org_id: str, openhands_type: bool = False
|
||||
) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key exists for the user and reuses it
|
||||
if found. Otherwise, generates a new key.
|
||||
First checks if an existing key exists for the user and verifies it
|
||||
is valid in LiteLLM. If valid, reuses it. Otherwise, generates a new key.
|
||||
"""
|
||||
# Check if user already has keys in LiteLLM
|
||||
existing_keys = await LiteLlmManager.get_user_keys(self.user_id)
|
||||
if existing_keys:
|
||||
logger.info(
|
||||
'saas_settings_store:store:user_already_has_keys',
|
||||
extra={'user_id': self.user_id, 'key_count': len(existing_keys)},
|
||||
)
|
||||
return
|
||||
|
||||
# Generate new key only if none exists
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
# First, check if our current key is valid
|
||||
if item.llm_api_key and not await LiteLlmManager.verify_existing_key(
|
||||
item.llm_api_key.get_secret_value(),
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
openhands_type=openhands_type,
|
||||
):
|
||||
generated_key = None
|
||||
if openhands_type:
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
else:
|
||||
# Must delete any existing key with the same alias first
|
||||
key_alias = get_openhands_cloud_key_alias(self.user_id, org_id)
|
||||
await LiteLlmManager.delete_key_by_alias(key_alias=key_alias)
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
key_alias,
|
||||
None,
|
||||
)
|
||||
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
64
enterprise/storage/user_app_settings_store.py
Normal file
64
enterprise/storage/user_app_settings_store.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Store class for managing user app settings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
|
||||
from server.routes.user_app_settings_models import UserAppSettingsUpdate
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserAppSettingsStore:
|
||||
"""Store for user app settings with injected db_session."""
|
||||
|
||||
db_session: AsyncSession
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
|
||||
Returns:
|
||||
User: The user object, or None if not found
|
||||
"""
|
||||
result = await self.db_session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def update_user_app_settings(
|
||||
self, user_id: str, update_data: UserAppSettingsUpdate
|
||||
) -> User | None:
|
||||
"""Update user app settings.
|
||||
|
||||
Only updates fields that are explicitly provided in update_data.
|
||||
Uses flush() - commit happens at request end via DbSessionInjector.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
update_data: Pydantic model with fields to update
|
||||
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
result = await self.db_session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id)).with_for_update()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Update only explicitly provided fields
|
||||
for field, value in update_data.model_dump(exclude_unset=True).items():
|
||||
setattr(user, field, value)
|
||||
|
||||
# flush instead of commit - DbSessionInjector auto-commits at request end
|
||||
await self.db_session.flush()
|
||||
await self.db_session.refresh(user)
|
||||
return user
|
||||
@@ -5,6 +5,7 @@ Store class for managing users.
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
@@ -27,6 +28,7 @@ from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
@@ -53,7 +55,8 @@ class UserStore:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_name=resolve_display_name(user_info)
|
||||
or user_info.get('preferred_username', ''),
|
||||
contact_email=user_info['email'],
|
||||
v1_enabled=True,
|
||||
)
|
||||
@@ -80,6 +83,8 @@ class UserStore:
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
user.email = user_info.get('email')
|
||||
user.email_verified = user_info.get('email_verified')
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
@@ -170,13 +175,28 @@ class UserStore:
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# Check if user has completed billing sessions to enable BYOR export
|
||||
from storage.billing_session import BillingSession
|
||||
|
||||
has_completed_billing = (
|
||||
session.query(BillingSession)
|
||||
.filter(
|
||||
BillingSession.user_id == user_id,
|
||||
BillingSession.status == 'completed',
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
org_version=user_settings.user_version,
|
||||
contact_name=user_info['username'],
|
||||
contact_name=resolve_display_name(user_info)
|
||||
or user_info.get('username', ''),
|
||||
contact_email=user_info['email'],
|
||||
byor_export_enabled=has_completed_billing,
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
@@ -750,12 +770,187 @@ class UserStore:
|
||||
finally:
|
||||
await UserStore._release_user_creation_lock(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_email_async(email: str) -> Optional[User]:
|
||||
"""Get user by email address (async version).
|
||||
|
||||
This method looks up a user by their email address. Note that email
|
||||
addresses may not be unique across all users in rare cases.
|
||||
|
||||
Args:
|
||||
email: The email address to search for
|
||||
|
||||
Returns:
|
||||
User: The user with the matching email, or None if not found
|
||||
"""
|
||||
if not email:
|
||||
return None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.email == email.lower().strip())
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
@staticmethod
|
||||
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
|
||||
"""Update the user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
org_id: The organization ID to set as current
|
||||
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.current_org_id = org_id
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
|
||||
"""Update contact_name on the personal org if it still has a username-style value.
|
||||
|
||||
Called during login to gradually fix existing users whose contact_name
|
||||
was stored as their username (before the resolve_display_name fix).
|
||||
Preserves custom values that were set via the PATCH endpoint.
|
||||
"""
|
||||
real_name = resolve_display_name(user_info)
|
||||
if not real_name:
|
||||
logger.debug(
|
||||
'backfill_contact_name:no_real_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
preferred_username = user_info.get('preferred_username', '')
|
||||
username = user_info.get('username', '')
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(Org).filter(Org.id == uuid.UUID(user_id))
|
||||
)
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
logger.debug(
|
||||
'backfill_contact_name:org_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
if org.contact_name in (preferred_username, username):
|
||||
logger.info(
|
||||
'backfill_contact_name:updated',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'old': org.contact_name,
|
||||
'new': real_name,
|
||||
},
|
||||
)
|
||||
org.contact_name = real_name
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def update_user_email(
|
||||
user_id: str,
|
||||
email: str | None = None,
|
||||
email_verified: bool | None = None,
|
||||
) -> None:
|
||||
"""Unconditionally update User.email and/or email_verified.
|
||||
|
||||
Unlike backfill_user_email(), this overwrites existing values.
|
||||
No-op when both arguments are None.
|
||||
Missing user is logged as a warning and ignored.
|
||||
"""
|
||||
if email is None and email_verified is None:
|
||||
return
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.warning(
|
||||
'update_user_email:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
if email is not None:
|
||||
user.email = email
|
||||
if email_verified is not None:
|
||||
user.email_verified = email_verified
|
||||
|
||||
logger.info(
|
||||
'update_user_email:updated',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'email_set': email is not None,
|
||||
'email_verified_set': email_verified is not None,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def backfill_user_email(user_id: str, user_info: dict) -> None:
|
||||
"""Set User.email and email_verified from IDP if they are still NULL.
|
||||
|
||||
Called during login to gradually fix existing users whose email
|
||||
was never persisted on the User record. Preserves non-NULL values
|
||||
(e.g. if a user manually changed their email).
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.debug(
|
||||
'backfill_user_email:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
updated = False
|
||||
if user.email is None:
|
||||
user.email = user_info.get('email')
|
||||
updated = True
|
||||
|
||||
if user.email_verified is None:
|
||||
user.email_verified = user_info.get('email_verified', False)
|
||||
updated = True
|
||||
|
||||
if updated:
|
||||
logger.info(
|
||||
'backfill_user_email:updated',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'email_set': user.email is not None,
|
||||
'email_verified_set': user.email_verified is not None,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import asyncio
|
||||
import asyncio # noqa: I001
|
||||
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
# This must be before the import of storage
|
||||
# to set up logging and prevent alembic from
|
||||
# running its mouth.
|
||||
from openhands.core.logger import openhands_logger
|
||||
|
||||
from storage.proactive_conversation_store import (
|
||||
ProactiveConversationStore,
|
||||
)
|
||||
|
||||
OLDER_THAN = 30 # 30 minutes
|
||||
|
||||
|
||||
async def main():
|
||||
openhands_logger.info('clean_proactive_convo_table')
|
||||
convo_store = ProactiveConversationStore()
|
||||
await convo_store.clean_old_convos(older_than_minutes=OLDER_THAN)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ Optional environment variables:
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -34,6 +35,7 @@ import resend
|
||||
from keycloak.exceptions import KeycloakError
|
||||
from resend.exceptions import ResendError
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.resend_synced_user_store import ResendSyncedUserStore
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -68,9 +70,6 @@ RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
|
||||
# Set up Resend API
|
||||
resend.api_key = RESEND_API_KEY
|
||||
|
||||
print('resend module', resend)
|
||||
print('has contacts', hasattr(resend, 'Contacts'))
|
||||
|
||||
|
||||
class ResendSyncError(Exception):
|
||||
"""Base exception for Resend sync errors."""
|
||||
@@ -90,6 +89,31 @@ class ResendAPIError(ResendSyncError):
|
||||
pass
|
||||
|
||||
|
||||
# Email validation regex pattern - matches standard email format
|
||||
# This pattern is intentionally strict to avoid Resend API validation errors
|
||||
# It rejects special characters like ! that some email providers technically allow
|
||||
# but Resend's API does not accept
|
||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
|
||||
|
||||
|
||||
def is_valid_email(email: Optional[str]) -> bool:
|
||||
"""Validate an email address format.
|
||||
|
||||
This uses a regex pattern that matches most valid email addresses
|
||||
while rejecting addresses with special characters that Resend's API
|
||||
does not accept (e.g., exclamation marks).
|
||||
|
||||
Args:
|
||||
email: The email address to validate, or None.
|
||||
|
||||
Returns:
|
||||
True if the email is valid, False otherwise (including for None).
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
return bool(EMAIL_REGEX.match(email))
|
||||
|
||||
|
||||
def get_keycloak_users(offset: int = 0, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get users from Keycloak using the admin client.
|
||||
|
||||
@@ -173,8 +197,6 @@ def get_resend_contacts(audience_id: str) -> Dict[str, Dict[str, Any]]:
|
||||
Raises:
|
||||
ResendAPIError: If the API call fails.
|
||||
"""
|
||||
print('getting resend contacts')
|
||||
print('has resend contacts', hasattr(resend, 'Contacts'))
|
||||
try:
|
||||
contacts = resend.Contacts.list(audience_id).get('data', [])
|
||||
# Create a dictionary mapping email addresses to contact data for
|
||||
@@ -229,6 +251,15 @@ def add_contact_to_resend(
|
||||
raise
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(MAX_RETRIES),
|
||||
wait=wait_exponential(
|
||||
multiplier=INITIAL_BACKOFF_SECONDS,
|
||||
max=MAX_BACKOFF_SECONDS,
|
||||
exp_base=BACKOFF_FACTOR,
|
||||
),
|
||||
retry=retry_if_exception_type(ResendError),
|
||||
)
|
||||
def send_welcome_email(
|
||||
email: str,
|
||||
first_name: Optional[str] = None,
|
||||
@@ -245,7 +276,7 @@ def send_welcome_email(
|
||||
The API response.
|
||||
|
||||
Raises:
|
||||
ResendError: If the API call fails.
|
||||
ResendError: If the API call fails after retries.
|
||||
"""
|
||||
try:
|
||||
# Prepare the recipient name
|
||||
@@ -291,8 +322,84 @@ def send_welcome_email(
|
||||
raise
|
||||
|
||||
|
||||
def _get_resend_synced_user_store() -> ResendSyncedUserStore:
|
||||
"""Get the ResendSyncedUserStore instance.
|
||||
|
||||
This is separated into a function to allow for easier testing/mocking.
|
||||
"""
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
config = get_global_config()
|
||||
db_session_injector = config.db_session
|
||||
return ResendSyncedUserStore(session_maker=db_session_injector.get_session_maker())
|
||||
|
||||
|
||||
def _backfill_existing_resend_contacts(
|
||||
synced_user_store: ResendSyncedUserStore,
|
||||
audience_id: str,
|
||||
) -> int:
|
||||
"""Backfill the synced_users table with contacts already in Resend.
|
||||
|
||||
This ensures that users who were added to Resend before the tracking
|
||||
table existed are properly recorded, preventing duplicate welcome emails.
|
||||
|
||||
Args:
|
||||
synced_user_store: The store for tracking synced users.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
The number of contacts backfilled.
|
||||
"""
|
||||
logger.info('Starting backfill of existing Resend contacts...')
|
||||
|
||||
try:
|
||||
resend_contacts = get_resend_contacts(audience_id)
|
||||
logger.info(f'Found {len(resend_contacts)} contacts in Resend audience')
|
||||
|
||||
already_synced_emails = synced_user_store.get_synced_emails_for_audience(
|
||||
audience_id
|
||||
)
|
||||
logger.info(
|
||||
f'Found {len(already_synced_emails)} already synced emails in database'
|
||||
)
|
||||
|
||||
backfilled_count = 0
|
||||
for email in resend_contacts:
|
||||
if email.lower() not in already_synced_emails:
|
||||
synced_user_store.mark_user_synced(
|
||||
email=email,
|
||||
audience_id=audience_id,
|
||||
keycloak_user_id=None, # We don't have this info during backfill
|
||||
)
|
||||
backfilled_count += 1
|
||||
logger.debug(f'Backfilled existing Resend contact: {email}')
|
||||
|
||||
logger.info(
|
||||
f'Backfill completed: {backfilled_count} contacts added to tracking'
|
||||
)
|
||||
return backfilled_count
|
||||
|
||||
except Exception:
|
||||
logger.exception('Error during backfill of existing Resend contacts')
|
||||
# Don't fail the entire sync if backfill fails - just log and continue
|
||||
return 0
|
||||
|
||||
|
||||
def sync_users_to_resend():
|
||||
"""Sync users from Keycloak to Resend."""
|
||||
"""Sync users from Keycloak to Resend.
|
||||
|
||||
This function syncs users from Keycloak to a Resend audience. It tracks
|
||||
which users have been synced in the database to ensure that:
|
||||
1. Users are only added once (even across multiple sync runs)
|
||||
2. Users who are manually deleted from Resend are not re-added
|
||||
|
||||
The tracking is done via the resend_synced_users table, which records
|
||||
each email/audience_id combination that has been synced.
|
||||
|
||||
On first run (or when new contacts exist in Resend), it will backfill
|
||||
the tracking table with existing Resend contacts to avoid sending
|
||||
duplicate welcome emails.
|
||||
"""
|
||||
# Check required environment variables
|
||||
required_vars = {
|
||||
'RESEND_API_KEY': RESEND_API_KEY,
|
||||
@@ -318,27 +425,36 @@ def sync_users_to_resend():
|
||||
)
|
||||
|
||||
try:
|
||||
# Get the store for tracking synced users
|
||||
synced_user_store = _get_resend_synced_user_store()
|
||||
|
||||
# Backfill existing Resend contacts into our tracking table
|
||||
# This ensures users already in Resend don't get duplicate welcome emails
|
||||
backfilled_count = _backfill_existing_resend_contacts(
|
||||
synced_user_store, RESEND_AUDIENCE_ID
|
||||
)
|
||||
|
||||
# Get the total number of users
|
||||
total_users = get_total_keycloak_users()
|
||||
logger.info(
|
||||
f'Found {total_users} users in Keycloak realm {KEYCLOAK_REALM_NAME}'
|
||||
)
|
||||
|
||||
# Get contacts from Resend
|
||||
resend_contacts = get_resend_contacts(RESEND_AUDIENCE_ID)
|
||||
logger.info(
|
||||
f'Found {len(resend_contacts)} contacts in Resend audience '
|
||||
f'{RESEND_AUDIENCE_ID}'
|
||||
)
|
||||
|
||||
# Stats
|
||||
stats = {
|
||||
'total_users': total_users,
|
||||
'existing_contacts': len(resend_contacts),
|
||||
'backfilled_contacts': backfilled_count,
|
||||
'already_synced': 0,
|
||||
'added_contacts': 0,
|
||||
'skipped_invalid_emails': 0,
|
||||
'errors': 0,
|
||||
}
|
||||
|
||||
synced_emails = synced_user_store.get_synced_emails_for_audience(
|
||||
RESEND_AUDIENCE_ID
|
||||
)
|
||||
logger.info(f'Found {len(synced_emails)} already synced emails in database')
|
||||
|
||||
# Process users in batches
|
||||
offset = 0
|
||||
while offset < total_users:
|
||||
@@ -351,39 +467,65 @@ def sync_users_to_resend():
|
||||
continue
|
||||
|
||||
email = email.lower()
|
||||
if email in resend_contacts:
|
||||
logger.debug(f'User {email} already exists in Resend, skipping')
|
||||
|
||||
if email in synced_emails:
|
||||
logger.debug(
|
||||
f'User {email} was already synced to this audience, skipping'
|
||||
)
|
||||
stats['already_synced'] += 1
|
||||
continue
|
||||
|
||||
# Validate email format before attempting to add to Resend
|
||||
if not is_valid_email(email):
|
||||
logger.warning(f'Skipping user with invalid email format: {email}')
|
||||
stats['skipped_invalid_emails'] += 1
|
||||
continue
|
||||
|
||||
first_name = user.get('first_name')
|
||||
last_name = user.get('last_name')
|
||||
keycloak_user_id = user.get('id')
|
||||
|
||||
# Mark as synced first (optimistic) to ensure consistency.
|
||||
# If Resend API fails, we remove the record.
|
||||
try:
|
||||
synced_user_store.mark_user_synced(
|
||||
email=email,
|
||||
audience_id=RESEND_AUDIENCE_ID,
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to mark user {email} as synced')
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
first_name = user.get('first_name')
|
||||
last_name = user.get('last_name')
|
||||
|
||||
# Add the contact to the Resend audience
|
||||
add_contact_to_resend(
|
||||
RESEND_AUDIENCE_ID, email, first_name, last_name
|
||||
)
|
||||
logger.info(f'Added user {email} to Resend')
|
||||
stats['added_contacts'] += 1
|
||||
|
||||
# Sleep to respect rate limit after first API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
# Send a welcome email to the newly added contact
|
||||
try:
|
||||
send_welcome_email(email, first_name, last_name)
|
||||
logger.info(f'Sent welcome email to {email}')
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f'Failed to send welcome email to {email}, but contact was added to audience'
|
||||
)
|
||||
# Continue with the sync process even if sending the welcome email fails
|
||||
|
||||
# Sleep to respect rate limit after second API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
except Exception:
|
||||
logger.exception(f'Error adding user {email} to Resend')
|
||||
synced_user_store.remove_synced_user(email, RESEND_AUDIENCE_ID)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
synced_emails.add(email)
|
||||
stats['added_contacts'] += 1
|
||||
|
||||
# Sleep to respect rate limit after first API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
# Send a welcome email to the newly added contact
|
||||
try:
|
||||
send_welcome_email(email, first_name, last_name)
|
||||
logger.info(f'Sent welcome email to {email}')
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f'Failed to send welcome email to {email}, but contact was added to audience'
|
||||
)
|
||||
|
||||
# Sleep to respect rate limit after second API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from server.verified_models.verified_model_service import (
|
||||
StoredVerifiedModel, # noqa: F401
|
||||
)
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -15,6 +18,7 @@ from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation # noqa: F401
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
|
||||
@@ -126,3 +126,24 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
|
||||
# Should be a different instance than the original (copied after handler runs)
|
||||
assert result is not agent
|
||||
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
|
||||
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
|
||||
def test_run_agent_variant_tests_v1_preserves_planning_agent_system_prompt():
|
||||
"""Planning agents should retain their specialized system prompt and not be overwritten by the experiment."""
|
||||
# Arrange
|
||||
planning_agent = make_agent().model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_planning.j2'}
|
||||
)
|
||||
conv_id = uuid4()
|
||||
|
||||
# Act
|
||||
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
|
||||
user_id='user-planning',
|
||||
conversation_id=conv_id,
|
||||
agent=planning_agent,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.system_prompt_filename == 'system_prompt_planning.j2'
|
||||
|
||||
@@ -141,12 +141,14 @@ def test_custom_to_static_conversion():
|
||||
|
||||
def create_provider_tokens(
|
||||
tokens_dict: dict[ProviderType, str],
|
||||
) -> dict[ProviderType, ProviderToken]:
|
||||
"""Helper to create provider tokens dictionary."""
|
||||
return {
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
) -> MappingProxyType:
|
||||
"""Helper to create provider tokens as MappingProxyType."""
|
||||
return MappingProxyType(
|
||||
{
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,3 +266,63 @@ async def test_get_latest_token_can_be_used_with_static_secret(
|
||||
# Assert - this should NOT raise a ValidationError
|
||||
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
|
||||
assert static_secret.get_value() == token_value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for get_authenticated_git_url - ensuring proper authenticated URLs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_git_url_raises_when_no_tokens(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_authenticated_git_url raises error when no provider tokens available."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match='No provider tokens available'):
|
||||
await resolver_context.get_authenticated_git_url('owner/repo')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_handler_caches_instance(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that _get_provider_handler caches the handler instance."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
|
||||
|
||||
# Act - call _get_provider_handler twice
|
||||
handler1 = await resolver_context._get_provider_handler()
|
||||
handler2 = await resolver_context._get_provider_handler()
|
||||
|
||||
# Assert - should be the same instance (cached)
|
||||
assert handler1 is handler2
|
||||
# get_provider_tokens should only be called once
|
||||
assert mock_saas_user_auth.get_provider_tokens.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_handler_creates_handler_with_correct_params(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that _get_provider_handler creates ProviderHandler with correct parameters."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
|
||||
|
||||
# Act
|
||||
handler = await resolver_context._get_provider_handler()
|
||||
|
||||
# Assert
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
|
||||
assert isinstance(handler, ProviderHandler)
|
||||
assert handler.provider_tokens == provider_tokens
|
||||
|
||||
@@ -6,6 +6,9 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
ByorPermittedResponse,
|
||||
LlmApiKeyResponse,
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
@@ -182,16 +185,18 @@ class TestGetLlmApiKeyForByor:
|
||||
"""Test the get_llm_api_key_for_byor endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_no_key_in_database_generates_new(
|
||||
self, mock_get_key, mock_generate_key, mock_store_key
|
||||
self, mock_get_key, mock_generate_key, mock_store_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when no key exists in database, a new one is generated."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = None
|
||||
mock_generate_key.return_value = new_key
|
||||
mock_store_key.return_value = None
|
||||
@@ -200,21 +205,24 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_generate_key.assert_called_once_with(user_id)
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
self, mock_get_key, mock_verify_key
|
||||
self, mock_get_key, mock_verify_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when a valid key exists in database, it is returned."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
existing_key = 'sk-existing-valid-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = existing_key
|
||||
mock_verify_key.return_value = True
|
||||
|
||||
@@ -222,11 +230,13 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': existing_key}
|
||||
assert result == LlmApiKeyResponse(key=existing_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_verify_key.assert_called_once_with(existing_key, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -239,12 +249,14 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_delete_key,
|
||||
mock_generate_key,
|
||||
mock_store_key,
|
||||
mock_check_enabled,
|
||||
):
|
||||
"""Test that when an invalid key exists in database, it is regenerated."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
invalid_key = 'sk-invalid-key'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = invalid_key
|
||||
mock_verify_key.return_value = False
|
||||
mock_delete_key.return_value = True
|
||||
@@ -255,7 +267,8 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_verify_key.assert_called_once_with(invalid_key, user_id)
|
||||
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||
@@ -263,6 +276,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -275,12 +289,14 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_delete_key,
|
||||
mock_generate_key,
|
||||
mock_store_key,
|
||||
mock_check_enabled,
|
||||
):
|
||||
"""Test that even if deletion fails, regeneration still proceeds."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
invalid_key = 'sk-invalid-key'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = invalid_key
|
||||
mock_verify_key.return_value = False
|
||||
mock_delete_key.return_value = False # Deletion fails
|
||||
@@ -291,20 +307,23 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||
mock_generate_key.assert_called_once_with(user_id)
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_key_generation_failure_raises_exception(
|
||||
self, mock_get_key, mock_generate_key
|
||||
self, mock_get_key, mock_generate_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when key generation fails, an HTTPException is raised."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = None
|
||||
mock_generate_key.return_value = None
|
||||
|
||||
@@ -316,11 +335,15 @@ class TestGetLlmApiKeyForByor:
|
||||
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_database_error_raises_exception(self, mock_get_key):
|
||||
async def test_database_error_raises_exception(
|
||||
self, mock_get_key, mock_check_enabled
|
||||
):
|
||||
"""Test that database errors are properly handled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.side_effect = Exception('Database connection error')
|
||||
|
||||
# Act & Assert
|
||||
@@ -330,6 +353,21 @@ class TestGetLlmApiKeyForByor:
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
|
||||
"""Test that when BYOR export is disabled, 402 is returned."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 402
|
||||
assert 'BYOR key export is not enabled' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteByorKeyFromLitellm:
|
||||
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
|
||||
@@ -425,3 +463,52 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCheckByorPermitted:
|
||||
"""Test the check_byor_permitted endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_permitted_when_enabled(self, mock_check_enabled):
|
||||
"""Test that permitted=True is returned when BYOR export is enabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == ByorPermittedResponse(permitted=True)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_not_permitted_when_disabled(self, mock_check_enabled):
|
||||
"""Test that permitted=False is returned when BYOR export is disabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == ByorPermittedResponse(permitted=False)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_error_raises_500(self, mock_check_enabled):
|
||||
"""Test that an exception raises 500 error."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.side_effect = Exception('Database error')
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_byor_permitted(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
@@ -6,8 +6,10 @@ from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.email import (
|
||||
EmailUpdate,
|
||||
ResendEmailVerificationRequest,
|
||||
resend_email_verification,
|
||||
update_email,
|
||||
verified_email,
|
||||
verify_email,
|
||||
)
|
||||
@@ -116,12 +118,15 @@ async def test_verified_email_default_redirect(mock_request, mock_user_auth):
|
||||
"""Test verified_email redirects to /settings/user by default."""
|
||||
# Arrange
|
||||
mock_request.query_params.get.return_value = None
|
||||
mock_user_auth.user_id = 'test-user-id'
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.email.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
@@ -140,12 +145,15 @@ async def test_verified_email_https_scheme(mock_request, mock_user_auth):
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com'
|
||||
mock_request.query_params.get.return_value = None
|
||||
mock_user_auth.user_id = 'test-user-id'
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.email.UserStore') as mock_user_store,
|
||||
):
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
# Assert
|
||||
@@ -327,6 +335,62 @@ async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
|
||||
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_email_calls_update_user_email(mock_request, mock_user_auth):
|
||||
"""POST /api/email should call UserStore.update_user_email with new email and email_verified=False."""
|
||||
user_id = 'test-user-id'
|
||||
new_email = 'new@example.com'
|
||||
email_data = EmailUpdate(email=new_email)
|
||||
|
||||
mock_keycloak_admin = MagicMock()
|
||||
mock_keycloak_admin.get_user.return_value = {
|
||||
'enabled': True,
|
||||
'username': 'testuser',
|
||||
}
|
||||
mock_keycloak_admin.a_update_user = AsyncMock()
|
||||
mock_user_store = MagicMock()
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
|
||||
),
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie'),
|
||||
patch('server.routes.email.verify_email', new_callable=AsyncMock),
|
||||
patch('server.routes.email.UserStore', mock_user_store),
|
||||
):
|
||||
result = await update_email(
|
||||
email_data=email_data, request=mock_request, user_id=user_id
|
||||
)
|
||||
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
mock_user_store.update_user_email.assert_awaited_once_with(
|
||||
user_id=user_id, email=new_email, email_verified=False
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verified_email_calls_update_user_email(mock_request, mock_user_auth):
|
||||
"""GET /api/email/verified should call UserStore.update_user_email with email_verified=True."""
|
||||
mock_user_auth.user_id = 'test-user-id'
|
||||
|
||||
mock_user_store = MagicMock()
|
||||
mock_user_store.update_user_email = AsyncMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
|
||||
patch('server.routes.email.set_response_cookie'),
|
||||
patch('server.routes.email.UserStore', mock_user_store),
|
||||
):
|
||||
result = await verified_email(mock_request)
|
||||
|
||||
assert result.status_code == 302
|
||||
mock_user_store.update_user_email.assert_awaited_once_with(
|
||||
user_id='test-user-id', email_verified=True
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resend_email_verification_body_none_uses_auth(mock_request):
|
||||
"""Test resend_email_verification uses auth when body is None."""
|
||||
|
||||
@@ -1220,3 +1220,60 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
|
||||
|
||||
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
|
||||
assert result == mock_workspace
|
||||
|
||||
|
||||
# Tests for OAuth URL encoding
|
||||
class TestJiraDcOAuthUrlEncoding:
|
||||
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira_dc.get_user_auth')
|
||||
@patch('server.routes.integration.jira_dc.redis_client')
|
||||
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
|
||||
async def test_create_jira_dc_workspace_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_jira_dc_workspace properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
workspace_data = JiraDcWorkspaceCreate(
|
||||
workspace_name='test-workspace',
|
||||
webhook_secret='secret',
|
||||
svc_acc_email='svc@test.com',
|
||||
svc_acc_api_key='key',
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
response = await create_jira_dc_workspace(mock_request, workspace_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira_dc.get_user_auth')
|
||||
@patch('server.routes.integration.jira_dc.redis_client')
|
||||
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
|
||||
async def test_create_workspace_link_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
link_data = JiraDcLinkCreate(workspace_name='test-workspace')
|
||||
|
||||
response = await create_workspace_link(mock_request, link_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@@ -1323,3 +1323,58 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
|
||||
|
||||
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
|
||||
assert result == mock_workspace
|
||||
|
||||
|
||||
# Tests for OAuth URL encoding
|
||||
class TestJiraOAuthUrlEncoding:
|
||||
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.get_user_auth')
|
||||
@patch('server.routes.integration.jira.redis_client')
|
||||
async def test_create_jira_workspace_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_jira_workspace properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
workspace_data = JiraWorkspaceCreate(
|
||||
workspace_name='test-workspace',
|
||||
webhook_secret='secret',
|
||||
svc_acc_email='svc@test.com',
|
||||
svc_acc_api_key='key',
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
response = await create_jira_workspace(mock_request, workspace_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.get_user_auth')
|
||||
@patch('server.routes.integration.jira.redis_client')
|
||||
async def test_create_workspace_link_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
link_data = JiraLinkCreate(workspace_name='test-workspace')
|
||||
|
||||
response = await create_workspace_link(mock_request, link_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
207
enterprise/tests/unit/server/routes/test_user_app_settings.py
Normal file
207
enterprise/tests/unit/server/routes/test_user_app_settings.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""
|
||||
Unit tests for user app settings API routes.
|
||||
|
||||
Tests the GET and POST /api/users/app endpoints.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.testclient import TestClient
|
||||
from server.routes.user_app_settings import user_app_settings_router
|
||||
from server.routes.user_app_settings_models import (
|
||||
UserAppSettingsResponse,
|
||||
UserNotFoundError,
|
||||
)
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
TEST_USER_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""Create a test FastAPI app with user app settings routes and mocked auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(user_app_settings_router)
|
||||
|
||||
def mock_get_user_id():
|
||||
return TEST_USER_ID
|
||||
|
||||
app.dependency_overrides[get_user_id] = mock_get_user_id
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app_unauthenticated():
|
||||
"""Create a test FastAPI app with no authenticated user."""
|
||||
app = FastAPI()
|
||||
app.include_router(user_app_settings_router)
|
||||
|
||||
def mock_get_user_id():
|
||||
return None
|
||||
|
||||
app.dependency_overrides[get_user_id] = mock_get_user_id
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings_response():
|
||||
"""Create a mock user app settings response."""
|
||||
return UserAppSettingsResponse(
|
||||
language='en',
|
||||
user_consents_to_analytics=True,
|
||||
enable_sound_notifications=False,
|
||||
git_user_name='testuser',
|
||||
git_user_email='test@example.com',
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_app_settings_success(mock_app, mock_settings_response):
|
||||
"""
|
||||
GIVEN: An authenticated user with app settings
|
||||
WHEN: GET /api/users/app is called
|
||||
THEN: User's app settings are returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.user_app_settings.UserAppSettingsService.get_user_app_settings',
|
||||
AsyncMock(return_value=mock_settings_response),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/users/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data['language'] == 'en'
|
||||
assert data['user_consents_to_analytics'] is True
|
||||
assert data['enable_sound_notifications'] is False
|
||||
assert data['git_user_name'] == 'testuser'
|
||||
assert data['git_user_email'] == 'test@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_app_settings_not_authenticated(mock_app_unauthenticated):
|
||||
"""
|
||||
GIVEN: An unauthenticated request
|
||||
WHEN: GET /api/users/app is called
|
||||
THEN: 401 Unauthorized is returned
|
||||
"""
|
||||
# Arrange
|
||||
client = TestClient(mock_app_unauthenticated)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/users/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'not authenticated' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_app_settings_user_not_found(mock_app):
|
||||
"""
|
||||
GIVEN: An authenticated user that doesn't exist in the database
|
||||
WHEN: GET /api/users/app is called
|
||||
THEN: 404 Not Found is returned
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.routes.user_app_settings.UserAppSettingsService.get_user_app_settings',
|
||||
AsyncMock(side_effect=UserNotFoundError(TEST_USER_ID)),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.get('/api/users/app')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert 'not found' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_success(mock_app):
|
||||
"""
|
||||
GIVEN: An authenticated user
|
||||
WHEN: POST /api/users/app is called with update data
|
||||
THEN: Updated settings are returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
updated_response = UserAppSettingsResponse(
|
||||
language='es',
|
||||
user_consents_to_analytics=False,
|
||||
enable_sound_notifications=True,
|
||||
git_user_name='newuser',
|
||||
git_user_email='new@example.com',
|
||||
)
|
||||
request_data = {
|
||||
'language': 'es',
|
||||
'user_consents_to_analytics': False,
|
||||
}
|
||||
|
||||
with patch(
|
||||
'server.routes.user_app_settings.UserAppSettingsService.update_user_app_settings',
|
||||
AsyncMock(return_value=updated_response),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/users/app', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert data['language'] == 'es'
|
||||
assert data['user_consents_to_analytics'] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_not_authenticated(mock_app_unauthenticated):
|
||||
"""
|
||||
GIVEN: An unauthenticated request
|
||||
WHEN: POST /api/users/app is called
|
||||
THEN: 401 Unauthorized is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {'language': 'en'}
|
||||
client = TestClient(mock_app_unauthenticated)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/users/app', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
assert 'not authenticated' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_user_not_found(mock_app):
|
||||
"""
|
||||
GIVEN: An authenticated user that doesn't exist in the database
|
||||
WHEN: POST /api/users/app is called
|
||||
THEN: 404 Not Found is returned
|
||||
"""
|
||||
# Arrange
|
||||
request_data = {'language': 'en'}
|
||||
|
||||
with patch(
|
||||
'server.routes.user_app_settings.UserAppSettingsService.update_user_app_settings',
|
||||
AsyncMock(side_effect=UserNotFoundError(TEST_USER_ID)),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/users/app', json=request_data)
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
assert 'not found' in response.json()['detail'].lower()
|
||||
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Unit tests for OrgAppSettingsService.
|
||||
|
||||
Tests the service layer for organization app settings operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import (
|
||||
OrgAppSettingsResponse,
|
||||
OrgAppSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from server.services.org_app_settings_service import OrgAppSettingsService
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id():
|
||||
"""Create a test user ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org():
|
||||
"""Create a mock organization with app settings."""
|
||||
org = MagicMock(spec=Org)
|
||||
org.id = uuid.uuid4()
|
||||
org.enable_proactive_conversation_starters = True
|
||||
org.enable_solvability_analysis = False
|
||||
org.max_budget_per_task = 25.0
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock OrgAppSettingsStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_context(user_id):
|
||||
"""Create a mock UserContext that returns the user_id."""
|
||||
context = MagicMock()
|
||||
context.get_user_id = AsyncMock(return_value=user_id)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user's current organization exists
|
||||
WHEN: get_org_app_settings is called
|
||||
THEN: OrgAppSettingsResponse is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.get_org_app_settings()
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
assert result.enable_proactive_conversation_starters is True
|
||||
assert result.enable_solvability_analysis is False
|
||||
assert result.max_budget_per_task == 25.0
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_app_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: get_org_app_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.get_org_app_settings()
|
||||
|
||||
assert 'current' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user's current organization exists
|
||||
WHEN: update_org_app_settings is called with new values
|
||||
THEN: OrgAppSettingsResponse is returned with updated data
|
||||
"""
|
||||
# Arrange
|
||||
mock_org.enable_proactive_conversation_starters = False
|
||||
mock_org.max_budget_per_task = 50.0
|
||||
|
||||
update_data = OrgAppSettingsUpdate(
|
||||
enable_proactive_conversation_starters=False,
|
||||
max_budget_per_task=50.0,
|
||||
)
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_app_settings = AsyncMock(return_value=mock_org)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_app_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
assert result.enable_proactive_conversation_starters is False
|
||||
assert result.max_budget_per_task == 50.0
|
||||
mock_store.update_org_app_settings.assert_called_once_with(
|
||||
org_id=mock_org.id, update_data=update_data
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_no_changes(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user's current organization exists
|
||||
WHEN: update_org_app_settings is called with no fields
|
||||
THEN: Current settings are returned without calling update
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgAppSettingsUpdate() # No fields set
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_app_settings = AsyncMock()
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_app_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgAppSettingsResponse)
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
mock_store.update_org_app_settings.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: update_org_app_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgAppSettingsUpdate(enable_proactive_conversation_starters=False)
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.update_org_app_settings(update_data)
|
||||
|
||||
assert 'current' in str(exc_info.value)
|
||||
@@ -0,0 +1,215 @@
|
||||
"""
|
||||
Unit tests for OrgLLMSettingsService.
|
||||
|
||||
Tests the service layer for organization LLM settings operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import (
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgNotFoundError,
|
||||
)
|
||||
from server.services.org_llm_settings_service import OrgLLMSettingsService
|
||||
from storage.org import Org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id():
|
||||
"""Create a test user ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def org_id():
|
||||
"""Create a test org ID."""
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org(org_id):
|
||||
"""Create a mock organization with LLM settings."""
|
||||
org = MagicMock(spec=Org)
|
||||
org.id = org_id
|
||||
org.default_llm_model = 'claude-3'
|
||||
org.default_llm_base_url = 'https://api.anthropic.com'
|
||||
org.search_api_key = None
|
||||
org.agent = 'CodeActAgent'
|
||||
org.confirmation_mode = True
|
||||
org.security_analyzer = None
|
||||
org.enable_default_condenser = True
|
||||
org.condenser_max_size = None
|
||||
org.default_max_iterations = 50
|
||||
return org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock OrgLLMSettingsStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_context(user_id):
|
||||
"""Create a mock UserContext that returns the user_id."""
|
||||
context = MagicMock()
|
||||
context.get_user_id = AsyncMock(return_value=user_id)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_llm_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with a current organization
|
||||
WHEN: get_org_llm_settings is called
|
||||
THEN: OrgLLMSettingsResponse is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.get_org_llm_settings()
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgLLMSettingsResponse)
|
||||
assert result.default_llm_model == 'claude-3'
|
||||
assert result.agent == 'CodeActAgent'
|
||||
mock_store.get_current_org_by_user_id.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_llm_settings_user_not_authenticated(mock_store):
|
||||
"""
|
||||
GIVEN: A user is not authenticated
|
||||
WHEN: get_org_llm_settings is called
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_user_context = MagicMock()
|
||||
mock_user_context.get_user_id = AsyncMock(return_value=None)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await service.get_org_llm_settings()
|
||||
|
||||
assert 'not authenticated' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_llm_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: get_org_llm_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.get_org_llm_settings()
|
||||
|
||||
assert 'No current organization' in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_success(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with a current organization
|
||||
WHEN: update_org_llm_settings is called with new values
|
||||
THEN: OrgLLMSettingsResponse is returned with updated data
|
||||
"""
|
||||
# Arrange
|
||||
updated_org = MagicMock(spec=Org)
|
||||
updated_org.id = mock_org.id
|
||||
updated_org.default_llm_model = 'new-model'
|
||||
updated_org.default_llm_base_url = None
|
||||
updated_org.search_api_key = None
|
||||
updated_org.agent = 'CodeActAgent'
|
||||
updated_org.confirmation_mode = False
|
||||
updated_org.security_analyzer = None
|
||||
updated_org.enable_default_condenser = True
|
||||
updated_org.condenser_max_size = None
|
||||
updated_org.default_max_iterations = 100
|
||||
|
||||
update_data = OrgLLMSettingsUpdate(
|
||||
default_llm_model='new-model',
|
||||
confirmation_mode=False,
|
||||
default_max_iterations=100,
|
||||
)
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_llm_settings = AsyncMock(return_value=updated_org)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_llm_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgLLMSettingsResponse)
|
||||
assert result.default_llm_model == 'new-model'
|
||||
assert result.confirmation_mode is False
|
||||
assert result.default_max_iterations == 100
|
||||
mock_store.update_org_llm_settings.assert_called_once_with(
|
||||
org_id=mock_org.id,
|
||||
update_data=update_data,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_no_changes(
|
||||
user_id, mock_org, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with a current organization
|
||||
WHEN: update_org_llm_settings is called with no fields
|
||||
THEN: Current settings are returned without calling update
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgLLMSettingsUpdate() # No fields set
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=mock_org)
|
||||
mock_store.update_org_llm_settings = AsyncMock()
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_org_llm_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgLLMSettingsResponse)
|
||||
assert result.default_llm_model == 'claude-3'
|
||||
mock_store.update_org_llm_settings.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_org_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user has no current organization
|
||||
WHEN: update_org_llm_settings is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
update_data = OrgLLMSettingsUpdate(default_llm_model='new-model')
|
||||
|
||||
mock_store.get_current_org_by_user_id = AsyncMock(return_value=None)
|
||||
service = OrgLLMSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await service.update_org_llm_settings(update_data)
|
||||
|
||||
assert 'No current organization' in str(exc_info.value)
|
||||
2223
enterprise/tests/unit/server/services/test_org_member_service.py
Normal file
2223
enterprise/tests/unit/server/services/test_org_member_service.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for UserAppSettingsService.
|
||||
|
||||
Tests the service layer for user app settings operations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from server.routes.user_app_settings_models import (
|
||||
UserAppSettingsResponse,
|
||||
UserAppSettingsUpdate,
|
||||
UserNotFoundError,
|
||||
)
|
||||
from server.services.user_app_settings_service import UserAppSettingsService
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id():
|
||||
"""Create a test user ID."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(user_id):
|
||||
"""Create a mock user with app settings."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = uuid.UUID(user_id)
|
||||
user.language = 'en'
|
||||
user.user_consents_to_analytics = True
|
||||
user.enable_sound_notifications = False
|
||||
user.git_user_name = 'testuser'
|
||||
user.git_user_email = 'test@example.com'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store():
|
||||
"""Create a mock UserAppSettingsStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_context(user_id):
|
||||
"""Create a mock UserContext that returns the user_id."""
|
||||
context = MagicMock()
|
||||
context.get_user_id = AsyncMock(return_value=user_id)
|
||||
return context
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_app_settings_success(
|
||||
user_id, mock_user, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user exists in the database
|
||||
WHEN: get_user_app_settings is called
|
||||
THEN: UserAppSettingsResponse is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
service = UserAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.get_user_app_settings()
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, UserAppSettingsResponse)
|
||||
assert result.language == 'en'
|
||||
assert result.user_consents_to_analytics is True
|
||||
assert result.enable_sound_notifications is False
|
||||
assert result.git_user_name == 'testuser'
|
||||
assert result.git_user_email == 'test@example.com'
|
||||
mock_store.get_user_by_id.assert_called_once_with(user_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_app_settings_user_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: get_user_app_settings is called
|
||||
THEN: UserNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
mock_store.get_user_by_id = AsyncMock(return_value=None)
|
||||
service = UserAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UserNotFoundError) as exc_info:
|
||||
await service.get_user_app_settings()
|
||||
|
||||
assert user_id in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_success(
|
||||
user_id, mock_user, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user exists in the database
|
||||
WHEN: update_user_app_settings is called with new values
|
||||
THEN: UserAppSettingsResponse is returned with updated data
|
||||
"""
|
||||
# Arrange
|
||||
mock_user.language = 'es'
|
||||
mock_user.user_consents_to_analytics = False
|
||||
|
||||
update_data = UserAppSettingsUpdate(
|
||||
language='es',
|
||||
user_consents_to_analytics=False,
|
||||
)
|
||||
|
||||
mock_store.update_user_app_settings = AsyncMock(return_value=mock_user)
|
||||
service = UserAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_user_app_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, UserAppSettingsResponse)
|
||||
assert result.language == 'es'
|
||||
assert result.user_consents_to_analytics is False
|
||||
mock_store.update_user_app_settings.assert_called_once_with(
|
||||
user_id=user_id, update_data=update_data
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_no_changes(
|
||||
user_id, mock_user, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user exists in the database
|
||||
WHEN: update_user_app_settings is called with no fields
|
||||
THEN: Current settings are returned without calling update
|
||||
"""
|
||||
# Arrange
|
||||
update_data = UserAppSettingsUpdate() # No fields set
|
||||
|
||||
mock_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_store.update_user_app_settings = AsyncMock()
|
||||
service = UserAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act
|
||||
result = await service.update_user_app_settings(update_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, UserAppSettingsResponse)
|
||||
mock_store.get_user_by_id.assert_called_once_with(user_id)
|
||||
mock_store.update_user_app_settings.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_user_not_found(
|
||||
user_id, mock_store, mock_user_context
|
||||
):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: update_user_app_settings is called
|
||||
THEN: UserNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
update_data = UserAppSettingsUpdate(language='en')
|
||||
|
||||
mock_store.update_user_app_settings = AsyncMock(return_value=None)
|
||||
service = UserAppSettingsService(store=mock_store, user_context=mock_user_context)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(UserNotFoundError) as exc_info:
|
||||
await service.update_user_app_settings(update_data)
|
||||
|
||||
assert user_id in str(exc_info.value)
|
||||
661
enterprise/tests/unit/storage/test_auth_token_store.py
Normal file
661
enterprise/tests/unit/storage/test_auth_token_store.py
Normal file
@@ -0,0 +1,661 @@
|
||||
"""Unit tests for AuthTokenStore."""
|
||||
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from storage.auth_token_store import (
|
||||
ACCESS_TOKEN_EXPIRY_BUFFER,
|
||||
LOCK_TIMEOUT_SECONDS,
|
||||
AuthTokenStore,
|
||||
)
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
def create_mock_session():
|
||||
"""Create a mock async session with properly configured context managers."""
|
||||
session = AsyncMock()
|
||||
|
||||
# Create async context manager for begin()
|
||||
@asynccontextmanager
|
||||
async def begin_context():
|
||||
yield
|
||||
|
||||
session.begin = begin_context
|
||||
return session
|
||||
|
||||
|
||||
def create_mock_session_maker(mock_session):
|
||||
"""Create a mock async session maker."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_context():
|
||||
yield mock_session
|
||||
|
||||
# Return a callable that returns the context manager
|
||||
return lambda: session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create mock async session."""
|
||||
return create_mock_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Create mock async session maker."""
|
||||
return create_mock_session_maker(mock_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token_store(mock_session_maker):
|
||||
"""Create AuthTokenStore instance with mocked session maker."""
|
||||
return AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Tests for _is_token_expired method."""
|
||||
|
||||
def test_both_tokens_valid(self, auth_token_store):
|
||||
"""Test when both tokens are valid (not expired)."""
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time + 1000
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_access_token_expired(self, auth_token_store):
|
||||
"""Test when access token is expired but within buffer."""
|
||||
current_time = int(time.time())
|
||||
# Access token expires within buffer period
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
|
||||
refresh_expires = current_time + 10000
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_refresh_token_expired(self, auth_token_store):
|
||||
"""Test when refresh token is expired."""
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time - 100 # Already expired
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_both_tokens_expired(self, auth_token_store):
|
||||
"""Test when both tokens are expired."""
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time - 100
|
||||
refresh_expires = current_time - 100
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
|
||||
"""Test that 0 expiration time is treated as never expires."""
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
|
||||
|
||||
class TestLoadTokensFastPath:
|
||||
"""Tests for load_tokens fast path (no lock needed)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_token_not_found(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns None when no token record exists."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.load_tokens()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_valid_token_no_refresh_needed(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns tokens when they are still valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
mock_token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.load_tokens()
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'valid-access-token'
|
||||
assert result['refresh_token'] == 'valid-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_no_refresh_callback_provided(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns existing tokens when no refresh callback is provided."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
# Expired access token
|
||||
mock_token.access_token_expires_at = current_time - 100
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'expired-access-token'
|
||||
|
||||
|
||||
class TestLoadTokensSlowPath:
|
||||
"""Tests for load_tokens slow path (lock required for refresh)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_successful_refresh(self):
|
||||
"""Test slow path successfully refreshes expired tokens."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
# Second call (slow path) - returns same token for update
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'expired-access-token'
|
||||
expired_token.refresh_token = 'valid-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-access-token',
|
||||
'refresh_token': 'new-refresh-token',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'new-access-token'
|
||||
assert result['refresh_token'] == 'new-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_double_check_avoids_refresh(self):
|
||||
"""Test double-check locking: token was refreshed by another request."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# Simulate scenario:
|
||||
# 1. Fast path sees expired token
|
||||
# 2. While waiting for lock, another request refreshes
|
||||
# 3. Slow path sees fresh token, skips refresh
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def create_token():
|
||||
call_count[0] += 1
|
||||
token = MagicMock()
|
||||
token.id = 1
|
||||
token.access_token = 'fresh-access-token'
|
||||
token.refresh_token = 'fresh-refresh-token'
|
||||
if call_count[0] == 1:
|
||||
# First call (fast path) - expired
|
||||
token.access_token_expires_at = current_time - 100
|
||||
else:
|
||||
# Second call (slow path) - already refreshed
|
||||
token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
token.refresh_token_expires_at = current_time + 86400
|
||||
return token
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = (
|
||||
lambda: create_token()
|
||||
)
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
refresh_called = [False]
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
refresh_called[0] = True
|
||||
return {
|
||||
'access_token': 'should-not-be-used',
|
||||
'refresh_token': 'should-not-be-used',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# The refresh callback should not be called because double-check
|
||||
# found the token was already refreshed
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'fresh-access-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_token_not_found_after_lock(self):
|
||||
"""Test slow path returns None if token record disappears after lock."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - token exists but expired
|
||||
# Second call (slow path with lock) - token no longer exists
|
||||
call_count = [0]
|
||||
|
||||
def get_token():
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
token = MagicMock()
|
||||
token.access_token_expires_at = current_time - 100 # Expired
|
||||
token.refresh_token_expires_at = current_time + 10000
|
||||
return token
|
||||
return None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = get_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadTokensLockTimeout:
|
||||
"""Tests for lock timeout handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_raises_token_refresh_error(self):
|
||||
"""Test that lock timeout raises TokenRefreshError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
# First execute for fast path succeeds
|
||||
# Second execute (for slow path) raises OperationalError
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
# Simulate lock timeout
|
||||
raise OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert 'lock timeout' in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_preserves_original_exception(self):
|
||||
"""Test that TokenRefreshError preserves the original OperationalError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
original_error = OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
raise original_error
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# Verify the original exception is chained
|
||||
assert exc_info.value.__cause__ is original_error
|
||||
|
||||
|
||||
class TestLoadTokensRefreshCallbackBehavior:
|
||||
"""Tests for refresh callback return values."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_callback_returns_none(self):
|
||||
"""Test behavior when refresh callback returns None (no refresh performed)."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'old-access-token'
|
||||
expired_token.refresh_token = 'old-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh_returns_none(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int] | None:
|
||||
return None
|
||||
|
||||
result = await auth_store.load_tokens(
|
||||
check_expiration_and_refresh=mock_refresh_returns_none
|
||||
)
|
||||
|
||||
# Should return the old tokens when refresh returns None
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'old-access-token'
|
||||
assert result['refresh_token'] == 'old-refresh-token'
|
||||
|
||||
|
||||
class TestStoreTokens:
|
||||
"""Tests for store_tokens method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_creates_new_record(self):
|
||||
"""Test storing tokens when no existing record."""
|
||||
mock_session = create_mock_session()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_updates_existing_record(self):
|
||||
"""Test storing tokens updates existing record."""
|
||||
mock_session = create_mock_session()
|
||||
existing_token = MagicMock()
|
||||
existing_token.access_token = 'old-access'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = existing_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
assert existing_token.access_token == 'new-access-token'
|
||||
assert existing_token.refresh_token == 'new-refresh-token'
|
||||
|
||||
|
||||
class TestIsAccessTokenValid:
|
||||
"""Tests for is_access_token_valid method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_when_no_tokens(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns False when no tokens found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_true_for_valid_token(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns True when token is valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time + 1000
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_for_expired_token(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns False when token is expired."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time - 100 # Expired
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetInstance:
|
||||
"""Tests for get_instance class method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_auth_token_store(self):
|
||||
"""Test get_instance creates an AuthTokenStore with correct params."""
|
||||
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
|
||||
store = await AuthTokenStore.get_instance(
|
||||
keycloak_user_id='user-123', idp=ProviderType.GITHUB
|
||||
)
|
||||
|
||||
assert store.keycloak_user_id == 'user-123'
|
||||
assert store.idp == ProviderType.GITHUB
|
||||
assert store.a_session_maker is mock_a_session_maker
|
||||
|
||||
|
||||
class TestIdentityProviderValue:
|
||||
"""Tests for identity_provider_value property."""
|
||||
|
||||
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
|
||||
"""Test that identity_provider_value returns the enum value."""
|
||||
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
|
||||
|
||||
def test_identity_provider_value_for_different_providers(self):
|
||||
"""Test identity_provider_value for different providers."""
|
||||
for provider in [
|
||||
ProviderType.GITHUB,
|
||||
ProviderType.GITLAB,
|
||||
ProviderType.BITBUCKET,
|
||||
]:
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=provider,
|
||||
a_session_maker=MagicMock(),
|
||||
)
|
||||
assert store.identity_provider_value == provider.value
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module constants."""
|
||||
|
||||
def test_access_token_expiry_buffer_value(self):
|
||||
"""Test ACCESS_TOKEN_EXPIRY_BUFFER is set to 15 minutes."""
|
||||
assert ACCESS_TOKEN_EXPIRY_BUFFER == 900
|
||||
|
||||
def test_lock_timeout_seconds_value(self):
|
||||
"""Test LOCK_TIMEOUT_SECONDS is set to 5 seconds."""
|
||||
assert LOCK_TIMEOUT_SECONDS == 5
|
||||
99
enterprise/tests/unit/storage/test_database.py
Normal file
99
enterprise/tests/unit/storage/test_database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for the enterprise storage.database module.
|
||||
|
||||
These tests verify that the session_maker function properly forwards
|
||||
keyword arguments to the underlying session maker for backward compatibility.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestSessionMaker:
|
||||
"""Test cases for the session_maker function."""
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_without_args(self, mock_get_injector):
|
||||
"""Test that session_maker works without any arguments."""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
# Call session_maker without arguments
|
||||
result = session_maker()
|
||||
|
||||
# Verify the inner session maker was called without arguments
|
||||
mock_inner_session_maker.assert_called_once_with()
|
||||
assert result == mock_session
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_with_expire_on_commit_false(self, mock_get_injector):
|
||||
"""Test that session_maker accepts expire_on_commit keyword argument.
|
||||
|
||||
This is a critical backward compatibility test - the session_maker
|
||||
must accept keyword arguments like expire_on_commit=False which is
|
||||
used in slack.py and potentially other integration modules.
|
||||
"""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
# Call session_maker with expire_on_commit=False
|
||||
# This is the exact call pattern used in slack.py line 242
|
||||
result = session_maker(expire_on_commit=False)
|
||||
|
||||
# Verify the inner session maker was called with the keyword argument
|
||||
mock_inner_session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
assert result == mock_session
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_with_multiple_kwargs(self, mock_get_injector):
|
||||
"""Test that session_maker passes through multiple keyword arguments."""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
# Call with multiple kwargs
|
||||
result = session_maker(
|
||||
expire_on_commit=False, autoflush=False, autocommit=False
|
||||
)
|
||||
|
||||
# Verify all kwargs were passed through
|
||||
mock_inner_session_maker.assert_called_once_with(
|
||||
expire_on_commit=False, autoflush=False, autocommit=False
|
||||
)
|
||||
assert result == mock_session
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_returns_correct_session(self, mock_get_injector):
|
||||
"""Test that session_maker returns the session from the inner session maker."""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
result = session_maker()
|
||||
|
||||
# Verify the returned session is from the inner session maker
|
||||
assert result is mock_session
|
||||
189
enterprise/tests/unit/storage/test_org_app_settings_store.py
Normal file
189
enterprise/tests/unit/storage/test_org_app_settings_store.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""
|
||||
Unit tests for OrgAppSettingsStore.
|
||||
|
||||
Tests the async database operations for organization app settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.org_models import OrgAppSettingsUpdate
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_app_settings_store import OrgAppSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists with a current organization
|
||||
WHEN: get_current_org_by_user_id is called with the user's ID
|
||||
THEN: The organization is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
name='test-org',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=25.0,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
current_org_id=org.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
# Act
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == 'test-org'
|
||||
assert result.enable_proactive_conversation_starters is True
|
||||
assert result.enable_solvability_analysis is False
|
||||
assert result.max_budget_per_task == 25.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_user_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: get_current_org_by_user_id is called with a non-existent ID
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists in the database
|
||||
WHEN: update_org_app_settings is called with new values
|
||||
THEN: The organization's settings are updated and returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
name='test-org',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=10.0,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
update_data = OrgAppSettingsUpdate(
|
||||
enable_proactive_conversation_starters=False,
|
||||
enable_solvability_analysis=True,
|
||||
max_budget_per_task=50.0,
|
||||
)
|
||||
|
||||
# Act
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.update_org_app_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.enable_proactive_conversation_starters is False
|
||||
assert result.enable_solvability_analysis is True
|
||||
assert result.max_budget_per_task == 50.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_partial(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists with existing settings
|
||||
WHEN: update_org_app_settings is called with only some fields
|
||||
THEN: Only the provided fields are updated, others remain unchanged
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(
|
||||
name='test-org',
|
||||
enable_proactive_conversation_starters=True,
|
||||
enable_solvability_analysis=False,
|
||||
max_budget_per_task=10.0,
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
# Only update max_budget_per_task
|
||||
update_data = OrgAppSettingsUpdate(max_budget_per_task=100.0)
|
||||
|
||||
# Act
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.update_org_app_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.max_budget_per_task == 100.0
|
||||
assert result.enable_proactive_conversation_starters is True # Unchanged
|
||||
assert result.enable_solvability_analysis is False # Unchanged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_app_settings_org_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization does not exist in the database
|
||||
WHEN: update_org_app_settings is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = uuid.uuid4()
|
||||
update_data = OrgAppSettingsUpdate(enable_proactive_conversation_starters=False)
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgAppSettingsStore(db_session=session)
|
||||
result = await store.update_org_app_settings(non_existent_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
180
enterprise/tests/unit/storage/test_org_llm_settings_store.py
Normal file
180
enterprise/tests/unit/storage/test_org_llm_settings_store.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Unit tests for OrgLLMSettingsStore.
|
||||
|
||||
Tests the async database operations for organization LLM settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.org_models import OrgLLMSettingsUpdate
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.org_llm_settings_store import OrgLLMSettingsStore
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists with a current_org_id
|
||||
WHEN: get_current_org_by_user_id is called
|
||||
THEN: The user's current organization is returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org', default_llm_model='claude-3')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
# Act
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.name == 'test-org'
|
||||
assert result.default_llm_model == 'claude-3'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_current_org_by_user_id_user_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: get_current_org_by_user_id is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
result = await store.get_current_org_by_user_id(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists in the database
|
||||
WHEN: update_org_llm_settings is called with new values
|
||||
THEN: The organization's LLM settings are updated and returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org', default_llm_model='old-model')
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
update_data = OrgLLMSettingsUpdate(
|
||||
default_llm_model='new-model',
|
||||
agent='CodeActAgent',
|
||||
confirmation_mode=True,
|
||||
)
|
||||
|
||||
# Act
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
with patch(
|
||||
'storage.org_llm_settings_store.OrgMemberStore.update_all_members_llm_settings_async',
|
||||
AsyncMock(),
|
||||
):
|
||||
result = await store.update_org_llm_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.default_llm_model == 'new-model'
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.confirmation_mode is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_org_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization does not exist in the database
|
||||
WHEN: update_org_llm_settings is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_org_id = uuid.uuid4()
|
||||
update_data = OrgLLMSettingsUpdate(default_llm_model='new-model')
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
result = await store.update_org_llm_settings(non_existent_org_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_org_llm_settings_propagates_to_members(async_session_maker):
|
||||
"""
|
||||
GIVEN: An organization exists with update data containing member-relevant settings
|
||||
WHEN: update_org_llm_settings is called
|
||||
THEN: Member settings are propagated via OrgMemberStore
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org', default_llm_model='old-model')
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
|
||||
update_data = OrgLLMSettingsUpdate(
|
||||
default_llm_model='new-model',
|
||||
llm_api_key='new-api-key',
|
||||
)
|
||||
|
||||
# Act
|
||||
store = OrgLLMSettingsStore(db_session=session)
|
||||
with patch(
|
||||
'storage.org_llm_settings_store.OrgMemberStore.update_all_members_llm_settings_async',
|
||||
AsyncMock(),
|
||||
) as mock_update_members:
|
||||
await store.update_org_llm_settings(org_id, update_data)
|
||||
|
||||
# Assert
|
||||
mock_update_members.assert_called_once()
|
||||
call_args = mock_update_members.call_args
|
||||
member_settings = call_args[0][2]
|
||||
assert member_settings.llm_model == 'new-model'
|
||||
assert member_settings.llm_api_key == 'new-api-key'
|
||||
158
enterprise/tests/unit/storage/test_resend_synced_user_store.py
Normal file
158
enterprise/tests/unit/storage/test_resend_synced_user_store.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Unit tests for ResendSyncedUserStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Import directly from the module files to avoid loading all of storage/__init__.py
|
||||
# which has many dependencies
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
from storage.resend_synced_user_store import ResendSyncedUserStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(mock_session_maker):
|
||||
"""Create ResendSyncedUserStore instance."""
|
||||
return ResendSyncedUserStore(session_maker=mock_session_maker)
|
||||
|
||||
|
||||
class TestResendSyncedUserStore:
|
||||
"""Test cases for ResendSyncedUserStore."""
|
||||
|
||||
def test_is_user_synced_returns_true_when_exists(self, store, mock_session):
|
||||
"""Test is_user_synced returns True when user exists in database."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_row = MagicMock()
|
||||
mock_session.execute.return_value.first.return_value = mock_row
|
||||
|
||||
result = store.is_user_synced(email, audience_id)
|
||||
|
||||
assert result is True
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_is_user_synced_returns_false_when_not_exists(self, store, mock_session):
|
||||
"""Test is_user_synced returns False when user doesn't exist."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
result = store.is_user_synced(email, audience_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
|
||||
"""Test that is_user_synced normalizes email to lowercase."""
|
||||
email = 'TEST@EXAMPLE.COM'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
store.is_user_synced(email, audience_id)
|
||||
|
||||
# Verify the query was called (we can't easily check the exact SQL)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_creates_new_record(self, store, mock_session):
|
||||
"""Test that mark_user_synced creates a new record."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
keycloak_user_id = 'kc-user-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = store.mark_user_synced(email, audience_id, keycloak_user_id)
|
||||
|
||||
assert result == mock_synced_user
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_handles_existing_record(self, store, mock_session):
|
||||
"""Test that mark_user_synced handles conflict (existing record)."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
# First execute (insert) returns None (conflict occurred)
|
||||
# Second execute (select existing) returns the record
|
||||
mock_existing_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result_insert = MagicMock()
|
||||
mock_result_insert.first.return_value = None
|
||||
|
||||
mock_result_select = MagicMock()
|
||||
mock_result_select.first.return_value = (mock_existing_user,)
|
||||
|
||||
mock_session.execute.side_effect = [mock_result_insert, mock_result_select]
|
||||
|
||||
result = store.mark_user_synced(email, audience_id)
|
||||
|
||||
assert result == mock_existing_user
|
||||
assert mock_session.execute.call_count == 2
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
|
||||
"""Test that mark_user_synced normalizes email to lowercase."""
|
||||
email = 'TEST@EXAMPLE.COM'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
store.mark_user_synced(email, audience_id)
|
||||
|
||||
# Verify execute was called (the email normalization happens in the SQL)
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_without_keycloak_user_id(self, store, mock_session):
|
||||
"""Test that mark_user_synced works without keycloak_user_id."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = store.mark_user_synced(email, audience_id)
|
||||
|
||||
assert result == mock_synced_user
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestResendSyncedUser:
|
||||
"""Test cases for ResendSyncedUser model."""
|
||||
|
||||
def test_model_has_required_fields(self):
|
||||
"""Test that the model has all required fields."""
|
||||
assert hasattr(ResendSyncedUser, 'id')
|
||||
assert hasattr(ResendSyncedUser, 'email')
|
||||
assert hasattr(ResendSyncedUser, 'audience_id')
|
||||
assert hasattr(ResendSyncedUser, 'synced_at')
|
||||
assert hasattr(ResendSyncedUser, 'keycloak_user_id')
|
||||
|
||||
def test_model_table_name(self):
|
||||
"""Test the model's table name."""
|
||||
assert ResendSyncedUser.__tablename__ == 'resend_synced_users'
|
||||
@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
# Test UUIDs
|
||||
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
|
||||
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
|
||||
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session with pre-populated Org and User rows for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
# Insert Orgs first (required for User foreign key)
|
||||
org1 = Org(
|
||||
id=ORG1_ID,
|
||||
name='test-org-1',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
org2 = Org(
|
||||
id=ORG2_ID,
|
||||
name='test-org-2',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
db_session.add(org1)
|
||||
db_session.add(org2)
|
||||
await db_session.flush()
|
||||
|
||||
# Insert Users
|
||||
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
|
||||
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
|
||||
db_session.add(user1)
|
||||
db_session.add(user2)
|
||||
await db_session.commit()
|
||||
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
async def test_secure_select_includes_user_and_org_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
"""Test that _secure_select method includes both user_id and org_id filtering."""
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
query = await service._secure_select()
|
||||
|
||||
# Convert query to string to verify filters are present
|
||||
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||
|
||||
# Verify user_id filter is present
|
||||
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
|
||||
|
||||
# Verify org_id filter is present (user1 is in org1)
|
||||
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
async def test_user_isolation_different_users(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
"""Test that different users cannot see each other's conversations."""
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
@@ -366,3 +347,319 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_user_org_switching_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that the same user switching orgs cannot see conversations from other orgs.
|
||||
|
||||
This tests the actual bug scenario: a user creates a conversation in org1,
|
||||
then switches to org2, and should NOT see org1's conversations.
|
||||
"""
|
||||
# Create service for user1 in org1
|
||||
user1_service_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create a conversation while user is in org1
|
||||
conv_in_org1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org1',
|
||||
title='Conversation in Org 1',
|
||||
)
|
||||
await user1_service_org1.save_app_conversation_info(conv_in_org1)
|
||||
|
||||
# Verify user can see the conversation in org1
|
||||
page_in_org1 = await user1_service_org1.search_app_conversation_info()
|
||||
assert len(page_in_org1.items) == 1
|
||||
assert page_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
# Simulate user switching to org2 by updating current_org_id using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
# Clear SQLAlchemy's identity map cache to simulate a new request
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Create new service instance (simulating a new request after org switch)
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should NOT see org1's conversations after switching to org2
|
||||
page_in_org2 = await user1_service_org2.search_app_conversation_info()
|
||||
assert (
|
||||
len(page_in_org2.items) == 0
|
||||
), 'User should not see conversations from org1 after switching to org2'
|
||||
|
||||
# User should not be able to get the specific conversation from org1
|
||||
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
|
||||
conv_in_org1.id
|
||||
)
|
||||
assert (
|
||||
conv_from_org2 is None
|
||||
), 'User should not be able to access org1 conversation from org2'
|
||||
|
||||
# Now create a conversation in org2
|
||||
conv_in_org2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org2',
|
||||
title='Conversation in Org 2',
|
||||
)
|
||||
await user1_service_org2.save_app_conversation_info(conv_in_org2)
|
||||
|
||||
# User should only see org2's conversation
|
||||
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
|
||||
assert len(page_in_org2_after.items) == 1
|
||||
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
|
||||
|
||||
# Switch back to org1 and verify isolation works both ways
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG1_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should only see org1's conversation now
|
||||
page_back_in_org1 = (
|
||||
await user1_service_back_to_org1.search_app_conversation_info()
|
||||
)
|
||||
assert len(page_back_in_org1.items) == 1
|
||||
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_respects_org_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that count_app_conversation_info respects org isolation."""
|
||||
# Create service for user1 in org1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations in org1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_org1_{i}',
|
||||
title=f'Org1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Count should be 3
|
||||
count_org1 = await user1_service.count_app_conversation_info()
|
||||
assert count_org1 == 3
|
||||
|
||||
# Switch to org2 using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Count should be 0 in org2
|
||||
count_org2 = await user1_service_org2.count_app_conversation_info()
|
||||
assert count_org2 == 0
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoServiceAdminContext:
|
||||
"""Test suite for SaasSQLAppConversationInfoService with ADMIN context."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_context_returns_unfiltered_data(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that ADMIN context returns unfiltered data (no user/org filtering)."""
|
||||
# Create conversations for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations for user1 in org1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_user1_{i}',
|
||||
title=f'User1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Now create an ADMIN service
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
# ADMIN should see ALL conversations (unfiltered)
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert (
|
||||
len(admin_page.items) == 3
|
||||
), 'ADMIN context should see all conversations without filtering'
|
||||
|
||||
# ADMIN count should return total count (3)
|
||||
admin_count = await admin_service.count_app_conversation_info()
|
||||
assert (
|
||||
admin_count == 3
|
||||
), 'ADMIN context should count all conversations without filtering'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_context_can_access_any_conversation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that ADMIN context can access any conversation regardless of owner."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Create a conversation as user1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User1 Private Conversation',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Create a service as user2 in org2 - should not see user1's conversation
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 0, 'User2 should not see User1 conversation'
|
||||
|
||||
# But ADMIN should see ALL conversations including user1's
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert len(admin_page.items) == 1
|
||||
assert admin_page.items[0].id == conv.id
|
||||
|
||||
# ADMIN should also be able to get specific conversation by ID
|
||||
admin_get_conv = await admin_service.get_app_conversation_info(conv.id)
|
||||
assert admin_get_conv is not None
|
||||
assert admin_get_conv.id == conv.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_admin_bypasses_filtering(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that _secure_select returns unfiltered query for ADMIN context."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Create an ADMIN service
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
# Get the secure select query
|
||||
query = await admin_service._secure_select()
|
||||
|
||||
# Convert query to string to verify NO filters are present
|
||||
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||
|
||||
# For ADMIN, there should be no user_id or org_id filtering
|
||||
# The query should not contain filters for user_id or org_id
|
||||
assert str(USER1_ID) not in query_str.replace(
|
||||
'-', ''
|
||||
), 'ADMIN context should not filter by user_id'
|
||||
assert str(USER2_ID) not in query_str.replace(
|
||||
'-', ''
|
||||
), 'ADMIN context should not filter by user_id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_context_filters_correctly(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that regular user context properly filters data (control test)."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Create conversations for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create 3 conversations for user1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_user1_{i}',
|
||||
title=f'User1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Create 2 conversations for user2
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
for i in range(2):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id=f'sandbox_user2_{i}',
|
||||
title=f'User2 Conversation {i}',
|
||||
)
|
||||
await user2_service.save_app_conversation_info(conv)
|
||||
|
||||
# User1 should only see their 3 conversations
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 3
|
||||
|
||||
# User2 should only see their 2 conversations
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 2
|
||||
|
||||
# But ADMIN should see all 5 conversations
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert len(admin_page.items) == 5
|
||||
|
||||
204
enterprise/tests/unit/storage/test_user_app_settings_store.py
Normal file
204
enterprise/tests/unit/storage/test_user_app_settings_store.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Unit tests for UserAppSettingsStore.
|
||||
|
||||
Tests the async database operations for user app settings.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Mock the database module before importing
|
||||
with patch('storage.database.engine', create=True), patch(
|
||||
'storage.database.a_engine', create=True
|
||||
):
|
||||
from server.routes.user_app_settings_models import UserAppSettingsUpdate
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
from storage.user_app_settings_store import UserAppSettingsStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_maker(async_engine):
|
||||
"""Create an async session maker for testing."""
|
||||
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists in the database
|
||||
WHEN: get_user_by_id is called with the user's ID
|
||||
THEN: The user is returned with correct data
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
current_org_id=org.id,
|
||||
language='en',
|
||||
user_consents_to_analytics=True,
|
||||
enable_sound_notifications=False,
|
||||
git_user_name='testuser',
|
||||
git_user_email='test@example.com',
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
# Act - create store with the session
|
||||
store = UserAppSettingsStore(db_session=session)
|
||||
result = await store.get_user_by_id(user_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert str(result.id) == user_id
|
||||
assert result.language == 'en'
|
||||
assert result.user_consents_to_analytics is True
|
||||
assert result.enable_sound_notifications is False
|
||||
assert result.git_user_name == 'testuser'
|
||||
assert result.git_user_email == 'test@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: get_user_by_id is called with a non-existent ID
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = UserAppSettingsStore(db_session=session)
|
||||
result = await store.get_user_by_id(non_existent_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_success(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists in the database
|
||||
WHEN: update_user_app_settings is called with new values
|
||||
THEN: The user's settings are updated and returned
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
current_org_id=org.id,
|
||||
language='en',
|
||||
user_consents_to_analytics=False,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
update_data = UserAppSettingsUpdate(
|
||||
language='es',
|
||||
user_consents_to_analytics=True,
|
||||
enable_sound_notifications=True,
|
||||
git_user_name='newuser',
|
||||
git_user_email='new@example.com',
|
||||
)
|
||||
|
||||
# Act - create store with the session
|
||||
store = UserAppSettingsStore(db_session=session)
|
||||
result = await store.update_user_app_settings(user_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.language == 'es'
|
||||
assert result.user_consents_to_analytics is True
|
||||
assert result.enable_sound_notifications is True
|
||||
assert result.git_user_name == 'newuser'
|
||||
assert result.git_user_email == 'new@example.com'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_partial(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user exists with existing settings
|
||||
WHEN: update_user_app_settings is called with only some fields
|
||||
THEN: Only the provided fields are updated, others remain unchanged
|
||||
"""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
current_org_id=org.id,
|
||||
language='en',
|
||||
user_consents_to_analytics=True,
|
||||
git_user_name='original',
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
user_id = str(user.id)
|
||||
|
||||
# Only update language
|
||||
update_data = UserAppSettingsUpdate(language='fr')
|
||||
|
||||
# Act - create store with the session
|
||||
store = UserAppSettingsStore(db_session=session)
|
||||
result = await store.update_user_app_settings(user_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.language == 'fr'
|
||||
assert result.user_consents_to_analytics is True # Unchanged
|
||||
assert result.git_user_name == 'original' # Unchanged
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_app_settings_user_not_found(async_session_maker):
|
||||
"""
|
||||
GIVEN: A user does not exist in the database
|
||||
WHEN: update_user_app_settings is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
update_data = UserAppSettingsUpdate(language='en')
|
||||
|
||||
# Act
|
||||
async with async_session_maker() as session:
|
||||
store = UserAppSettingsStore(db_session=session)
|
||||
result = await store.update_user_app_settings(non_existent_id, update_data)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
267
enterprise/tests/unit/sync/test_resend_keycloak.py
Normal file
267
enterprise/tests/unit/sync/test_resend_keycloak.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for Resend Keycloak sync functionality."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from resend.exceptions import ResendError
|
||||
from tenacity import RetryError
|
||||
|
||||
# Set required environment variables before importing the module
|
||||
# that reads them at import time
|
||||
os.environ['RESEND_API_KEY'] = 'test_api_key'
|
||||
os.environ['RESEND_AUDIENCE_ID'] = 'test_audience_id'
|
||||
os.environ['KEYCLOAK_SERVER_URL'] = 'http://localhost:8080'
|
||||
os.environ['KEYCLOAK_REALM_NAME'] = 'test_realm'
|
||||
os.environ['KEYCLOAK_ADMIN_PASSWORD'] = 'test_password'
|
||||
|
||||
from enterprise.sync.resend_keycloak import ( # noqa: E402
|
||||
add_contact_to_resend,
|
||||
is_valid_email,
|
||||
send_welcome_email,
|
||||
)
|
||||
|
||||
|
||||
class TestIsValidEmail:
|
||||
"""Test cases for is_valid_email function."""
|
||||
|
||||
def test_valid_simple_email(self):
|
||||
"""Test that a simple valid email passes validation."""
|
||||
assert is_valid_email('user@example.com') is True
|
||||
|
||||
def test_valid_email_with_plus(self):
|
||||
"""Test that email with + modifier passes validation."""
|
||||
assert is_valid_email('user+tag@example.com') is True
|
||||
|
||||
def test_valid_email_with_dots(self):
|
||||
"""Test that email with dots in local part passes validation."""
|
||||
assert is_valid_email('first.last@example.com') is True
|
||||
|
||||
def test_valid_email_with_numbers(self):
|
||||
"""Test that email with numbers passes validation."""
|
||||
assert is_valid_email('user123@example.com') is True
|
||||
|
||||
def test_valid_email_with_subdomain(self):
|
||||
"""Test that email with subdomain passes validation."""
|
||||
assert is_valid_email('user@mail.example.com') is True
|
||||
|
||||
def test_valid_email_with_hyphen_domain(self):
|
||||
"""Test that email with hyphen in domain passes validation."""
|
||||
assert is_valid_email('user@example-site.com') is True
|
||||
|
||||
def test_valid_email_with_underscore(self):
|
||||
"""Test that email with underscore passes validation."""
|
||||
assert is_valid_email('user_name@example.com') is True
|
||||
|
||||
def test_valid_email_with_percent(self):
|
||||
"""Test that email with percent sign passes validation."""
|
||||
assert is_valid_email('user%name@example.com') is True
|
||||
|
||||
def test_invalid_email_with_exclamation(self):
|
||||
"""Test that email with exclamation mark fails validation.
|
||||
|
||||
This is the specific case from the bug report:
|
||||
ethanjames3713+!@gmail.com
|
||||
"""
|
||||
assert is_valid_email('ethanjames3713+!@gmail.com') is False
|
||||
|
||||
def test_invalid_email_with_special_chars(self):
|
||||
"""Test that email with other special characters fails validation."""
|
||||
assert is_valid_email('user!name@example.com') is False
|
||||
assert is_valid_email('user#name@example.com') is False
|
||||
assert is_valid_email('user$name@example.com') is False
|
||||
assert is_valid_email('user&name@example.com') is False
|
||||
assert is_valid_email("user'name@example.com") is False
|
||||
assert is_valid_email('user*name@example.com') is False
|
||||
assert is_valid_email('user=name@example.com') is False
|
||||
assert is_valid_email('user^name@example.com') is False
|
||||
assert is_valid_email('user`name@example.com') is False
|
||||
assert is_valid_email('user{name@example.com') is False
|
||||
assert is_valid_email('user|name@example.com') is False
|
||||
assert is_valid_email('user}name@example.com') is False
|
||||
assert is_valid_email('user~name@example.com') is False
|
||||
|
||||
def test_invalid_email_no_at_symbol(self):
|
||||
"""Test that email without @ symbol fails validation."""
|
||||
assert is_valid_email('userexample.com') is False
|
||||
|
||||
def test_invalid_email_no_domain(self):
|
||||
"""Test that email without domain fails validation."""
|
||||
assert is_valid_email('user@') is False
|
||||
|
||||
def test_invalid_email_no_local_part(self):
|
||||
"""Test that email without local part fails validation."""
|
||||
assert is_valid_email('@example.com') is False
|
||||
|
||||
def test_invalid_email_no_tld(self):
|
||||
"""Test that email without TLD fails validation."""
|
||||
assert is_valid_email('user@example') is False
|
||||
|
||||
def test_invalid_email_single_char_tld(self):
|
||||
"""Test that email with single character TLD fails validation."""
|
||||
assert is_valid_email('user@example.c') is False
|
||||
|
||||
def test_invalid_email_empty_string(self):
|
||||
"""Test that empty string fails validation."""
|
||||
assert is_valid_email('') is False
|
||||
|
||||
def test_invalid_email_none(self):
|
||||
"""Test that None fails validation."""
|
||||
assert is_valid_email(None) is False
|
||||
|
||||
def test_invalid_email_whitespace(self):
|
||||
"""Test that email with whitespace fails validation."""
|
||||
assert is_valid_email('user @example.com') is False
|
||||
assert is_valid_email('user@ example.com') is False
|
||||
assert is_valid_email(' user@example.com') is False
|
||||
assert is_valid_email('user@example.com ') is False
|
||||
|
||||
def test_invalid_email_double_at(self):
|
||||
"""Test that email with double @ fails validation."""
|
||||
assert is_valid_email('user@@example.com') is False
|
||||
|
||||
def test_email_double_dot_domain(self):
|
||||
"""Test email with double dot in domain.
|
||||
|
||||
Note: The regex allows this as it's technically valid in some edge cases,
|
||||
and Resend's API may accept it. The main goal is to reject special
|
||||
characters like ! that Resend definitely rejects.
|
||||
"""
|
||||
# This is allowed by our regex - Resend may or may not accept it
|
||||
assert is_valid_email('user@example..com') is True
|
||||
|
||||
def test_case_insensitive_validation(self):
|
||||
"""Test that validation works for uppercase emails."""
|
||||
assert is_valid_email('USER@EXAMPLE.COM') is True
|
||||
assert is_valid_email('User@Example.Com') is True
|
||||
|
||||
|
||||
class TestSendWelcomeEmail:
|
||||
"""Tests for send_welcome_email function."""
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_success(self, mock_send: MagicMock) -> None:
|
||||
"""Test successful welcome email sending."""
|
||||
mock_send.return_value = {'id': 'email_123'}
|
||||
|
||||
result = send_welcome_email(
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
assert result == {'id': 'email_123'}
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args[0][0]
|
||||
assert call_args['to'] == ['test@example.com']
|
||||
assert call_args['subject'] == 'Welcome to OpenHands Cloud'
|
||||
assert 'Hi John Doe,' in call_args['html']
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_retries_on_rate_limit(
|
||||
self, mock_send: MagicMock
|
||||
) -> None:
|
||||
"""Test that send_welcome_email retries on rate limit errors."""
|
||||
# First two calls raise rate limit error, third succeeds
|
||||
mock_send.side_effect = [
|
||||
ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
),
|
||||
ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
),
|
||||
{'id': 'email_123'},
|
||||
]
|
||||
|
||||
result = send_welcome_email(
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
assert result == {'id': 'email_123'}
|
||||
assert mock_send.call_count == 3
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_fails_after_max_retries(
|
||||
self, mock_send: MagicMock
|
||||
) -> None:
|
||||
"""Test that send_welcome_email fails after max retries."""
|
||||
# All calls raise rate limit error
|
||||
mock_send.side_effect = ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
)
|
||||
|
||||
# Tenacity wraps the final exception in RetryError
|
||||
with pytest.raises(RetryError):
|
||||
send_welcome_email(
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
# Default MAX_RETRIES is 3
|
||||
assert mock_send.call_count == 3
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_no_name(self, mock_send: MagicMock) -> None:
|
||||
"""Test welcome email with no name provided."""
|
||||
mock_send.return_value = {'id': 'email_123'}
|
||||
|
||||
result = send_welcome_email(email='test@example.com')
|
||||
|
||||
assert result == {'id': 'email_123'}
|
||||
call_args = mock_send.call_args[0][0]
|
||||
assert 'Hi there,' in call_args['html']
|
||||
|
||||
|
||||
class TestAddContactToResend:
|
||||
"""Tests for add_contact_to_resend function."""
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
|
||||
def test_add_contact_to_resend_success(self, mock_create: MagicMock) -> None:
|
||||
"""Test successful contact addition."""
|
||||
mock_create.return_value = {'id': 'contact_123'}
|
||||
|
||||
result = add_contact_to_resend(
|
||||
audience_id='test_audience',
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
assert result == {'id': 'contact_123'}
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
|
||||
def test_add_contact_to_resend_retries_on_rate_limit(
|
||||
self, mock_create: MagicMock
|
||||
) -> None:
|
||||
"""Test that add_contact_to_resend retries on rate limit errors."""
|
||||
# First call raises rate limit error, second succeeds
|
||||
mock_create.side_effect = [
|
||||
ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
),
|
||||
{'id': 'contact_123'},
|
||||
]
|
||||
|
||||
result = add_contact_to_resend(
|
||||
audience_id='test_audience',
|
||||
email='test@example.com',
|
||||
)
|
||||
|
||||
assert result == {'id': 'contact_123'}
|
||||
assert mock_create.call_count == 2
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user