mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
144 Commits
fix-async-
...
self-hoste
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
818f1ef6aa | ||
|
|
4d407aefdc | ||
|
|
faacfd48d0 | ||
|
|
64ad0afff3 | ||
|
|
aea51b3b47 | ||
|
|
60f921dbc0 | ||
|
|
bcdd8bb9e2 | ||
|
|
bf9472511c | ||
|
|
ab0ff4c1ac | ||
|
|
4fcf30e76b | ||
|
|
27708b98aa | ||
|
|
f3dd2024b2 | ||
|
|
42c3527e73 | ||
|
|
87a8778210 | ||
|
|
3640e1dadd | ||
|
|
c9cf433fea | ||
|
|
8784681772 | ||
|
|
38e65351a5 | ||
|
|
d8d522bb1e | ||
|
|
d1aed4cfc1 | ||
|
|
36bb4d9e30 | ||
|
|
c336a79727 | ||
|
|
a57db2b5b2 | ||
|
|
929dcc39eb | ||
|
|
872f2b87f2 | ||
|
|
ee86005a3a | ||
|
|
d4aa30580b | ||
|
|
2f0e879129 | ||
|
|
3bc2ef954e | ||
|
|
32ab2a24c6 | ||
|
|
a6e148d1e6 | ||
|
|
3fc977eddd | ||
|
|
89a6890269 | ||
|
|
8927ac2230 | ||
|
|
f3429e33ca | ||
|
|
7cd219792b | ||
|
|
2aabe2ed8c | ||
|
|
731a9a813e | ||
|
|
123e556fed | ||
|
|
6676cae249 | ||
|
|
fede37b496 | ||
|
|
3bcd6f18df | ||
|
|
0da18440c2 | ||
|
|
ac76e10048 | ||
|
|
b98bae8b5f | ||
|
|
516721d1ee | ||
|
|
4d6f66ca28 | ||
|
|
b18568da0b | ||
|
|
83dd3c169c | ||
|
|
35bddb14f1 | ||
|
|
e8425218e2 | ||
|
|
0a879fa781 | ||
|
|
41e142bbab | ||
|
|
b06b9eedac | ||
|
|
a9afafa991 | ||
|
|
663ace4b39 | ||
|
|
2d085a6e0a | ||
|
|
8b7112abe8 | ||
|
|
34547ba947 | ||
|
|
5f958ab60d | ||
|
|
d7656bf1c9 | ||
|
|
2bc107564c | ||
|
|
85eb1e1504 | ||
|
|
cd235cc8c7 | ||
|
|
40f52dfabc | ||
|
|
bab7bf85e8 | ||
|
|
c856537f65 | ||
|
|
736f5b2255 | ||
|
|
c1d9d11772 | ||
|
|
85244499fe | ||
|
|
c55084e223 | ||
|
|
e3bb75deb4 | ||
|
|
1948200762 | ||
|
|
affe0af361 | ||
|
|
f20c956196 | ||
|
|
4a089a3a0d | ||
|
|
aa0b2d0b74 | ||
|
|
bef9b80b9d | ||
|
|
c4a90b1f89 | ||
|
|
0d13c57d9f | ||
|
|
b3422f1275 | ||
|
|
f139a9970b | ||
|
|
54d156122c | ||
|
|
ac072bf686 | ||
|
|
a53812c029 | ||
|
|
1d1c0925b5 | ||
|
|
872f41e3c0 | ||
|
|
d43ff82534 | ||
|
|
8cd8c011b2 | ||
|
|
5c68b10983 | ||
|
|
a97fad1976 | ||
|
|
4c3542a91c | ||
|
|
f460057f58 | ||
|
|
4fa2ad0f47 | ||
|
|
dd8be12809 | ||
|
|
89475095d9 | ||
|
|
05d5f8848a | ||
|
|
ee2885eb0b | ||
|
|
545257f870 | ||
|
|
b23ab33a01 | ||
|
|
a9ede73391 | ||
|
|
634c2439b4 | ||
|
|
a1989a40b3 | ||
|
|
e38f1283ea | ||
|
|
07eb791735 | ||
|
|
c355c4819f | ||
|
|
9d8e4c44cc | ||
|
|
25cc55e558 | ||
|
|
0e825c38d7 | ||
|
|
ce04e70b5b | ||
|
|
7b0589ad40 | ||
|
|
682465a862 | ||
|
|
1bb4c844d4 | ||
|
|
d6c11fe517 | ||
|
|
b088d4857e | ||
|
|
0f05898d55 | ||
|
|
d1f0a01a57 | ||
|
|
f5a9d28999 | ||
|
|
afa0417608 | ||
|
|
e688fba761 | ||
|
|
d1ec5cbdf6 | ||
|
|
f42625f789 | ||
|
|
fe28519677 | ||
|
|
e62ceafa4a | ||
|
|
0b8c69fad2 | ||
|
|
37d9b672a4 | ||
|
|
c8b867a634 | ||
|
|
59834beba7 | ||
|
|
d2eced9cff | ||
|
|
7836136ff8 | ||
|
|
fdb04dfe5d | ||
|
|
3d4cb89441 | ||
|
|
9fb9efd3d2 | ||
|
|
5511c01c2e | ||
|
|
02825fb5bb | ||
|
|
876e773589 | ||
|
|
9e1ae86191 | ||
|
|
df47b7b79d | ||
|
|
7d1c105b55 | ||
|
|
db6a9e8895 | ||
|
|
d76ac44dc3 | ||
|
|
c483c80a3c | ||
|
|
570ab904f6 | ||
|
|
00a74731ae |
137
.github/ISSUE_TEMPLATE/bug_template.yml
vendored
137
.github/ISSUE_TEMPLATE/bug_template.yml
vendored
@@ -5,52 +5,113 @@ labels: ['bug']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: Thank you for taking the time to fill out this bug report. Please provide as much information as possible
|
||||
to help us understand and address the issue effectively.
|
||||
value: |
|
||||
## Thank you for reporting a bug! 🐛
|
||||
|
||||
**Please fill out all required fields.** Issues missing critical information (version, installation method, reproduction steps, etc.) will be delayed or closed until complete details are provided.
|
||||
|
||||
Clear, detailed reports help us resolve issues faster.
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for the same bug? (If one exists, thumbs up or comment on the issue instead).
|
||||
description: Please check if an issue already exists for the bug you encountered.
|
||||
label: Is there an existing issue for the same bug?
|
||||
description: Please search existing issues before creating a new one. If found, react or comment to the duplicate issue instead of making a new one.
|
||||
options:
|
||||
- label: I have checked the existing issues.
|
||||
- label: I have searched existing issues and this is not a duplicate.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: bug-description
|
||||
attributes:
|
||||
label: Describe the bug and reproduction steps
|
||||
description: Provide a description of the issue along with any reproduction steps.
|
||||
label: Bug Description
|
||||
description: Clearly describe what went wrong. Be specific and concise.
|
||||
placeholder: Example - "When I run a Python task, OpenHands crashes after 30 seconds with a connection timeout error."
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: What did you expect to happen?
|
||||
placeholder: Example - "OpenHands should execute the Python script and return results."
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: What actually happened?
|
||||
placeholder: Example - "Connection timed out after 30 seconds, task failed with error code 500."
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: reproduction-steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Provide clear, step-by-step instructions to reproduce the bug.
|
||||
placeholder: |
|
||||
1. Install OpenHands using Docker
|
||||
2. Configure with Claude 3.5 Sonnet
|
||||
3. Run command: `openhands run "write a python script"`
|
||||
4. Wait 30 seconds
|
||||
5. Error appears
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: dropdown
|
||||
id: installation
|
||||
attributes:
|
||||
label: OpenHands Installation
|
||||
label: OpenHands Installation Method
|
||||
description: How are you running OpenHands?
|
||||
options:
|
||||
- Docker command in README
|
||||
- GitHub resolver
|
||||
- CLI (uv tool install)
|
||||
- CLI (executable binary)
|
||||
- CLI (Docker)
|
||||
- Local GUI (Docker web interface)
|
||||
- OpenHands Cloud (app.all-hands.dev)
|
||||
- SDK (Python library)
|
||||
- Development workflow
|
||||
- CLI
|
||||
- app.all-hands.dev
|
||||
- Other
|
||||
default: 0
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: installation-other
|
||||
attributes:
|
||||
label: If you selected "Other", please specify
|
||||
description: Describe your installation method
|
||||
placeholder: ex. Custom Kubernetes deployment, pip install from source, etc.
|
||||
|
||||
- type: input
|
||||
id: openhands-version
|
||||
attributes:
|
||||
label: OpenHands Version
|
||||
description: What version of OpenHands are you using?
|
||||
placeholder: ex. 0.9.8, main, etc.
|
||||
description: What version are you using? Find this in settings or by running `openhands --version`
|
||||
placeholder: ex. 0.9.8, main, commit hash, etc.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: version-confirmation
|
||||
attributes:
|
||||
label: Version Confirmation
|
||||
description: Bugs on older versions may already be fixed. Please upgrade before submitting.
|
||||
options:
|
||||
- label: "I have confirmed this bug exists on the LATEST version of OpenHands"
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: model-name
|
||||
attributes:
|
||||
label: Model Name
|
||||
description: What model are you using?
|
||||
placeholder: ex. gpt-4o, claude-3-5-sonnet, openrouter/deepseek-r1, etc.
|
||||
description: Which LLM model are you using?
|
||||
placeholder: ex. gpt-4o, claude-3-5-sonnet-20241022, openrouter/deepseek-r1, etc.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
@@ -60,12 +121,46 @@ body:
|
||||
- MacOS
|
||||
- Linux
|
||||
- WSL on Windows
|
||||
- Windows (Docker Desktop)
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: browser
|
||||
attributes:
|
||||
label: Browser (if using web UI)
|
||||
description: |
|
||||
If applicable, which browser and version?
|
||||
|
||||
placeholder: ex. Chrome 131, Firefox 133, Safari 17.2
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs and Error Messages
|
||||
description: |
|
||||
**Paste relevant logs, error messages, or stack traces.** Use code blocks (```) for formatting.
|
||||
|
||||
LLM logs are in `logs/llm/default/`. Include timestamps if errors occurred at a specific time.
|
||||
placeholder: |
|
||||
```
|
||||
Paste error logs here
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Logs, Errors, Screenshots, and Additional Context
|
||||
description: Please provide any additional information you think might help. If you want to share the chat history
|
||||
you can click the thumbs-down (👎) button above the input field and you will get a shareable link
|
||||
(you can also click thumbs up when things are going well of course!). LLM logs will be stored in the
|
||||
`logs/llm/default` folder. Please add any additional context about the problem here.
|
||||
label: Screenshots and Additional Context
|
||||
description: |
|
||||
Add screenshots, videos, runtime environment, or other context that helps explain the issue.
|
||||
|
||||
💡 **Share conversation history:** In the OpenHands chat UI, click the 👎 or 👍 button (above the message input) to generate a shareable link to your conversation.
|
||||
|
||||
placeholder: Drag and drop screenshots here, paste links, or add additional context.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
---
|
||||
**Note:** Issues with incomplete information may be closed or deprioritized. Maintainers and community members have limited bandwidth and prioritize well-documented bugs that are easier to reproduce and fix. Thank you for your understanding!
|
||||
|
||||
17
.github/ISSUE_TEMPLATE/feature_request.md
vendored
17
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: Feature Request or Enhancement
|
||||
about: Suggest an idea for an OpenHands feature or enhancement
|
||||
title: ''
|
||||
labels: 'enhancement'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**What problem or use case are you trying to solve?**
|
||||
|
||||
**Describe the UX or technical implementation you have in mind**
|
||||
|
||||
**Additional context**
|
||||
|
||||
|
||||
### If you find this feature request or enhancement useful, make sure to add a 👍 to the issue
|
||||
105
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
105
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
name: Feature Request or Enhancement
|
||||
description: Suggest a new feature or improvement for OpenHands
|
||||
title: '[Feature]: '
|
||||
labels: ['enhancement']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Thank you for suggesting a feature! 💡
|
||||
|
||||
**Please provide detailed information.** Vague or low-effort requests may be closed. Well-documented feature requests with strong community support are more likely to be added to the roadmap.
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing feature request for this?
|
||||
description: Please search existing issues and feature requests before creating a new one. If found, react or comment to the duplicate issue instead of making a new one.
|
||||
options:
|
||||
- label: I have searched existing issues and feature requests, and this is not a duplicate.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: problem-statement
|
||||
attributes:
|
||||
label: Problem or Use Case
|
||||
description: What problem are you trying to solve? What use case would this feature enable?
|
||||
placeholder: |
|
||||
Example - "As a developer working on large codebases, I need to search across multiple files simultaneously. Currently, I have to search file-by-file which is time-consuming and inefficient."
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposed-solution
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: Describe your ideal solution. What should this feature do? How should it work?
|
||||
placeholder: |
|
||||
Example - "Add a global search feature that allows searching across all files in the workspace. Results should show file name, line number, and context around matches. Include regex support and filtering options."
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: Have you considered any alternative solutions or workarounds? What are their limitations?
|
||||
placeholder: Example - "I tried using grep in the terminal, but it's not integrated with the UI and doesn't provide click-to-navigate functionality."
|
||||
|
||||
- type: dropdown
|
||||
id: priority
|
||||
attributes:
|
||||
label: Priority / Severity
|
||||
description: How important is this feature to your workflow?
|
||||
options:
|
||||
- "Critical - Blocking my work, no workaround available"
|
||||
- "High - Significant impact on productivity"
|
||||
- "Medium - Would improve experience"
|
||||
- "Low - Nice to have"
|
||||
default: 2
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: scope
|
||||
attributes:
|
||||
label: Estimated Scope
|
||||
description: To the best of your knowledge, how complex do you think this feature would be to implement?
|
||||
options:
|
||||
- "Small - UI tweak, config option, or minor change"
|
||||
- "Medium - New feature with moderate complexity"
|
||||
- "Large - Significant feature requiring architecture changes"
|
||||
- "Unknown - Not sure about the technical complexity"
|
||||
default: 3
|
||||
|
||||
- type: dropdown
|
||||
id: feature-area
|
||||
attributes:
|
||||
label: Feature Area
|
||||
description: Which part of OpenHands does this feature relate to? If you select "Other", please specify the area in the Additional Context section below.
|
||||
options:
|
||||
- "Agent / AI behavior"
|
||||
- "User Interface / UX"
|
||||
- "CLI / Command-line interface"
|
||||
- "File system / Workspace management"
|
||||
- "Configuration / Settings"
|
||||
- "Integrations (GitHub, GitLab, etc.)"
|
||||
- "Performance / Optimization"
|
||||
- "Documentation"
|
||||
- "Other"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: technical-details
|
||||
attributes:
|
||||
label: Technical Implementation Ideas (Optional)
|
||||
description: If you have technical expertise, share implementation ideas, API suggestions, or relevant technical details.
|
||||
placeholder: |
|
||||
Example - "Could use ripgrep library for fast search. Expose results via /api/search endpoint. Frontend can use virtualized list for rendering large result sets."
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context, screenshots, mockups, or examples that help illustrate this feature request.
|
||||
placeholder: Drag and drop screenshots, mockups, or links here.
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -1,4 +1,4 @@
|
||||
<!-- Ideally you should open a PR when it is ready for review. Draft PRs will not be reviewed -->
|
||||
<!-- If you are still working on the PR, please mark it as draft. Maintainers will review PRs marked ready for review, which leads to lost time if your PR is actually not ready yet. Keep the PR marked as draft until it is finally ready for review -->
|
||||
|
||||
## Summary of PR
|
||||
|
||||
|
||||
29
.github/workflows/enterprise-preview.yml
vendored
29
.github/workflows/enterprise-preview.yml
vendored
@@ -1,29 +0,0 @@
|
||||
# Feature branch preview for enterprise code
|
||||
name: Enterprise Preview
|
||||
|
||||
# Run on PRs labeled
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
# Match ghcr-build.yml, but don't interrupt it.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
# This must happen for the PR Docker workflow when the label is present,
|
||||
# and also if it's added after the fact. Thus, it exists in both places.
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event.label.name == 'deploy'
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
# This should match the version in ghcr-build.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
16
.github/workflows/ghcr-build.yml
vendored
16
.github/workflows/ghcr-build.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "saas-rel-*"
|
||||
tags:
|
||||
- "*"
|
||||
pull_request:
|
||||
@@ -239,21 +240,6 @@ jobs:
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
needs: [ghcr_build_enterprise]
|
||||
steps:
|
||||
# This should match the version in enterprise-preview.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# "All Runtime Tests Passed" is a required job for PRs to merge
|
||||
# We can remove this once the config changes
|
||||
runtime_tests_check_success:
|
||||
|
||||
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
name: PR Review by OpenHands
|
||||
|
||||
on:
|
||||
# TEMPORARY MITIGATION (Clinejection hardening)
|
||||
#
|
||||
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
|
||||
# workflow is fully hardened for untrusted execution.
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
pr-review:
|
||||
# Note: fork PRs will not have access to repository secrets under `pull_request`.
|
||||
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
|
||||
if: |
|
||||
github.event.pull_request.head.repo.full_name == github.repository &&
|
||||
(
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
|
||||
github.event.action == 'ready_for_review' ||
|
||||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
|
||||
(
|
||||
github.event.action == 'review_requested' &&
|
||||
(
|
||||
github.event.requested_reviewer.login == 'openhands-agent' ||
|
||||
github.event.requested_reviewer.login == 'all-hands-bot'
|
||||
)
|
||||
)
|
||||
)
|
||||
concurrency:
|
||||
group: pr-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Run PR Review
|
||||
uses: OpenHands/extensions/plugins/pr-review@main
|
||||
with:
|
||||
llm-model: litellm_proxy/claude-sonnet-4-5-20250929
|
||||
llm-base-url: https://llm-proxy.app.all-hands.dev
|
||||
review-style: roasted
|
||||
llm-api-key: ${{ secrets.LLM_API_KEY }}
|
||||
github-token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
lmnr-api-key: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
name: PR Review Evaluation
|
||||
|
||||
# This workflow evaluates how well PR review comments were addressed.
|
||||
# It runs when a PR is closed to assess review effectiveness.
|
||||
#
|
||||
# Security note: pull_request_target is safe here because:
|
||||
# 1. Only triggers on PR close (not on code changes)
|
||||
# 2. Does not checkout PR code - only downloads artifacts from trusted workflow runs
|
||||
# 3. Runs evaluation scripts from the extensions repo, not from the PR
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
evaluate:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
PR_MERGED: ${{ github.event.pull_request.merged }}
|
||||
|
||||
steps:
|
||||
- name: Download review trace artifact
|
||||
id: download-trace
|
||||
uses: dawidd6/action-download-artifact@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
workflow: pr-review-by-openhands.yml
|
||||
name: pr-review-trace-${{ github.event.pull_request.number }}
|
||||
path: trace-info
|
||||
search_artifacts: true
|
||||
if_no_artifact_found: warn
|
||||
|
||||
- name: Check if trace file exists
|
||||
id: check-trace
|
||||
run: |
|
||||
if [ -f "trace-info/laminar_trace_info.json" ]; then
|
||||
echo "trace_exists=true" >> $GITHUB_OUTPUT
|
||||
echo "Found trace file for PR #$PR_NUMBER"
|
||||
else
|
||||
echo "trace_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No trace file found for PR #$PR_NUMBER - skipping evaluation"
|
||||
fi
|
||||
|
||||
# Always checkout main branch for security - cannot test script changes in PRs
|
||||
- name: Checkout extensions repository
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
repository: OpenHands/extensions
|
||||
path: extensions
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
run: pip install lmnr
|
||||
|
||||
- name: Run evaluation
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
env:
|
||||
# Script expects LMNR_PROJECT_API_KEY; org secret is named LMNR_SKILLS_API_KEY
|
||||
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
python extensions/plugins/pr-review/scripts/evaluate_review.py \
|
||||
--trace-file trace-info/laminar_trace_info.json
|
||||
|
||||
- name: Upload evaluation logs
|
||||
uses: actions/upload-artifact@v5
|
||||
if: always() && steps.check-trace.outputs.trace_exists == 'true'
|
||||
with:
|
||||
name: pr-review-evaluation-${{ github.event.pull_request.number }}
|
||||
path: '*.log'
|
||||
retention-days: 30
|
||||
@@ -6,11 +6,12 @@ Thanks for your interest in contributing to OpenHands! We welcome and appreciate
|
||||
|
||||
To understand the codebase, please refer to the README in each module:
|
||||
- [frontend](./frontend/README.md)
|
||||
- [evaluation](./evaluation/README.md)
|
||||
- [openhands](./openhands/README.md)
|
||||
- [agenthub](./openhands/agenthub/README.md)
|
||||
- [server](./openhands/server/README.md)
|
||||
|
||||
For benchmarks and evaluation, see the [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks) repository.
|
||||
|
||||
## Setting up Your Development Environment
|
||||
|
||||
We have a separate doc [Development.md](https://github.com/OpenHands/OpenHands/blob/main/Development.md) that tells
|
||||
|
||||
@@ -200,7 +200,7 @@ Here's a guide to the important documentation files in the repository:
|
||||
- [/frontend/README.md](./frontend/README.md): Frontend React application setup and development guide
|
||||
- [/containers/README.md](./containers/README.md): Information about Docker containers and deployment
|
||||
- [/tests/unit/README.md](./tests/unit/README.md): Guide to writing and running unit tests
|
||||
- [/evaluation/README.md](./evaluation/README.md): Documentation for the evaluation framework and benchmarks
|
||||
- [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks): Documentation for the evaluation framework and benchmarks
|
||||
- [/skills/README.md](./skills/README.md): Information about the skills architecture and implementation
|
||||
- [/openhands/server/README.md](./openhands/server/README.md): Server implementation details and API documentation
|
||||
- [/openhands/runtime/README.md](./openhands/runtime/README.md): Documentation for the runtime environment and execution model
|
||||
|
||||
@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
|
||||
@@ -440,12 +440,6 @@ type = "noop"
|
||||
#temperature = 0.1
|
||||
#max_input_tokens = 1024
|
||||
|
||||
#################################### Eval ####################################
|
||||
# Configuration for the evaluation, please refer to the specific evaluation
|
||||
# plugin for the available options
|
||||
##############################################################################
|
||||
|
||||
|
||||
########################### Kubernetes #######################################
|
||||
# Kubernetes configuration when using the Kubernetes runtime
|
||||
##############################################################################
|
||||
|
||||
@@ -23,12 +23,23 @@ RUN apt-get update && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python packages with security fixes
|
||||
RUN /app/.venv/bin/pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace "posthog>=6.0.0" "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy google-cloud-recaptcha-enterprise && \
|
||||
# Update packages with known CVE fixes
|
||||
/app/.venv/bin/pip install --upgrade \
|
||||
"mcp>=1.10.0" \
|
||||
"pillow>=11.3.0"
|
||||
# Install poetry and export before importing current code.
|
||||
RUN /app/.venv/bin/pip install poetry poetry-plugin-export
|
||||
|
||||
# Install Python dependencies from poetry.lock for reproducible builds
|
||||
# Copy lock files first for better Docker layer caching
|
||||
COPY --chown=openhands:openhands enterprise/pyproject.toml enterprise/poetry.lock /tmp/enterprise/
|
||||
RUN cd /tmp/enterprise && \
|
||||
# Export only main dependencies with hashes for supply chain security
|
||||
/app/.venv/bin/poetry export --only main -o requirements.txt && \
|
||||
# Remove the local path dependency (openhands-ai is already in base image)
|
||||
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
|
||||
# Install pinned dependencies from lock file
|
||||
/app/.venv/bin/pip install -r requirements.txt && \
|
||||
# Cleanup - return to /app before removing /tmp/enterprise
|
||||
cd /app && \
|
||||
rm -rf /tmp/enterprise && \
|
||||
/app/.venv/bin/pip uninstall -y poetry poetry-plugin-export
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This script can be removed once orgs is established - probably after Feb 15 2026
|
||||
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
|
||||
@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
return agent
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
# Skip experiment for planning agents which require their specialized prompt
|
||||
if agent.system_prompt_filename != 'system_prompt_planning.j2':
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -145,11 +145,7 @@ class GithubManager(Manager):
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
# Check event types before making expensive API calls (e.g., _user_has_write_access_to_repo)
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
@@ -159,8 +155,17 @@ class GithubManager(Manager):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
user_has_write_access = self._user_has_write_access_to_repo(
|
||||
installation_id, repo_name, username
|
||||
)
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
if (
|
||||
GithubFactory.is_eligible_for_conversation_starter(message)
|
||||
and user_has_write_access
|
||||
):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
return user_has_write_access
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
@@ -167,17 +167,15 @@ async def install_webhook_on_resource(
|
||||
scopes=SCOPES,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Creating new webhook',
|
||||
extra={
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
},
|
||||
)
|
||||
log_extra = {
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
}
|
||||
|
||||
if status == WebhookStatus.RATE_LIMITED:
|
||||
logger.warning('Rate limited while creating webhook', extra=log_extra)
|
||||
raise BreakLoopException()
|
||||
|
||||
if webhook_id:
|
||||
@@ -191,9 +189,8 @@ async def install_webhook_on_resource(
|
||||
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
|
||||
)
|
||||
logger.info('Created new webhook', extra=log_extra)
|
||||
else:
|
||||
logger.error('Failed to create webhook', extra=log_extra)
|
||||
|
||||
return webhook_id, status
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
@@ -14,6 +14,7 @@ class ResolverUserContext(UserContext):
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
self._provider_handler: ProviderHandler | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
@@ -29,12 +30,26 @@ class ResolverUserContext(UserContext):
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def _get_provider_handler(self) -> ProviderHandler:
|
||||
"""Get or create a ProviderHandler for git operations."""
|
||||
if self._provider_handler is None:
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('No provider tokens available')
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
self._provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_id=user_id
|
||||
)
|
||||
return self._provider_handler
|
||||
|
||||
async def get_authenticated_git_url(
|
||||
self, repository: str, is_optional: bool = False
|
||||
) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
provider_handler = await self._get_provider_handler()
|
||||
url = await provider_handler.get_authenticated_git_url(
|
||||
repository, is_optional=is_optional
|
||||
)
|
||||
return url
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
# Suppress alembic.runtime.plugins INFO logs during import to prevent non-JSON logs in production
|
||||
# These plugin setup messages would otherwise appear before logging is configured
|
||||
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine # noqa: E402
|
||||
from storage.base import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add byor_export_enabled flag to org table.
|
||||
|
||||
Revision ID: 091
|
||||
Revises: 090
|
||||
Create Date: 2025-01-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '091'
|
||||
down_revision: Union[str, None] = '090'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add byor_export_enabled column to org table with default false
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'byor_export_enabled',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Set byor_export_enabled to true for orgs that have completed billing sessions
|
||||
op.execute(
|
||||
sa.text("""
|
||||
UPDATE org SET byor_export_enabled = TRUE
|
||||
WHERE id IN (
|
||||
SELECT DISTINCT org_id FROM billing_sessions
|
||||
WHERE status = 'completed' AND org_id IS NOT NULL
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'byor_export_enabled')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Rename 'user' role to 'member' in role table.
|
||||
|
||||
Revision ID: 092
|
||||
Revises: 091
|
||||
Create Date: 2025-02-12 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '092'
|
||||
down_revision: Union[str, None] = '091'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename 'user' role to 'member' for clarity
|
||||
# This avoids confusion between the 'user' role and the 'user' entity/account
|
||||
op.execute(sa.text("UPDATE role SET name = 'member' WHERE name = 'user'"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert 'member' role back to 'user'
|
||||
op.execute(sa.text("UPDATE role SET name = 'user' WHERE name = 'member'"))
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add pending_free_credits flag to org table.
|
||||
|
||||
Revision ID: 093
|
||||
Revises: 092
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '093'
|
||||
down_revision: Union[str, None] = '092'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add pending_free_credits column to org table with default false.
|
||||
# New orgs will have this set to TRUE at creation time.
|
||||
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""create org_invitation table
|
||||
|
||||
Revision ID: 094
|
||||
Revises: 093
|
||||
Create Date: 2026-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '094'
|
||||
down_revision: Union[str, None] = '093'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create org_invitation table
|
||||
op.create_table(
|
||||
'org_invitation',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('token', sa.String(64), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Foreign key constraints
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'],
|
||||
['org.id'],
|
||||
name='org_invitation_org_fkey',
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['role_id'],
|
||||
['role.id'],
|
||||
name='org_invitation_role_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['inviter_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_inviter_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['accepted_by_user_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_accepter_fkey',
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index(
|
||||
'ix_org_invitation_token',
|
||||
'org_invitation',
|
||||
['token'],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_id',
|
||||
'org_invitation',
|
||||
['org_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_email',
|
||||
'org_invitation',
|
||||
['email'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_status',
|
||||
'org_invitation',
|
||||
['status'],
|
||||
)
|
||||
# Composite index for checking pending invitations
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_email_status',
|
||||
'org_invitation',
|
||||
['org_id', 'email', 'status'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('org_invitation')
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Drop pending_free_credits column from org table.
|
||||
|
||||
Revision ID: 095
|
||||
Revises: 094
|
||||
Create Date: 2025-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '095'
|
||||
down_revision: Union[str, None] = '094'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the pending_free_credits column from org table.
|
||||
# This column was used for tracking free credit eligibility but is no longer needed.
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add pending_free_credits column with default false.
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Create resend_synced_users table.
|
||||
|
||||
Revision ID: 096
|
||||
Revises: 095
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '096'
|
||||
down_revision: Union[str, None] = '095'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
|
||||
op.create_table(
|
||||
'resend_synced_users',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('audience_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'synced_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
|
||||
# Create index on email for fast lookups
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_email',
|
||||
'resend_synced_users',
|
||||
['email'],
|
||||
)
|
||||
|
||||
# Create index on audience_id for filtering by audience
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_audience_id',
|
||||
'resend_synced_users',
|
||||
['audience_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop resend_synced_users table."""
|
||||
op.drop_index(
|
||||
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
|
||||
)
|
||||
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
|
||||
op.drop_table('resend_synced_users')
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add session_api_key_hash to v1_remote_sandbox table
|
||||
|
||||
Revision ID: 097
|
||||
Revises: 096
|
||||
Create Date: 2025-02-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '097'
|
||||
down_revision: Union[str, None] = '096'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add session_api_key_hash column to v1_remote_sandbox table."""
|
||||
op.add_column(
|
||||
'v1_remote_sandbox',
|
||||
sa.Column('session_api_key_hash', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
'v1_remote_sandbox',
|
||||
['session_api_key_hash'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
|
||||
op.drop_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
table_name='v1_remote_sandbox',
|
||||
)
|
||||
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')
|
||||
212
enterprise/poetry.lock
generated
212
enterprise/poetry.lock
generated
@@ -6102,14 +6102,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.10.0"
|
||||
version = "1.11.5"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.10.0-py3-none-any.whl", hash = "sha256:2e21076fff5e7cf9d03a3b011e2c90a6a3a46d2da3f18db9f7553ac413229c22"},
|
||||
{file = "openhands_agent_server-1.10.0.tar.gz", hash = "sha256:2062da2496a98a6c23201d086f124e02329d6c6d9d1b47be55921c084a29f55a"},
|
||||
{file = "openhands_agent_server-1.11.5-py3-none-any.whl", hash = "sha256:8bae7063f232791d58a5c31919f58b557f7cce60e6295773985c7dadc556cb9e"},
|
||||
{file = "openhands_agent_server-1.11.5.tar.gz", hash = "sha256:b61366d727c61ab9b7fcd66faab53f230f8ef0928c1177a388d2c5c4be6ebbd0"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6126,7 +6126,7 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "1.2.1"
|
||||
version = "1.4.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@@ -6168,9 +6168,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.2"
|
||||
openhands-agent-server = "1.10"
|
||||
openhands-sdk = "1.10"
|
||||
openhands-tools = "1.10"
|
||||
openhands-agent-server = "1.11.5"
|
||||
openhands-sdk = "1.11.5"
|
||||
openhands-tools = "1.11.5"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6225,14 +6225,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.10.0"
|
||||
version = "1.11.5"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.10.0-py3-none-any.whl", hash = "sha256:5c8875f2a07d7fabe3449914639572bef9003821207cb06aa237a239e964eed5"},
|
||||
{file = "openhands_sdk-1.10.0.tar.gz", hash = "sha256:93371b1af4532266ad2d225b9d7d3d711c745df31888efe643970673f62bdef9"},
|
||||
{file = "openhands_sdk-1.11.5-py3-none-any.whl", hash = "sha256:f949cd540cbecc339d90fb0cca2a5f29e1b62566b82b5aee82ef40f259d14e60"},
|
||||
{file = "openhands_sdk-1.11.5.tar.gz", hash = "sha256:dd6225876b7b8dbb6c608559f2718c3d0bf44d0bb741e990b185c6cdc5150c5a"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6253,14 +6253,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.10.0"
|
||||
version = "1.11.5"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.10.0-py3-none-any.whl", hash = "sha256:1d5d2d1e34cc4ceb02c0ff1f008b06883ad48a8e7236ab8dd61ece64fbf8e2ed"},
|
||||
{file = "openhands_tools-1.10.0.tar.gz", hash = "sha256:7ed38cb13545ec2c4a35c26ece725d5b35788d30597db8b1904619c043ec1194"},
|
||||
{file = "openhands_tools-1.11.5-py3-none-any.whl", hash = "sha256:1e981e1e7f3544184fe946cee8eb6bd287010cdef77d83ebac945c9f42df3baf"},
|
||||
{file = "openhands_tools-1.11.5.tar.gz", hash = "sha256:d7b1163f6505a51b07147e7d8972062c129ecc46571a71f28d5470355e06650e"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6851,103 +6851,103 @@ scramp = ">=1.4.5"
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "12.1.0"
|
||||
version = "12.1.1"
|
||||
description = "Python Imaging Library (fork)"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main", "test"]
|
||||
files = [
|
||||
{file = "pillow-12.1.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:fb125d860738a09d363a88daa0f59c4533529a90e564785e20fe875b200b6dbd"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cad302dc10fac357d3467a74a9561c90609768a6f73a1923b0fd851b6486f8b0"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a40905599d8079e09f25027423aed94f2823adaf2868940de991e53a449e14a8"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:92a7fe4225365c5e3a8e598982269c6d6698d3e783b3b1ae979e7819f9cd55c1"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f10c98f49227ed8383d28174ee95155a675c4ed7f85e2e573b04414f7e371bda"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8637e29d13f478bc4f153d8daa9ffb16455f0a6cb287da1b432fdad2bfbd66c7"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:21e686a21078b0f9cb8c8a961d99e6a4ddb88e0fc5ea6e130172ddddc2e5221a"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2415373395a831f53933c23ce051021e79c8cd7979822d8cc478547a3f4da8ef"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win32.whl", hash = "sha256:e75d3dba8fc1ddfec0cd752108f93b83b4f8d6ab40e524a95d35f016b9683b09"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:64efdf00c09e31efd754448a383ea241f55a994fd079866b92d2bbff598aad91"},
|
||||
{file = "pillow-12.1.0-cp310-cp310-win_arm64.whl", hash = "sha256:f188028b5af6b8fb2e9a76ac0f841a575bd1bd396e46ef0840d9b88a48fdbcea"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:a83e0850cb8f5ac975291ebfc4170ba481f41a28065277f7f735c202cd8e0af3"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b6e53e82ec2db0717eabb276aa56cf4e500c9a7cec2c2e189b55c24f65a3e8c0"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:40a8e3b9e8773876d6e30daed22f016509e3987bab61b3b7fe309d7019a87451"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:800429ac32c9b72909c671aaf17ecd13110f823ddb7db4dfef412a5587c2c24e"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b022eaaf709541b391ee069f0022ee5b36c709df71986e3f7be312e46f42c84"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1f345e7bc9d7f368887c712aa5054558bad44d2a301ddf9248599f4161abc7c0"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d70347c8a5b7ccd803ec0c85c8709f036e6348f1e6a5bf048ecd9c64d3550b8b"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1fcc52d86ce7a34fd17cb04e87cfdb164648a3662a6f20565910a99653d66c18"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win32.whl", hash = "sha256:3ffaa2f0659e2f740473bcf03c702c39a8d4b2b7ffc629052028764324842c64"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:806f3987ffe10e867bab0ddad45df1148a2b98221798457fa097ad85d6e8bc75"},
|
||||
{file = "pillow-12.1.0-cp311-cp311-win_arm64.whl", hash = "sha256:9f5fefaca968e700ad1a4a9de98bf0869a94e397fe3524c4c9450c1445252304"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a332ac4ccb84b6dde65dbace8431f3af08874bf9770719d32a635c4ef411b18b"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:907bfa8a9cb790748a9aa4513e37c88c59660da3bcfffbd24a7d9e6abf224551"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:efdc140e7b63b8f739d09a99033aa430accce485ff78e6d311973a67b6bf3208"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bef9768cab184e7ae6e559c032e95ba8d07b3023c289f79a2bd36e8bf85605a5"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:742aea052cf5ab5034a53c3846165bc3ce88d7c38e954120db0ab867ca242661"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a6dfc2af5b082b635af6e08e0d1f9f1c4e04d17d4e2ca0ef96131e85eda6eb17"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:609e89d9f90b581c8d16358c9087df76024cf058fa693dd3e1e1620823f39670"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43b4899cfd091a9693a1278c4982f3e50f7fb7cff5153b05174b4afc9593b616"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win32.whl", hash = "sha256:aa0c9cc0b82b14766a99fbe6084409972266e82f459821cd26997a488a7261a7"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:d70534cea9e7966169ad29a903b99fc507e932069a881d0965a1a84bb57f6c6d"},
|
||||
{file = "pillow-12.1.0-cp312-cp312-win_arm64.whl", hash = "sha256:65b80c1ee7e14a87d6a068dd3b0aea268ffcabfe0498d38661b00c5b4b22e74c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:7b5dd7cbae20285cdb597b10eb5a2c13aa9de6cde9bb64a3c1317427b1db1ae1"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:29a4cef9cb672363926f0470afc516dbf7305a14d8c54f7abbb5c199cd8f8179"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:681088909d7e8fa9e31b9799aaa59ba5234c58e5e4f1951b4c4d1082a2e980e0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:983976c2ab753166dc66d36af6e8ec15bb511e4a25856e2227e5f7e00a160587"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:db44d5c160a90df2d24a24760bbd37607d53da0b34fb546c4c232af7192298ac"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6b7a9d1db5dad90e2991645874f708e87d9a3c370c243c2d7684d28f7e133e6b"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6258f3260986990ba2fa8a874f8b6e808cf5abb51a94015ca3dc3c68aa4f30ea"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e115c15e3bc727b1ca3e641a909f77f8ca72a64fff150f666fcc85e57701c26c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6741e6f3074a35e47c77b23a4e4f2d90db3ed905cb1c5e6e0d49bff2045632bc"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:935b9d1aed48fcfb3f838caac506f38e29621b44ccc4f8a64d575cb1b2a88644"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fee4c04aad8932da9f8f710af2c1a15a83582cfb884152a9caa79d4efcdbf9c"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win32.whl", hash = "sha256:a786bf667724d84aa29b5db1c61b7bfdde380202aaca12c3461afd6b71743171"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:461f9dfdafa394c59cd6d818bdfdbab4028b83b02caadaff0ffd433faf4c9a7a"},
|
||||
{file = "pillow-12.1.0-cp313-cp313-win_arm64.whl", hash = "sha256:9212d6b86917a2300669511ed094a9406888362e085f2431a7da985a6b124f45"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:00162e9ca6d22b7c3ee8e61faa3c3253cd19b6a37f126cad04f2f88b306f557d"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:7d6daa89a00b58c37cb1747ec9fb7ac3bc5ffd5949f5888657dfddde6d1312e0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e2479c7f02f9d505682dc47df8c0ea1fc5e264c4d1629a5d63fe3e2334b89554"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f188d580bd870cda1e15183790d1cc2fa78f666e76077d103edf048eed9c356e"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0fde7ec5538ab5095cc02df38ee99b0443ff0e1c847a045554cf5f9af1f4aa82"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0ed07dca4a8464bada6139ab38f5382f83e5f111698caf3191cb8dbf27d908b4"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:f45bd71d1fa5e5749587613037b172e0b3b23159d1c00ef2fc920da6f470e6f0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:277518bf4fe74aa91489e1b20577473b19ee70fb97c374aa50830b279f25841b"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win32.whl", hash = "sha256:7315f9137087c4e0ee73a761b163fc9aa3b19f5f606a7fc08d83fd3e4379af65"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:0ddedfaa8b5f0b4ffbc2fa87b556dc59f6bb4ecb14a53b33f9189713ae8053c0"},
|
||||
{file = "pillow-12.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:80941e6d573197a0c28f394753de529bb436b1ca990ed6e765cf42426abc39f8"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:5cb7bc1966d031aec37ddb9dcf15c2da5b2e9f7cc3ca7c54473a20a927e1eb91"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:97e9993d5ed946aba26baf9c1e8cf18adbab584b99f452ee72f7ee8acb882796"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:414b9a78e14ffeb98128863314e62c3f24b8a86081066625700b7985b3f529bd"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:e6bdb408f7c9dd2a5ff2b14a3b0bb6d4deb29fb9961e6eb3ae2031ae9a5cec13"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3413c2ae377550f5487991d444428f1a8ae92784aac79caa8b1e3b89b175f77e"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e5dcbe95016e88437ecf33544ba5db21ef1b8dd6e1b434a2cb2a3d605299e643"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d0a7735df32ccbcc98b98a1ac785cc4b19b580be1bdf0aeb5c03223220ea09d5"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0c27407a2d1b96774cbc4a7594129cc027339fd800cd081e44497722ea1179de"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15c794d74303828eaa957ff8070846d0efe8c630901a1c753fdc63850e19ecd9"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:c990547452ee2800d8506c4150280757f88532f3de2a58e3022e9b179107862a"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:b63e13dd27da389ed9475b3d28510f0f954bca0041e8e551b2a4eb1eab56a39a"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win32.whl", hash = "sha256:1a949604f73eb07a8adab38c4fe50791f9919344398bdc8ac6b307f755fc7030"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win_amd64.whl", hash = "sha256:4f9f6a650743f0ddee5593ac9e954ba1bdbc5e150bc066586d4f26127853ab94"},
|
||||
{file = "pillow-12.1.0-cp314-cp314-win_arm64.whl", hash = "sha256:808b99604f7873c800c4840f55ff389936ef1948e4e87645eaf3fccbc8477ac4"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:bc11908616c8a283cf7d664f77411a5ed2a02009b0097ff8abbba5e79128ccf2"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:896866d2d436563fa2a43a9d72f417874f16b5545955c54a64941e87c1376c61"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8e178e3e99d3c0ea8fc64b88447f7cac8ccf058af422a6cedc690d0eadd98c51"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:079af2fb0c599c2ec144ba2c02766d1b55498e373b3ac64687e43849fbbef5bc"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bdec5e43377761c5dbca620efb69a77f6855c5a379e32ac5b158f54c84212b14"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:565c986f4b45c020f5421a4cea13ef294dde9509a8577f29b2fc5edc7587fff8"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:43aca0a55ce1eefc0aefa6253661cb54571857b1a7b2964bd8a1e3ef4b729924"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0deedf2ea233722476b3a81e8cdfbad786f7adbed5d848469fa59fe52396e4ef"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win32.whl", hash = "sha256:b17fbdbe01c196e7e159aacb889e091f28e61020a8abeac07b68079b6e626988"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win_amd64.whl", hash = "sha256:27b9baecb428899db6c0de572d6d305cfaf38ca1596b5c0542a5182e3e74e8c6"},
|
||||
{file = "pillow-12.1.0-cp314-cp314t-win_arm64.whl", hash = "sha256:f61333d817698bdcdd0f9d7793e365ac3d2a21c1f1eb02b32ad6aefb8d8ea831"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:ca94b6aac0d7af2a10ba08c0f888b3d5114439b6b3ef39968378723622fed377"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:351889afef0f485b84078ea40fe33727a0492b9af3904661b0abbafee0355b72"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb0984b30e973f7e2884362b7d23d0a348c7143ee559f38ef3eaab640144204c"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:84cabc7095dd535ca934d57e9ce2a72ffd216e435a84acb06b2277b1de2689bd"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53d8b764726d3af1a138dd353116f774e3862ec7e3794e0c8781e30db0f35dfc"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5da841d81b1a05ef940a8567da92decaa15bc4d7dedb540a8c219ad83d91808a"},
|
||||
{file = "pillow-12.1.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:75af0b4c229ac519b155028fa1be632d812a519abba9b46b20e50c6caa184f19"},
|
||||
{file = "pillow-12.1.0.tar.gz", hash = "sha256:5c5ae0a06e9ea030ab786b0251b32c7e4ce10e58d983c0d5c56029455180b5b9"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1f1625b72740fdda5d77b4def688eb8fd6490975d06b909fd19f13f391e077e0"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:178aa072084bd88ec759052feca8e56cbb14a60b39322b99a049e58090479713"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b66e95d05ba806247aaa1561f080abc7975daf715c30780ff92a20e4ec546e1b"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:89c7e895002bbe49cdc5426150377cbbc04767d7547ed145473f496dfa40408b"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a5cbdcddad0af3da87cb16b60d23648bc3b51967eb07223e9fed77a82b457c4"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9f51079765661884a486727f0729d29054242f74b46186026582b4e4769918e4"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:99c1506ea77c11531d75e3a412832a13a71c7ebc8192ab9e4b2e355555920e3e"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:36341d06738a9f66c8287cf8b876d24b18db9bd8740fa0672c74e259ad408cff"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win32.whl", hash = "sha256:6c52f062424c523d6c4db85518774cc3d50f5539dd6eed32b8f6229b26f24d40"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:c6008de247150668a705a6338156efb92334113421ceecf7438a12c9a12dab23"},
|
||||
{file = "pillow-12.1.1-cp310-cp310-win_arm64.whl", hash = "sha256:1a9b0ee305220b392e1124a764ee4265bd063e54a751a6b62eff69992f457fa9"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:e879bb6cd5c73848ef3b2b48b8af9ff08c5b71ecda8048b7dd22d8a33f60be32"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:365b10bb9417dd4498c0e3b128018c4a624dc11c7b97d8cc54effe3b096f4c38"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d4ce8e329c93845720cd2014659ca67eac35f6433fd3050393d85f3ecef0dad5"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc354a04072b765eccf2204f588a7a532c9511e8b9c7f900e1b64e3e33487090"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e7976bf1910a8116b523b9f9f58bf410f3e8aa330cd9a2bb2953f9266ab49af"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:597bd9c8419bc7c6af5604e55847789b69123bbe25d65cc6ad3012b4f3c98d8b"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2c1fc0f2ca5f96a3c8407e41cca26a16e46b21060fe6d5b099d2cb01412222f5"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:578510d88c6229d735855e1f278aa305270438d36a05031dfaae5067cc8eb04d"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win32.whl", hash = "sha256:7311c0a0dcadb89b36b7025dfd8326ecfa36964e29913074d47382706e516a7c"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:fbfa2a7c10cc2623f412753cddf391c7f971c52ca40a3f65dc5039b2939e8563"},
|
||||
{file = "pillow-12.1.1-cp311-cp311-win_arm64.whl", hash = "sha256:b81b5e3511211631b3f672a595e3221252c90af017e399056d0faabb9538aa80"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ab323b787d6e18b3d91a72fc99b1a2c28651e4358749842b8f8dfacd28ef2052"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:adebb5bee0f0af4909c30db0d890c773d1a92ffe83da908e2e9e720f8edf3984"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bb66b7cc26f50977108790e2456b7921e773f23db5630261102233eb355a3b79"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:aee2810642b2898bb187ced9b349e95d2a7272930796e022efaf12e99dccd293"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a0b1cd6232e2b618adcc54d9882e4e662a089d5768cd188f7c245b4c8c44a397"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7aac39bcf8d4770d089588a2e1dd111cbaa42df5a94be3114222057d68336bd0"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ab174cd7d29a62dd139c44bf74b698039328f45cb03b4596c43473a46656b2f3"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:339ffdcb7cbeaa08221cd401d517d4b1fe7a9ed5d400e4a8039719238620ca35"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win32.whl", hash = "sha256:5d1f9575a12bed9e9eedd9a4972834b08c97a352bd17955ccdebfeca5913fa0a"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:21329ec8c96c6e979cd0dfd29406c40c1d52521a90544463057d2aaa937d66a6"},
|
||||
{file = "pillow-12.1.1-cp312-cp312-win_arm64.whl", hash = "sha256:af9a332e572978f0218686636610555ae3defd1633597be015ed50289a03c523"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:d242e8ac078781f1de88bf823d70c1a9b3c7950a44cdf4b7c012e22ccbcd8e4e"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:02f84dfad02693676692746df05b89cf25597560db2857363a208e393429f5e9"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:e65498daf4b583091ccbb2556c7000abf0f3349fcd57ef7adc9a84a394ed29f6"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6c6db3b84c87d48d0088943bf33440e0c42370b99b1c2a7989216f7b42eede60"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8b7e5304e34942bf62e15184219a7b5ad4ff7f3bb5cca4d984f37df1a0e1aee2"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:18e5bddd742a44b7e6b1e773ab5db102bd7a94c32555ba656e76d319d19c3850"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fc44ef1f3de4f45b50ccf9136999d71abb99dca7706bc75d222ed350b9fd2289"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5a8eb7ed8d4198bccbd07058416eeec51686b498e784eda166395a23eb99138e"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:47b94983da0c642de92ced1702c5b6c292a84bd3a8e1d1702ff923f183594717"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:518a48c2aab7ce596d3bf79d0e275661b846e86e4d0e7dec34712c30fe07f02a"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a550ae29b95c6dc13cf69e2c9dc5747f814c54eeb2e32d683e5e93af56caa029"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win32.whl", hash = "sha256:a003d7422449f6d1e3a34e3dd4110c22148336918ddbfc6a32581cd54b2e0b2b"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:344cf1e3dab3be4b1fa08e449323d98a2a3f819ad20f4b22e77a0ede31f0faa1"},
|
||||
{file = "pillow-12.1.1-cp313-cp313-win_arm64.whl", hash = "sha256:5c0dd1636633e7e6a0afe7bf6a51a14992b7f8e60de5789018ebbdfae55b040a"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0330d233c1a0ead844fc097a7d16c0abff4c12e856c0b325f231820fee1f39da"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5dae5f21afb91322f2ff791895ddd8889e5e947ff59f71b46041c8ce6db790bc"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e0c664be47252947d870ac0d327fea7e63985a08794758aa8af5b6cb6ec0c9c"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:691ab2ac363b8217f7d31b3497108fb1f50faab2f75dfb03284ec2f217e87bf8"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e9e8064fb1cc019296958595f6db671fba95209e3ceb0c4734c9baf97de04b20"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:472a8d7ded663e6162dafdf20015c486a7009483ca671cece7a9279b512fcb13"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:89b54027a766529136a06cfebeecb3a04900397a3590fd252160b888479517bf"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:86172b0831b82ce4f7877f280055892b31179e1576aa00d0df3bb1bbf8c3e524"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win32.whl", hash = "sha256:44ce27545b6efcf0fdbdceb31c9a5bdea9333e664cda58a7e674bb74608b3986"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win_amd64.whl", hash = "sha256:a285e3eb7a5a45a2ff504e31f4a8d1b12ef62e84e5411c6804a42197c1cf586c"},
|
||||
{file = "pillow-12.1.1-cp313-cp313t-win_arm64.whl", hash = "sha256:cc7d296b5ea4d29e6570dabeaed58d31c3fea35a633a69679fb03d7664f43fb3"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:417423db963cb4be8bac3fc1204fe61610f6abeed1580a7a2cbb2fbda20f12af"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:b957b71c6b2387610f556a7eb0828afbe40b4a98036fc0d2acfa5a44a0c2036f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:097690ba1f2efdeb165a20469d59d8bb03c55fb6621eb2041a060ae8ea3e9642"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2815a87ab27848db0321fb78c7f0b2c8649dee134b7f2b80c6a45c6831d75ccd"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:f7ed2c6543bad5a7d5530eb9e78c53132f93dfa44a28492db88b41cdab885202"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:652a2c9ccfb556235b2b501a3a7cf3742148cd22e04b5625c5fe057ea3e3191f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d6e4571eedf43af33d0fc233a382a76e849badbccdf1ac438841308652a08e1f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b574c51cf7d5d62e9be37ba446224b59a2da26dc4c1bb2ecbe936a4fb1a7cb7f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a37691702ed687799de29a518d63d4682d9016932db66d4e90c345831b02fb4e"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f95c00d5d6700b2b890479664a06e754974848afaae5e21beb4d83c106923fd0"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:559b38da23606e68681337ad74622c4dbba02254fc9cb4488a305dd5975c7eeb"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win32.whl", hash = "sha256:03edcc34d688572014ff223c125a3f77fb08091e4607e7745002fc214070b35f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win_amd64.whl", hash = "sha256:50480dcd74fa63b8e78235957d302d98d98d82ccbfac4c7e12108ba9ecbdba15"},
|
||||
{file = "pillow-12.1.1-cp314-cp314-win_arm64.whl", hash = "sha256:5cb1785d97b0c3d1d1a16bc1d710c4a0049daefc4935f3a8f31f827f4d3d2e7f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:1f90cff8aa76835cba5769f0b3121a22bd4eb9e6884cfe338216e557a9a548b8"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1f1be78ce9466a7ee64bfda57bdba0f7cc499d9794d518b854816c41bf0aa4e9"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:42fc1f4677106188ad9a55562bbade416f8b55456f522430fadab3cef7cd4e60"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98edb152429ab62a1818039744d8fbb3ccab98a7c29fc3d5fcef158f3f1f68b7"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d470ab1178551dd17fdba0fef463359c41aaa613cdcd7ff8373f54be629f9f8f"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6408a7b064595afcab0a49393a413732a35788f2a5092fdc6266952ed67de586"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:5d8c41325b382c07799a3682c1c258469ea2ff97103c53717b7893862d0c98ce"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:c7697918b5be27424e9ce568193efd13d925c4481dd364e43f5dff72d33e10f8"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win32.whl", hash = "sha256:d2912fd8114fc5545aa3a4b5576512f64c55a03f3ebcca4c10194d593d43ea36"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win_amd64.whl", hash = "sha256:4ceb838d4bd9dab43e06c363cab2eebf63846d6a4aeaea283bbdfd8f1a8ed58b"},
|
||||
{file = "pillow-12.1.1-cp314-cp314t-win_arm64.whl", hash = "sha256:7b03048319bfc6170e93bd60728a1af51d3dd7704935feb228c4d4faab35d334"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:600fd103672b925fe62ed08e0d874ea34d692474df6f4bf7ebe148b30f89f39f"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:665e1b916b043cef294bc54d47bf02d87e13f769bc4bc5fa225a24b3a6c5aca9"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:495c302af3aad1ca67420ddd5c7bd480c8867ad173528767d906428057a11f0e"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8fd420ef0c52c88b5a035a0886f367748c72147b2b8f384c9d12656678dfdfa9"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f975aa7ef9684ce7e2c18a3aa8f8e2106ce1e46b94ab713d156b2898811651d3"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8089c852a56c2966cf18835db62d9b34fef7ba74c726ad943928d494fa7f4735"},
|
||||
{file = "pillow-12.1.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:cb9bb857b2d057c6dfc72ac5f3b44836924ba15721882ef103cecb40d002d80e"},
|
||||
{file = "pillow-12.1.1.tar.gz", hash = "sha256:9ad8fa5937ab05218e2b6a4cff30295ad35afd2f83ac592e68c0d871bb0fdbc4"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -14917,4 +14917,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "^3.12,<3.14"
|
||||
content-hash = "b5cbb1e25176845ac9f95650a802667e2f8be1a536e3e55a9269b5af5a42e3fc"
|
||||
content-hash = "1cad6029269393af67155e930c72eae2c03da02e4b3a3699823f6168c14a4218"
|
||||
|
||||
@@ -44,6 +44,12 @@ httpx = "*"
|
||||
scikit-learn = "^1.7.0"
|
||||
shap = "^0.48.0"
|
||||
google-cloud-recaptcha-enterprise = "^1.24.0"
|
||||
# Dependencies previously only in Dockerfile, now managed via poetry.lock
|
||||
prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -38,6 +38,12 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
accept_router as invitation_accept_router,
|
||||
)
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
invitation_router,
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
@@ -78,8 +84,15 @@ base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
|
||||
from integrations.github.github_v1_callback_processor import ( # noqa: E402
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
from server.routes.integration.github import github_integration_router # noqa: E402
|
||||
|
||||
# Bludgeon mypy into not deleting my import
|
||||
logger.debug(f'Loaded {GithubV1CallbackProcessor.__name__}')
|
||||
|
||||
base_app.include_router(
|
||||
github_integration_router
|
||||
) # Add additional route for integration webhook events
|
||||
@@ -92,6 +105,8 @@ if GITLAB_APP_CLIENT_ID:
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
@@ -38,3 +38,9 @@ class ExpiredError(AuthError):
|
||||
"""Error when a token has expired (Usually the refresh token)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TokenRefreshError(AuthError):
|
||||
"""Error when token refresh fails due to timeout or lock contention"""
|
||||
|
||||
pass
|
||||
|
||||
306
enterprise/server/auth/authorization.py
Normal file
306
enterprise/server/auth/authorization.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Permission-based authorization dependencies for API endpoints.
|
||||
|
||||
This module provides FastAPI dependencies for checking user permissions
|
||||
within organizations. It uses a permission-based authorization model where
|
||||
roles (owner, admin, member) are mapped to specific permissions.
|
||||
|
||||
Permissions are defined in the Permission enum and mapped to roles via
|
||||
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
|
||||
the familiar role-based hierarchy.
|
||||
|
||||
Usage:
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with VIEW_LLM_SETTINGS permission can access
|
||||
...
|
||||
|
||||
@router.patch('/{org_id}/settings')
|
||||
async def update_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with EDIT_LLM_SETTINGS permission can access
|
||||
...
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Permissions that can be assigned to roles."""
|
||||
|
||||
# Secrets
|
||||
MANAGE_SECRETS = 'manage_secrets'
|
||||
|
||||
# MCP
|
||||
MANAGE_MCP = 'manage_mcp'
|
||||
|
||||
# Integrations
|
||||
MANAGE_INTEGRATIONS = 'manage_integrations'
|
||||
|
||||
# Application Settings
|
||||
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
|
||||
|
||||
# API Keys
|
||||
MANAGE_API_KEYS = 'manage_api_keys'
|
||||
|
||||
# LLM Settings
|
||||
VIEW_LLM_SETTINGS = 'view_llm_settings'
|
||||
EDIT_LLM_SETTINGS = 'edit_llm_settings'
|
||||
|
||||
# Billing
|
||||
VIEW_BILLING = 'view_billing'
|
||||
ADD_CREDITS = 'add_credits'
|
||||
|
||||
# Organization Members
|
||||
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
|
||||
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
|
||||
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
|
||||
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
|
||||
|
||||
# Organization Management
|
||||
VIEW_ORG_SETTINGS = 'view_org_settings'
|
||||
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
|
||||
DELETE_ORGANIZATION = 'delete_organization'
|
||||
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
MEMBER = 'member'
|
||||
|
||||
|
||||
# Permission mappings for each role
|
||||
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
RoleName.OWNER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
Permission.CHANGE_USER_ROLE_OWNER,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (synchronous version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
|
||||
else:
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return RoleStore.get_role_by_id(org_member.role_id)
|
||||
|
||||
|
||||
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
|
||||
parse_uuid(user_id)
|
||||
)
|
||||
else:
|
||||
org_member = await OrgMemberStore.get_org_member_async(
|
||||
org_id, parse_uuid(user_id)
|
||||
)
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return await RoleStore.get_role_by_id_async(org_member.role_id)
|
||||
|
||||
|
||||
def get_role_permissions(role_name: str) -> frozenset[Permission]:
|
||||
"""
|
||||
Get the permissions for a role.
|
||||
|
||||
Args:
|
||||
role_name: Name of the role
|
||||
|
||||
Returns:
|
||||
Set of permissions for the role
|
||||
"""
|
||||
try:
|
||||
role_enum = RoleName(role_name)
|
||||
return ROLE_PERMISSIONS.get(role_enum, frozenset())
|
||||
except ValueError:
|
||||
return frozenset()
|
||||
|
||||
|
||||
def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
"""
|
||||
Check if a role has a specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User's Role object
|
||||
permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if the role has the permission
|
||||
"""
|
||||
permissions = get_role_permissions(user_role.name)
|
||||
return permission in permissions
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
...
|
||||
|
||||
Args:
|
||||
permission: The permission required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that validates permission and returns user_id
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role_async(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'User not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if not has_permission(user_role, permission):
|
||||
logger.warning(
|
||||
'Insufficient permissions',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
'required_permission': permission.value,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'Requires {permission.value} permission',
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
@@ -1,11 +1,36 @@
|
||||
import asyncio
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import select
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
async def _user_has_gitlab_provider(user_id: str) -> bool:
|
||||
"""Check if the user has authenticated with GitLab.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if the user has a GitLab provider token, False otherwise
|
||||
"""
|
||||
# Lazy import to avoid circular dependency issues at module load time
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == user_id,
|
||||
AuthTokens.identity_provider == ProviderType.GITLAB.value,
|
||||
)
|
||||
)
|
||||
return result.scalars().first() is not None
|
||||
|
||||
|
||||
def schedule_gitlab_repo_sync(
|
||||
user_id: str, keycloak_access_token: SecretStr | None = None
|
||||
) -> None:
|
||||
@@ -14,10 +39,20 @@ def schedule_gitlab_repo_sync(
|
||||
Because the outer call is already a background task, we instruct the service
|
||||
to store repository data synchronously (store_in_background=False) to avoid
|
||||
nested background tasks while still keeping the overall operation async.
|
||||
|
||||
The sync is only performed if the user has authenticated with GitLab.
|
||||
"""
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
# Check if the user has a GitLab provider token before syncing
|
||||
if not await _user_has_gitlab_provider(user_id):
|
||||
logger.debug(
|
||||
'gitlab_repo_sync_skipped: user has no GitLab provider',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
# Lazy import to avoid circular dependency:
|
||||
# middleware -> gitlab_sync -> integrations.gitlab.gitlab_service
|
||||
# -> openhands.integrations.gitlab.gitlab_service -> get_impl
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class AssessmentResult:
|
||||
"""Result of a reCAPTCHA Enterprise assessment."""
|
||||
|
||||
name: str
|
||||
score: float
|
||||
valid: bool
|
||||
action_valid: bool
|
||||
@@ -63,6 +64,7 @@ class RecaptchaService:
|
||||
user_ip: str,
|
||||
user_agent: str,
|
||||
email: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> AssessmentResult:
|
||||
"""Create a reCAPTCHA Enterprise assessment.
|
||||
|
||||
@@ -72,6 +74,7 @@ class RecaptchaService:
|
||||
user_ip: The user's IP address.
|
||||
user_agent: The user's browser user agent.
|
||||
email: Optional email for Account Defender hashing.
|
||||
user_id: Optional Keycloak user ID for logging correlation.
|
||||
|
||||
Returns:
|
||||
AssessmentResult with score, validity, and allowed status.
|
||||
@@ -100,6 +103,10 @@ class RecaptchaService:
|
||||
|
||||
response = self.client.create_assessment(request)
|
||||
|
||||
# Capture assessment name for potential annotation later
|
||||
# Format: projects/{project_id}/assessments/{assessment_id}
|
||||
assessment_name = response.name
|
||||
|
||||
token_properties = response.token_properties
|
||||
risk_analysis = response.risk_analysis
|
||||
|
||||
@@ -129,6 +136,7 @@ class RecaptchaService:
|
||||
logger.info(
|
||||
'recaptcha_assessment',
|
||||
extra={
|
||||
'assessment_name': assessment_name,
|
||||
'score': score,
|
||||
'valid': valid,
|
||||
'action_valid': action_valid,
|
||||
@@ -137,10 +145,13 @@ class RecaptchaService:
|
||||
'has_suspicious_labels': has_suspicious_labels,
|
||||
'allowed': allowed,
|
||||
'user_ip': user_ip,
|
||||
'user_id': user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
|
||||
return AssessmentResult(
|
||||
name=assessment_name,
|
||||
score=score,
|
||||
valid=valid,
|
||||
action_valid=action_valid,
|
||||
|
||||
@@ -216,9 +216,9 @@ class SaasUserAuth(UserAuth):
|
||||
|
||||
async def get_mcp_api_key(self) -> str:
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id)
|
||||
mcp_api_key = await api_key_store.retrieve_mcp_api_key(self.user_id)
|
||||
if not mcp_api_key:
|
||||
mcp_api_key = api_key_store.create_api_key(
|
||||
mcp_api_key = await api_key_store.create_api_key(
|
||||
self.user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
return mcp_api_key
|
||||
|
||||
@@ -49,6 +49,10 @@ from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.types import SessionExpiredError
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# HTTP timeout for external IDP calls (in seconds)
|
||||
# This prevents indefinite blocking if an IDP is slow or unresponsive
|
||||
IDP_HTTP_TIMEOUT = 15.0
|
||||
|
||||
|
||||
def _before_sleep_callback(retry_state: RetryCallState) -> None:
|
||||
logger.info(f'Retry attempt {retry_state.attempt_number} for Keycloak operation')
|
||||
@@ -202,7 +206,9 @@ class TokenManager:
|
||||
access_token: str,
|
||||
idp: ProviderType,
|
||||
) -> dict[str, str | int]:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
base_url = KEYCLOAK_SERVER_URL_EXT if self.external else KEYCLOAK_SERVER_URL
|
||||
url = f'{base_url}/realms/{KEYCLOAK_REALM_NAME}/broker/{idp.value}/token'
|
||||
headers = {
|
||||
@@ -361,7 +367,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitHub token')
|
||||
@@ -387,7 +395,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitLab token')
|
||||
@@ -415,7 +425,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed Bitbucket token')
|
||||
|
||||
@@ -15,6 +15,11 @@ IS_FEATURE_ENV = (
|
||||
) # Does not include the staging deployment
|
||||
IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
|
||||
# Role name constants
|
||||
ROLE_OWNER = 'owner'
|
||||
ROLE_ADMIN = 'admin'
|
||||
ROLE_MEMBER = 'member'
|
||||
|
||||
# Deprecated - billing margins are now handled internally in litellm
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
@@ -25,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
# Minimax is now the default as it gives results close to claude in terms of quality
|
||||
# but at a much lower price
|
||||
5: 'minimax-m2.5',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
@@ -54,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
|
||||
@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
|
||||
obj['stack_info'] = format_stack(stack_info)
|
||||
|
||||
result = json.dumps(obj, **kwargs)
|
||||
|
||||
# Swap out newlines to make things easier to read. This will produce
|
||||
# invalid json but means we can have similar logs in local development
|
||||
# to production, making things easier to correlate. Obviously,
|
||||
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
|
||||
if LOG_JSON_FOR_CONSOLE:
|
||||
result = result.replace('\\n', '\n')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
# NOTE: these details are specific to the MCP protocol
|
||||
class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
@staticmethod
|
||||
def create_default_mcp_server_config(
|
||||
async def create_default_mcp_server_config(
|
||||
host: str, config: 'OpenHandsConfig', user_id: str | None = None
|
||||
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
|
||||
"""
|
||||
@@ -38,10 +38,12 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
api_key = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
api_key = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
if not api_key:
|
||||
api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None)
|
||||
api_key = await api_key_store.create_api_key(
|
||||
user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
logger.error(f'Could not provision MCP API Key for user: {user_id}')
|
||||
|
||||
@@ -103,7 +103,7 @@ class SetAuthCookieMiddleware:
|
||||
keycloak_auth_cookie = request.cookies.get('keycloak_auth')
|
||||
auth_header = request.headers.get('Authorization')
|
||||
mcp_auth_header = request.headers.get('X-Session-API-Key')
|
||||
accepted_tos = False
|
||||
accepted_tos: bool | None = False
|
||||
if (
|
||||
keycloak_auth_cookie is None
|
||||
and (auth_header is None or not auth_header.startswith('Bearer '))
|
||||
@@ -160,10 +160,10 @@ class SetAuthCookieMiddleware:
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/api/organizations/members/invite/accept',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
'/api/v1/webhooks/secrets',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
@@ -174,6 +174,10 @@ class SetAuthCookieMiddleware:
|
||||
):
|
||||
return False
|
||||
|
||||
# Webhooks access is controlled using separate API keys
|
||||
if path.startswith('/api/v1/webhooks/'):
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
@@ -2,66 +2,58 @@ from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
|
||||
def _get_byor_key():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
return None
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
def _update_user_settings():
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
|
||||
try:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
@@ -144,9 +136,9 @@ class ApiKeyCreate(BaseModel):
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ApiKeyCreateResponse(ApiKeyResponse):
|
||||
@@ -157,58 +149,78 @@ class LlmApiKeyResponse(BaseModel):
|
||||
key: str | None
|
||||
|
||||
|
||||
@api_router.post('', response_model=ApiKeyCreateResponse)
|
||||
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
|
||||
class ByorPermittedResponse(BaseModel):
|
||||
permitted: bool
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor/permitted', tags=['Keys'])
|
||||
async def check_byor_permitted(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> ByorPermittedResponse:
|
||||
"""Check if BYOR key export is permitted for the user's current org."""
|
||||
try:
|
||||
permitted = await OrgService.check_byor_export_enabled(user_id)
|
||||
return ByorPermittedResponse(permitted=permitted)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error checking BYOR export permission', extra={'error': str(e)}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to check BYOR export permission',
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('', tags=['Keys'])
|
||||
async def create_api_key(
|
||||
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
|
||||
) -> ApiKeyCreateResponse:
|
||||
"""Create a new API key for the authenticated user."""
|
||||
try:
|
||||
api_key = api_key_store.create_api_key(
|
||||
api_key = await api_key_store.create_api_key(
|
||||
user_id, key_data.name, key_data.expires_at
|
||||
)
|
||||
# Get the created key details
|
||||
keys = api_key_store.list_api_keys(user_id)
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
for key in keys:
|
||||
if key['name'] == key_data.name:
|
||||
return {
|
||||
**key,
|
||||
'key': api_key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
if key.name == key_data.name:
|
||||
return ApiKeyCreateResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
key=api_key,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error creating API key')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('', response_model=list[ApiKeyResponse])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('', tags=['Keys'])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
|
||||
"""List all API keys for the authenticated user."""
|
||||
try:
|
||||
keys = api_key_store.list_api_keys(user_id)
|
||||
return [
|
||||
{
|
||||
**key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
for key in keys
|
||||
]
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
return [api_key_to_response(key) for key in keys]
|
||||
except Exception:
|
||||
logger.exception('Error listing API keys')
|
||||
raise HTTPException(
|
||||
@@ -217,16 +229,18 @@ async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete('/{key_id}')
|
||||
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
@api_router.delete('/{key_id}', tags=['Keys'])
|
||||
async def delete_api_key(
|
||||
key_id: int, user_id: str = Depends(get_user_id)
|
||||
) -> MessageResponse:
|
||||
"""Delete an API key."""
|
||||
try:
|
||||
# First, verify the key belongs to the user
|
||||
keys = api_key_store.list_api_keys(user_id)
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
key_to_delete = None
|
||||
|
||||
for key in keys:
|
||||
if key['id'] == key_id:
|
||||
if key.id == key_id:
|
||||
key_to_delete = key
|
||||
break
|
||||
|
||||
@@ -244,7 +258,7 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete API key',
|
||||
)
|
||||
return {'message': 'API key deleted successfully'}
|
||||
return MessageResponse(message='API key deleted successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
@@ -255,22 +269,33 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
|
||||
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
This endpoint validates that the key exists in LiteLLM before returning it.
|
||||
If validation fails, it automatically generates a new key to ensure users
|
||||
always receive a working key.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Check if the BYOR key exists in the database
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
return LlmApiKeyResponse(key=byor_key)
|
||||
else:
|
||||
# Key exists in DB but is invalid in LiteLLM - regenerate it
|
||||
logger.warning(
|
||||
@@ -295,7 +320,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'Successfully generated and stored new BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate new BYOR LLM API key',
|
||||
@@ -317,12 +342,24 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
|
||||
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
|
||||
@api_router.post('/llm/byor/refresh', tags=['Keys'])
|
||||
async def refresh_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
@@ -361,7 +398,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'BYOR LLM API key refresh completed successfully',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
except HTTPException as he:
|
||||
logger.error(
|
||||
'HTTP exception during BYOR LLM API key refresh',
|
||||
|
||||
@@ -5,6 +5,7 @@ import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
import posthog
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
|
||||
@@ -26,6 +27,13 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
@@ -104,22 +112,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
|
||||
)
|
||||
|
||||
|
||||
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
|
||||
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token, invitation_token).
|
||||
Tokens may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None, None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return (
|
||||
state_data.get('redirect_url', ''),
|
||||
state_data.get('recaptcha_token'),
|
||||
state_data.get('invitation_token'),
|
||||
)
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None, None
|
||||
|
||||
|
||||
# Keep alias for backward compatibility
|
||||
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
|
||||
"""Extract redirect URL and reCAPTCHA token from OAuth state.
|
||||
|
||||
Deprecated: Use _extract_oauth_state instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token). Token may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None
|
||||
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
|
||||
return redirect_url, recaptcha_token
|
||||
|
||||
|
||||
@oauth_router.get('/keycloak/callback')
|
||||
@@ -130,8 +156,8 @@ async def keycloak_callback(
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
):
|
||||
# Extract redirect URL and reCAPTCHA token from state
|
||||
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
|
||||
# Extract redirect URL, reCAPTCHA token, and invitation token from state
|
||||
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
|
||||
if not redirect_url:
|
||||
redirect_url = str(request.base_url)
|
||||
|
||||
@@ -179,6 +205,9 @@ async def keycloak_callback(
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
else:
|
||||
# Existing user — gradually backfill contact_name if it still has a username-style value
|
||||
await UserStore.backfill_contact_name(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
@@ -219,6 +248,7 @@ async def keycloak_callback(
|
||||
user_ip=user_ip,
|
||||
user_agent=user_agent,
|
||||
email=email,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not result.allowed:
|
||||
@@ -298,8 +328,13 @@ async def keycloak_callback(
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
# Preserve invitation token so it can be included in OAuth state after verification
|
||||
if invitation_token:
|
||||
verification_redirect_url = (
|
||||
f'{verification_redirect_url}&invitation_token={invitation_token}'
|
||||
)
|
||||
response = RedirectResponse(verification_redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
@@ -377,14 +412,90 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
|
||||
# Process invitation token if present (after email verification but before TOS)
|
||||
if invitation_token:
|
||||
try:
|
||||
logger.info(
|
||||
'Processing invitation token during auth callback',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'invitation_token_prefix': invitation_token[:10] + '...',
|
||||
},
|
||||
)
|
||||
|
||||
await OrgInvitationService.accept_invitation(
|
||||
invitation_token, parse_uuid(user_id)
|
||||
)
|
||||
logger.info(
|
||||
'Invitation accepted during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation expired during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Add query param to redirect URL
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invalid invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'User already member during invitation acceptance',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Email mismatch during auth callback invitation acceptance',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error processing invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
# Don't fail the login if invitation processing fails
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_error=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_error=true'
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
tos_redirect_url = (
|
||||
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
|
||||
)
|
||||
if invitation_token:
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
else:
|
||||
if invitation_token:
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
set_response_cookie(
|
||||
|
||||
@@ -9,14 +9,13 @@ from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import (
|
||||
STRIPE_API_KEY,
|
||||
)
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -94,9 +93,9 @@ async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(user.current_org_id)
|
||||
)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
@@ -148,7 +147,7 @@ async def create_customer_setup_session(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{base_url}?free_credits=success',
|
||||
success_url=f'{base_url}?setup=success',
|
||||
cancel_url=f'{base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
@@ -250,15 +249,21 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
org = session.query(Org).filter(Org.id == user.current_org_id).first()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Enable BYOR export for the org now that they've purchased credits
|
||||
if org:
|
||||
org.byor_export_enabled = True
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
|
||||
@@ -371,9 +371,7 @@ async def create_jira_workspace(request: Request, workspace_data: JiraWorkspaceC
|
||||
'prompt': 'consent',
|
||||
}
|
||||
|
||||
auth_url = (
|
||||
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
)
|
||||
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -432,9 +430,7 @@ async def create_workspace_link(request: Request, link_data: JiraLinkCreate):
|
||||
'response_type': 'code',
|
||||
'prompt': 'consent',
|
||||
}
|
||||
auth_url = (
|
||||
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
)
|
||||
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import (
|
||||
@@ -316,7 +316,7 @@ async def create_jira_dc_workspace(
|
||||
'response_type': 'code',
|
||||
}
|
||||
|
||||
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -436,7 +436,7 @@ async def create_workspace_link(request: Request, link_data: JiraDcLinkCreate):
|
||||
'state': state,
|
||||
'response_type': 'code',
|
||||
}
|
||||
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
||||
@@ -272,7 +272,7 @@ async def device_verification_authenticated(
|
||||
try:
|
||||
# Create a unique API key for this device using user_code in the name
|
||||
device_key_name = f'{API_KEY_NAME} ({user_code})'
|
||||
api_key_store.create_api_key(
|
||||
await api_key_store.create_api_key(
|
||||
user_id,
|
||||
name=device_key_name,
|
||||
expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME,
|
||||
|
||||
122
enterprise/server/routes/org_invitation_models.py
Normal file
122
enterprise/server/routes/org_invitation_models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Pydantic models and custom exceptions for organization invitations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
class InvitationError(Exception):
|
||||
"""Base exception for invitation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvitationAlreadyExistsError(InvitationError):
|
||||
"""Raised when a pending invitation already exists for the email."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = 'A pending invitation already exists for this email'
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserAlreadyMemberError(InvitationError):
|
||||
"""Raised when the user is already a member of the organization."""
|
||||
|
||||
def __init__(self, message: str = 'User is already a member of this organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationExpiredError(InvitationError):
|
||||
"""Raised when the invitation has expired."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation has expired'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationInvalidError(InvitationError):
|
||||
"""Raised when the invitation is invalid or revoked."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation is no longer valid'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionError(InvitationError):
|
||||
"""Raised when the user lacks permission to perform the action."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EmailMismatchError(InvitationError):
|
||||
"""Raised when the accepting user's email doesn't match the invitation email."""
|
||||
|
||||
def __init__(self, message: str = 'Your email does not match the invitation'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationCreate(BaseModel):
|
||||
"""Request model for creating invitation(s)."""
|
||||
|
||||
emails: list[EmailStr]
|
||||
role: str = 'member' # Default to member role
|
||||
|
||||
|
||||
class InvitationResponse(BaseModel):
|
||||
"""Response model for invitation details."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
created_at: str
|
||||
expires_at: str
|
||||
inviter_email: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_invitation(
|
||||
cls,
|
||||
invitation: OrgInvitation,
|
||||
inviter_email: str | None = None,
|
||||
) -> 'InvitationResponse':
|
||||
"""Create an InvitationResponse from an OrgInvitation entity.
|
||||
|
||||
Args:
|
||||
invitation: The invitation entity to convert
|
||||
inviter_email: Optional email of the inviter
|
||||
|
||||
Returns:
|
||||
InvitationResponse: The response model instance
|
||||
"""
|
||||
role_name = ''
|
||||
if invitation.role:
|
||||
role_name = invitation.role.name
|
||||
elif invitation.role_id:
|
||||
role = RoleStore.get_role_by_id(invitation.role_id)
|
||||
role_name = role.name if role else ''
|
||||
|
||||
return cls(
|
||||
id=invitation.id,
|
||||
email=invitation.email,
|
||||
role=role_name,
|
||||
status=invitation.status,
|
||||
created_at=invitation.created_at.isoformat(),
|
||||
expires_at=invitation.expires_at.isoformat(),
|
||||
inviter_email=inviter_email,
|
||||
)
|
||||
|
||||
|
||||
class InvitationFailure(BaseModel):
|
||||
"""Response model for a failed invitation."""
|
||||
|
||||
email: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchInvitationResponse(BaseModel):
|
||||
"""Response model for batch invitation creation."""
|
||||
|
||||
successful: list[InvitationResponse]
|
||||
failed: list[InvitationFailure]
|
||||
226
enterprise/server/routes/org_invitations.py
Normal file
226
enterprise/server/routes/org_invitations.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""API routes for organization invitations."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from server.routes.org_invitation_models import (
|
||||
BatchInvitationResponse,
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationCreate,
|
||||
InvitationExpiredError,
|
||||
InvitationFailure,
|
||||
InvitationInvalidError,
|
||||
InvitationResponse,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
|
||||
# Router for invitation operations on an organization (requires org_id)
|
||||
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
|
||||
|
||||
# Router for accepting invitations (no org_id required)
|
||||
accept_router = APIRouter(prefix='/api/organizations/members/invite')
|
||||
|
||||
|
||||
@invitation_router.post(
|
||||
'/invite',
|
||||
response_model=BatchInvitationResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
invitation_data: InvitationCreate,
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Create organization invitations for multiple email addresses.
|
||||
|
||||
Sends emails to invitees with secure links to join the organization.
|
||||
Supports batch invitations - some may succeed while others fail.
|
||||
|
||||
Permission rules:
|
||||
- Only owners and admins can create invitations
|
||||
- Admins can only invite with 'member' or 'admin' role (not 'owner')
|
||||
- Owners can invite with any role
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
invitation_data: Invitation details (emails array, role)
|
||||
request: FastAPI request
|
||||
user_id: Authenticated user ID (from dependency)
|
||||
|
||||
Returns:
|
||||
BatchInvitationResponse: Lists of successful and failed invitations
|
||||
|
||||
Raises:
|
||||
HTTPException 400: Invalid role or organization not found
|
||||
HTTPException 403: User lacks permission to invite
|
||||
HTTPException 429: Rate limit exceeded
|
||||
"""
|
||||
# Rate limit: 10 invitations per minute per user (6 seconds between requests)
|
||||
await check_rate_limit_by_user_id(
|
||||
request=request,
|
||||
key_prefix='org_invitation_create',
|
||||
user_id=user_id,
|
||||
user_rate_limit_seconds=6,
|
||||
)
|
||||
|
||||
try:
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=[str(email) for email in invitation_data.emails],
|
||||
role_name=invitation_data.role,
|
||||
inviter_id=UUID(user_id),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Batch organization invitations created',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'total_emails': len(invitation_data.emails),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
'inviter_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
return BatchInvitationResponse(
|
||||
successful=[InvitationResponse.from_invitation(inv) for inv in successful],
|
||||
failed=[
|
||||
InvitationFailure(email=email, error=error) for email, error in failed
|
||||
],
|
||||
)
|
||||
|
||||
except InsufficientPermissionError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error creating batch invitations',
|
||||
extra={'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@accept_router.get('/accept')
|
||||
async def accept_invitation(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""Accept an organization invitation via token.
|
||||
|
||||
This endpoint is accessed via the link in the invitation email.
|
||||
|
||||
Flow:
|
||||
1. If user is authenticated: Accept invitation directly and redirect to home
|
||||
2. If user is not authenticated: Redirect to login page with invitation token
|
||||
- Frontend stores token and includes it in OAuth state during login
|
||||
- After authentication, keycloak_callback processes the invitation
|
||||
|
||||
Args:
|
||||
token: The invitation token from the email link
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
|
||||
or home page with error query params on failure
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
|
||||
# Try to get user_id from auth (may not be authenticated)
|
||||
user_id = None
|
||||
try:
|
||||
user_auth = await get_user_auth(request)
|
||||
if user_auth:
|
||||
user_id = await user_auth.get_user_id()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not user_id:
|
||||
# User not authenticated - redirect to login page with invitation token
|
||||
# Frontend will store the token and include it in OAuth state during login
|
||||
logger.info(
|
||||
'Invitation accept: redirecting unauthenticated user to login',
|
||||
extra={'token_prefix': token[:10] + '...'},
|
||||
)
|
||||
login_url = f'{base_url}/login?invitation_token={token}'
|
||||
return RedirectResponse(login_url, status_code=302)
|
||||
|
||||
# User is authenticated - process the invitation directly
|
||||
try:
|
||||
await OrgInvitationService.accept_invitation(token, UUID(user_id))
|
||||
|
||||
logger.info(
|
||||
'Invitation accepted successfully',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Redirect to home page on success
|
||||
return RedirectResponse(f'{base_url}/', status_code=302)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation accept failed: expired',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: invalid',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'Invitation accept: user already member',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Invitation accept failed: email mismatch',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error accepting invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)
|
||||
@@ -1,5 +1,9 @@
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field, SecretStr, StringConstraints
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class OrgCreationError(Exception):
|
||||
@@ -41,6 +45,16 @@ class OrgAuthorizationError(OrgDeletionError):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrphanedUserError(OrgDeletionError):
|
||||
"""Raised when deleting an org would leave users without any organization."""
|
||||
|
||||
def __init__(self, user_ids: list[str]):
|
||||
self.user_ids = user_ids
|
||||
super().__init__(
|
||||
f'Cannot delete organization: {len(user_ids)} user(s) would have no remaining organization'
|
||||
)
|
||||
|
||||
|
||||
class OrgNotFoundError(Exception):
|
||||
"""Raised when organization is not found or user doesn't have access."""
|
||||
|
||||
@@ -49,13 +63,70 @@ class OrgNotFoundError(Exception):
|
||||
super().__init__(f'Organization with id "{org_id}" not found')
|
||||
|
||||
|
||||
class OrgMemberNotFoundError(Exception):
|
||||
"""Raised when a member is not found in an organization."""
|
||||
|
||||
def __init__(self, org_id: str, user_id: str):
|
||||
self.org_id = org_id
|
||||
self.user_id = user_id
|
||||
super().__init__(f'Member "{user_id}" not found in organization "{org_id}"')
|
||||
|
||||
|
||||
class RoleNotFoundError(Exception):
|
||||
"""Raised when a role is not found."""
|
||||
|
||||
def __init__(self, role_id: int):
|
||||
self.role_id = role_id
|
||||
super().__init__(f'Role with id "{role_id}" not found')
|
||||
|
||||
|
||||
class InvalidRoleError(Exception):
|
||||
"""Raised when an invalid role name is specified."""
|
||||
|
||||
def __init__(self, role_name: str):
|
||||
self.role_name = role_name
|
||||
super().__init__(f'Invalid role: "{role_name}"')
|
||||
|
||||
|
||||
class InsufficientPermissionError(Exception):
|
||||
"""Raised when user lacks permission to perform an operation."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CannotModifySelfError(Exception):
|
||||
"""Raised when user attempts to modify their own membership."""
|
||||
|
||||
def __init__(self, action: str = 'modify'):
|
||||
self.action = action
|
||||
super().__init__(f'Cannot {action} your own membership')
|
||||
|
||||
|
||||
class LastOwnerError(Exception):
|
||||
"""Raised when attempting to remove or demote the last owner."""
|
||||
|
||||
def __init__(self, action: str = 'remove'):
|
||||
self.action = action
|
||||
super().__init__(f'Cannot {action} the last owner of an organization')
|
||||
|
||||
|
||||
class MemberUpdateError(Exception):
|
||||
"""Raised when member update operation fails."""
|
||||
|
||||
def __init__(self, message: str = 'Failed to update member'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class OrgCreate(BaseModel):
|
||||
"""Request model for creating a new organization."""
|
||||
|
||||
# Required fields
|
||||
name: str = Field(min_length=1, max_length=255, strip_whitespace=True)
|
||||
name: Annotated[
|
||||
str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255)
|
||||
]
|
||||
contact_name: str
|
||||
contact_email: EmailStr = Field(strip_whitespace=True)
|
||||
contact_email: EmailStr
|
||||
|
||||
|
||||
class OrgResponse(BaseModel):
|
||||
@@ -87,14 +158,18 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis: bool | None = None
|
||||
v1_enabled: bool | None = None
|
||||
credits: float | None = None
|
||||
is_personal: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_org(cls, org: Org, credits: float | None = None) -> 'OrgResponse':
|
||||
def from_org(
|
||||
cls, org: Org, credits: float | None = None, user_id: str | None = None
|
||||
) -> 'OrgResponse':
|
||||
"""Create an OrgResponse from an Org entity.
|
||||
|
||||
Args:
|
||||
org: The organization entity to convert
|
||||
credits: Optional credits value (defaults to None)
|
||||
user_id: Optional user ID to determine if org is personal (defaults to None)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The response model instance
|
||||
@@ -130,6 +205,7 @@ class OrgResponse(BaseModel):
|
||||
enable_solvability_analysis=org.enable_solvability_analysis,
|
||||
v1_enabled=org.v1_enabled,
|
||||
credits=credits,
|
||||
is_personal=str(org.id) == user_id if user_id else False,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,14 +214,19 @@ class OrgPage(BaseModel):
|
||||
|
||||
items: list[OrgResponse]
|
||||
next_page_id: str | None = None
|
||||
current_org_id: str | None = None
|
||||
|
||||
|
||||
class OrgUpdate(BaseModel):
|
||||
"""Request model for updating an organization."""
|
||||
|
||||
# Basic organization information (any authenticated user can update)
|
||||
name: Annotated[
|
||||
str | None,
|
||||
StringConstraints(strip_whitespace=True, min_length=1, max_length=255),
|
||||
] = None
|
||||
contact_name: str | None = None
|
||||
contact_email: EmailStr | None = Field(default=None, strip_whitespace=True)
|
||||
contact_email: EmailStr | None = None
|
||||
conversation_expiration: int | None = None
|
||||
default_max_iterations: int | None = Field(default=None, gt=0)
|
||||
remote_runtime_resource_factor: int | None = Field(default=None, gt=0)
|
||||
@@ -169,3 +250,79 @@ class OrgUpdate(BaseModel):
|
||||
confirmation_mode: bool | None = None
|
||||
enable_default_condenser: bool | None = None
|
||||
condenser_max_size: int | None = Field(default=None, ge=20)
|
||||
|
||||
|
||||
class OrgMemberResponse(BaseModel):
|
||||
"""Response model for a single organization member."""
|
||||
|
||||
user_id: str
|
||||
email: str | None
|
||||
role_id: int
|
||||
role: str
|
||||
role_rank: int
|
||||
status: str | None
|
||||
|
||||
|
||||
class OrgMemberPage(BaseModel):
|
||||
"""Paginated response for organization members."""
|
||||
|
||||
items: list[OrgMemberResponse]
|
||||
next_page_id: str | None = None
|
||||
|
||||
|
||||
class OrgMemberUpdate(BaseModel):
|
||||
"""Request model for updating an organization member."""
|
||||
|
||||
role: str | None = None # Role name: 'owner', 'admin', or 'member'
|
||||
|
||||
|
||||
class MeResponse(BaseModel):
|
||||
"""Response model for the current user's membership in an organization."""
|
||||
|
||||
org_id: str
|
||||
user_id: str
|
||||
email: str
|
||||
role: str
|
||||
llm_api_key: str
|
||||
max_iterations: int | None = None
|
||||
llm_model: str | None = None
|
||||
llm_api_key_for_byor: str | None = None
|
||||
llm_base_url: str | None = None
|
||||
status: str | None = None
|
||||
|
||||
@staticmethod
|
||||
def _mask_key(secret: SecretStr | None) -> str:
|
||||
"""Mask an API key, showing only last 4 characters."""
|
||||
if secret is None:
|
||||
return ''
|
||||
raw = secret.get_secret_value()
|
||||
if not raw:
|
||||
return ''
|
||||
if len(raw) <= 4:
|
||||
return '****'
|
||||
return '****' + raw[-4:]
|
||||
|
||||
@classmethod
|
||||
def from_org_member(cls, member: OrgMember, role: Role, email: str) -> 'MeResponse':
|
||||
"""Create a MeResponse from an OrgMember, Role, and user email.
|
||||
|
||||
Args:
|
||||
member: The OrgMember entity
|
||||
role: The Role entity (provides role name)
|
||||
email: The user's email address
|
||||
|
||||
Returns:
|
||||
MeResponse with masked API keys
|
||||
"""
|
||||
return cls(
|
||||
org_id=str(member.org_id),
|
||||
user_id=str(member.user_id),
|
||||
email=email,
|
||||
role=role.name,
|
||||
llm_api_key=cls._mask_key(member.llm_api_key),
|
||||
max_iterations=member.max_iterations,
|
||||
llm_model=member.llm_model,
|
||||
llm_api_key_for_byor=cls._mask_key(member.llm_api_key_for_byor) or None,
|
||||
llm_base_url=member.llm_base_url,
|
||||
status=member.status,
|
||||
)
|
||||
|
||||
@@ -2,19 +2,37 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
LiteLLMIntegrationError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgAuthorizationError,
|
||||
OrgCreate,
|
||||
OrgDatabaseError,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
OrgNameExistsError,
|
||||
OrgNotFoundError,
|
||||
OrgPage,
|
||||
OrgResponse,
|
||||
OrgUpdate,
|
||||
OrphanedUserError,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from server.services.org_member_service import OrgMemberService
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -61,6 +79,12 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch user to get current_org_id
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
current_org_id = (
|
||||
str(user.current_org_id) if user and user.current_org_id else None
|
||||
)
|
||||
|
||||
# Fetch organizations from service layer
|
||||
orgs, next_page_id = OrgService.get_user_orgs_paginated(
|
||||
user_id=user_id,
|
||||
@@ -69,7 +93,9 @@ async def list_user_orgs(
|
||||
)
|
||||
|
||||
# Convert Org entities to OrgResponse objects
|
||||
org_responses = [OrgResponse.from_org(org, credits=None) for org in orgs]
|
||||
org_responses = [
|
||||
OrgResponse.from_org(org, credits=None, user_id=user_id) for org in orgs
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'Successfully retrieved organizations',
|
||||
@@ -80,7 +106,11 @@ async def list_user_orgs(
|
||||
},
|
||||
)
|
||||
|
||||
return OrgPage(items=org_responses, next_page_id=next_page_id)
|
||||
return OrgPage(
|
||||
items=org_responses,
|
||||
next_page_id=next_page_id,
|
||||
current_org_id=current_org_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@@ -136,7 +166,7 @@ async def create_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
@@ -174,23 +204,26 @@ async def create_org(
|
||||
@org_router.get('/{org_id}', response_model=OrgResponse, status_code=status.HTTP_200_OK)
|
||||
async def get_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Get organization details by ID.
|
||||
|
||||
This endpoint allows authenticated users who are members of an organization
|
||||
to retrieve its details. Only members of the organization can access this endpoint.
|
||||
This endpoint retrieves details for a specific organization. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 404 if organization not found or user is not a member
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -211,7 +244,7 @@ async def get_org(
|
||||
# Retrieve credits from LiteLLM
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits)
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -228,26 +261,86 @@ async def get_org(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/me', response_model=MeResponse)
|
||||
async def get_me(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> MeResponse:
|
||||
"""Get the current user's membership record for an organization.
|
||||
|
||||
Returns the authenticated user's role, status, email, and LLM override
|
||||
fields (with masked API keys) within the specified organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
MeResponse: The user's membership data
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if user is not a member or org doesn't exist
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Retrieving current member details',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
try:
|
||||
user_uuid = UUID(user_id)
|
||||
return OrgMemberService.get_me(org_id, user_uuid)
|
||||
|
||||
except OrgMemberNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Organization with id "{org_id}" not found',
|
||||
)
|
||||
except RoleNotFoundError as e:
|
||||
logger.exception(
|
||||
'Role not found for org member',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'role_id': e.role_id,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error retrieving member details',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}', status_code=status.HTTP_200_OK)
|
||||
async def delete_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.DELETE_ORGANIZATION)),
|
||||
) -> dict:
|
||||
"""Delete an organization.
|
||||
|
||||
This endpoint allows authenticated organization owners to delete their organization.
|
||||
All associated data including organization members, conversations, billing data,
|
||||
and external LiteLLM team resources will be permanently removed.
|
||||
This endpoint permanently deletes an organization and all associated data including
|
||||
organization members, conversations, billing data, and external LiteLLM team resources.
|
||||
Access requires the DELETE_ORGANIZATION permission, which is granted only to owners.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to delete
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
org_id: Organization ID to delete (UUID)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message with deleted organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not the organization owner
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks DELETE_ORGANIZATION permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if deletion fails
|
||||
"""
|
||||
@@ -303,6 +396,19 @@ async def delete_org(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrphanedUserError as e:
|
||||
logger.warning(
|
||||
'Cannot delete organization: users would be orphaned',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'orphaned_users': e.user_ids,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database error during organization deletion',
|
||||
@@ -327,25 +433,26 @@ async def delete_org(
|
||||
async def update_org(
|
||||
org_id: UUID,
|
||||
update_data: OrgUpdate,
|
||||
user_id: str = Depends(get_user_id),
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_ORG_SETTINGS)),
|
||||
) -> OrgResponse:
|
||||
"""Update an existing organization.
|
||||
|
||||
This endpoint allows authenticated users to update organization settings.
|
||||
LLM-related settings require admin or owner role in the organization.
|
||||
This endpoint updates organization settings. Access requires the EDIT_ORG_SETTINGS
|
||||
permission, which is granted to admin and owner roles.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to update (UUID validated by FastAPI)
|
||||
org_id: Organization ID to update (UUID)
|
||||
update_data: Organization update data
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The updated organization details
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if org_id is invalid UUID format (handled by FastAPI)
|
||||
HTTPException: 403 if user lacks permission for LLM settings
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks EDIT_ORG_SETTINGS permission
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 409 if organization name already exists
|
||||
HTTPException: 422 if validation errors occur (handled by FastAPI)
|
||||
HTTPException: 500 if update fails
|
||||
"""
|
||||
@@ -368,7 +475,7 @@ async def update_org(
|
||||
# Retrieve credits from LiteLLM (following same pattern as create endpoint)
|
||||
credits = await OrgService.get_org_credits(user_id, updated_org.id)
|
||||
|
||||
return OrgResponse.from_org(updated_org, credits=credits)
|
||||
return OrgResponse.from_org(updated_org, credits=credits, user_id=user_id)
|
||||
|
||||
except ValueError as e:
|
||||
# Organization not found
|
||||
@@ -376,6 +483,11 @@ async def update_org(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgNameExistsError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except PermissionError as e:
|
||||
# User lacks permission for LLM settings
|
||||
raise HTTPException(
|
||||
@@ -400,3 +512,314 @@ async def update_org(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get('/{org_id}/members')
|
||||
async def get_org_members(
|
||||
org_id: UUID,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(title='Optional next_page_id from the previously returned page'),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_ORG_SETTINGS)),
|
||||
) -> OrgMemberPage:
|
||||
"""Get all members of an organization with cursor-based pagination.
|
||||
|
||||
This endpoint retrieves a paginated list of organization members. Access requires
|
||||
the VIEW_ORG_SETTINGS permission, which is granted to all organization members
|
||||
(member, admin, and owner roles).
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
page_id: Optional page ID (offset) for pagination
|
||||
limit: Maximum number of members to return (1-100, default 100)
|
||||
user_id: Authenticated user ID (injected by require_permission dependency)
|
||||
|
||||
Returns:
|
||||
OrgMemberPage: Paginated list of organization members
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks VIEW_ORG_SETTINGS permission
|
||||
HTTPException: 400 if org_id or page_id format is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
try:
|
||||
success, error_code, data = await OrgMemberService.get_org_members(
|
||||
org_id=org_id,
|
||||
current_user_id=UUID(user_id),
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not success:
|
||||
error_map = {
|
||||
'not_a_member': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You are not a member of this organization',
|
||||
),
|
||||
'invalid_page_id': (
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'Invalid page_id format',
|
||||
),
|
||||
}
|
||||
status_code, detail = error_map.get(
|
||||
error_code, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
|
||||
if data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve members',
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error retrieving organization members')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve members',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}/members/{user_id}')
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Remove a member from an organization.
|
||||
|
||||
Only owners and admins can remove members:
|
||||
- Owners can remove admins and regular users
|
||||
- Admins can only remove regular users
|
||||
|
||||
Users cannot remove themselves. The last owner cannot be removed.
|
||||
"""
|
||||
try:
|
||||
success, error = await OrgMemberService.remove_org_member(
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
)
|
||||
|
||||
if not success:
|
||||
error_map = {
|
||||
'not_a_member': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You are not a member of this organization',
|
||||
),
|
||||
'cannot_remove_self': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'Cannot remove yourself from an organization',
|
||||
),
|
||||
'member_not_found': (
|
||||
status.HTTP_404_NOT_FOUND,
|
||||
'Member not found in this organization',
|
||||
),
|
||||
'insufficient_permission': (
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
'You do not have permission to remove this member',
|
||||
),
|
||||
'cannot_remove_last_owner': (
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
'Cannot remove the last owner of an organization',
|
||||
),
|
||||
'removal_failed': (
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
'Failed to remove member',
|
||||
),
|
||||
}
|
||||
status_code, detail = error_map.get(
|
||||
error, (status.HTTP_500_INTERNAL_SERVER_ERROR, 'An error occurred')
|
||||
)
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
|
||||
return {'message': 'Member removed successfully'}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization or user ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error removing organization member')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to remove member',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def switch_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Switch to a different organization.
|
||||
|
||||
This endpoint allows authenticated users to switch their current active
|
||||
organization. The user must be a member of the target organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to switch to (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details that was switched to
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 403 if user is not a member of the organization
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if switch fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to switch organization with membership validation
|
||||
org = await OrgService.switch_org(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM for the new current org
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed during organization switch',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to switch organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error switching organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
|
||||
@org_router.patch('/{org_id}/members/{user_id}', response_model=OrgMemberResponse)
|
||||
async def update_org_member(
|
||||
org_id: UUID,
|
||||
user_id: str,
|
||||
update_data: OrgMemberUpdate,
|
||||
current_user_id: str = Depends(get_user_id),
|
||||
) -> OrgMemberResponse:
|
||||
"""Update a member's role in an organization.
|
||||
|
||||
Permission rules:
|
||||
- Admins can change roles of regular members to Admin or Member
|
||||
- Admins cannot modify other Admins or Owners
|
||||
- Owners can change roles of Admins and Members to any role (Owner, Admin, Member)
|
||||
- Owners cannot modify other Owners
|
||||
|
||||
Members cannot modify their own role. The last owner cannot be demoted.
|
||||
"""
|
||||
try:
|
||||
return await OrgMemberService.update_org_member(
|
||||
org_id=org_id,
|
||||
target_user_id=UUID(user_id),
|
||||
current_user_id=UUID(current_user_id),
|
||||
update_data=update_data,
|
||||
)
|
||||
except OrgMemberNotFoundError as e:
|
||||
# Distinguish between requester not being a member vs target not found
|
||||
if str(current_user_id) in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You are not a member of this organization',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Member not found in this organization',
|
||||
)
|
||||
except CannotModifySelfError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Cannot modify your own role',
|
||||
)
|
||||
except RoleNotFoundError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Role configuration error',
|
||||
)
|
||||
except InvalidRoleError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid role specified',
|
||||
)
|
||||
except InsufficientPermissionError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='You do not have permission to modify this member',
|
||||
)
|
||||
except LastOwnerError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot demote the last owner of an organization',
|
||||
)
|
||||
except MemberUpdateError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
except ValueError:
|
||||
logger.exception('Invalid UUID format')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Invalid organization or user ID format',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error updating organization member')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
@@ -121,6 +122,8 @@ async def saas_get_user(
|
||||
login=(user_info.get('preferred_username') if user_info else '') or '',
|
||||
avatar_url='',
|
||||
email=user_info.get('email') if user_info else None,
|
||||
name=resolve_display_name(user_info) if user_info else None,
|
||||
company=user_info.get('company') if user_info else None,
|
||||
),
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
@@ -516,11 +516,13 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_mcp_config(self, user_id: str) -> MCPConfig | None:
|
||||
async def _get_mcp_config(self, user_id: str) -> MCPConfig | None:
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
mcp_api_key = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
mcp_api_key = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
if not mcp_api_key:
|
||||
mcp_api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None)
|
||||
mcp_api_key = await api_key_store.create_api_key(
|
||||
user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
if not mcp_api_key:
|
||||
return None
|
||||
web_host = os.environ.get('WEB_HOST', 'app.all-hands.dev')
|
||||
@@ -547,7 +549,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
'conversation_id': sid,
|
||||
}
|
||||
|
||||
mcp_config = self._get_mcp_config(user_id)
|
||||
mcp_config = await self._get_mcp_config(user_id)
|
||||
if mcp_config:
|
||||
# Merge with any MCP config from settings
|
||||
if settings.mcp_config:
|
||||
@@ -1137,6 +1139,71 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
}
|
||||
update_conversation_metadata(conversation_id, metadata_content)
|
||||
|
||||
async def list_files(self, sid: str, path: str | None = None) -> list[str]:
|
||||
"""List files in the workspace for a conversation.
|
||||
|
||||
Delegates to the nested container's list-files endpoint.
|
||||
|
||||
Args:
|
||||
sid: The session/conversation ID.
|
||||
path: Optional path to list files from. If None, lists from workspace root.
|
||||
|
||||
Returns:
|
||||
A list of file paths.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_list_files_from_nested(
|
||||
sid, nested_url, session_api_key, path
|
||||
)
|
||||
|
||||
async def select_file(self, sid: str, file: str) -> tuple[str | None, str | None]:
|
||||
"""Read a file from the workspace via nested container.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_select_file_from_nested(
|
||||
sid, nested_url, session_api_key, file
|
||||
)
|
||||
|
||||
async def upload_files(
|
||||
self, sid: str, files: list[tuple[str, bytes]]
|
||||
) -> tuple[list[str], list[dict[str, str]]]:
|
||||
"""Upload files to the workspace via nested container.
|
||||
|
||||
Raises:
|
||||
ValueError: If the conversation is not running.
|
||||
httpx.HTTPError: If there's an error communicating with the nested runtime.
|
||||
"""
|
||||
runtime = await self._get_runtime(sid)
|
||||
if runtime is None or runtime.get('status') != 'running':
|
||||
raise ValueError(f'Conversation {sid} is not running')
|
||||
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
session_api_key = runtime.get('session_api_key')
|
||||
|
||||
return await self._fetch_upload_files_to_nested(
|
||||
sid, nested_url, session_api_key, files
|
||||
)
|
||||
|
||||
|
||||
def _last_updated_at_key(conversation: ConversationMetadata) -> float:
|
||||
last_updated_at = conversation.last_updated_at
|
||||
|
||||
131
enterprise/server/services/email_service.py
Normal file
131
enterprise/server/services/email_service.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""Email service for sending transactional emails via Resend."""
|
||||
|
||||
import os
|
||||
|
||||
try:
|
||||
import resend
|
||||
|
||||
RESEND_AVAILABLE = True
|
||||
except ImportError:
|
||||
RESEND_AVAILABLE = False
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
DEFAULT_FROM_EMAIL = 'OpenHands <no-reply@openhands.dev>'
|
||||
DEFAULT_WEB_HOST = 'https://app.all-hands.dev'
|
||||
|
||||
|
||||
class EmailService:
|
||||
"""Service for sending transactional emails."""
|
||||
|
||||
@staticmethod
|
||||
def _get_resend_client() -> bool:
|
||||
"""Initialize and return the Resend client.
|
||||
|
||||
Returns:
|
||||
bool: True if client is ready, False otherwise
|
||||
"""
|
||||
if not RESEND_AVAILABLE:
|
||||
logger.warning('Resend library not installed, skipping email')
|
||||
return False
|
||||
|
||||
resend_api_key = os.environ.get('RESEND_API_KEY')
|
||||
if not resend_api_key:
|
||||
logger.warning('RESEND_API_KEY not configured, skipping email')
|
||||
return False
|
||||
|
||||
resend.api_key = resend_api_key
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def send_invitation_email(
|
||||
to_email: str,
|
||||
org_name: str,
|
||||
inviter_name: str,
|
||||
role_name: str,
|
||||
invitation_token: str,
|
||||
invitation_id: int,
|
||||
) -> None:
|
||||
"""Send an organization invitation email.
|
||||
|
||||
Args:
|
||||
to_email: Recipient's email address
|
||||
org_name: Name of the organization
|
||||
inviter_name: Display name of the person who sent the invite
|
||||
role_name: Role being offered (e.g., 'member', 'admin')
|
||||
invitation_token: The secure invitation token
|
||||
invitation_id: The invitation ID for logging
|
||||
"""
|
||||
if not EmailService._get_resend_client():
|
||||
return
|
||||
|
||||
# Build invitation URL
|
||||
web_host = os.environ.get('WEB_HOST', DEFAULT_WEB_HOST)
|
||||
invitation_url = f'{web_host}/api/organizations/members/invite/accept?token={invitation_token}'
|
||||
|
||||
from_email = os.environ.get('RESEND_FROM_EMAIL', DEFAULT_FROM_EMAIL)
|
||||
|
||||
params = {
|
||||
'from': from_email,
|
||||
'to': [to_email],
|
||||
'subject': f"You're invited to join {org_name} on OpenHands",
|
||||
'html': f"""
|
||||
<div style="font-family: Arial, sans-serif; max-width: 600px; margin: 0 auto;">
|
||||
<p>Hi,</p>
|
||||
|
||||
<p><strong>{inviter_name}</strong> has invited you to join <strong>{org_name}</strong> on OpenHands as a <strong>{role_name}</strong>.</p>
|
||||
|
||||
<p>Click the button below to accept the invitation:</p>
|
||||
|
||||
<p style="margin: 30px 0;">
|
||||
<a href="{invitation_url}"
|
||||
style="background-color: #c9b974; color: #0D0F11; padding: 8px 16px;
|
||||
text-decoration: none; border-radius: 8px; display: inline-block;
|
||||
font-size: 14px; font-weight: 600;">
|
||||
Accept Invitation
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
Or copy and paste this link into your browser:<br>
|
||||
<a href="{invitation_url}" style="color: #c9b974; font-weight: 600;">{invitation_url}</a>
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
This invitation will expire in 7 days.
|
||||
</p>
|
||||
|
||||
<p style="color: #666; font-size: 14px;">
|
||||
If you weren't expecting this invitation, you can safely ignore this email.
|
||||
</p>
|
||||
|
||||
<hr style="border: none; border-top: 1px solid #eee; margin: 30px 0;">
|
||||
|
||||
<p style="color: #999; font-size: 12px;">
|
||||
Best,<br>
|
||||
The OpenHands Team
|
||||
</p>
|
||||
</div>
|
||||
""",
|
||||
}
|
||||
|
||||
try:
|
||||
response = resend.Emails.send(params)
|
||||
logger.info(
|
||||
'Invitation email sent',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'response_id': response.get('id') if response else None,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'email': to_email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise
|
||||
397
enterprise/server/services/org_invitation_service.py
Normal file
397
enterprise/server/services/org_invitation_service.py
Normal file
@@ -0,0 +1,397 @@
|
||||
"""Service for managing organization invitations."""
|
||||
|
||||
import asyncio
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import ROLE_ADMIN, ROLE_OWNER
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.services.email_service import EmailService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import OrgInvitationStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgInvitationService:
|
||||
"""Service for organization invitation operations."""
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the organization exists
|
||||
2. Validates this is not a personal workspace
|
||||
3. Checks inviter has owner/admin role
|
||||
4. Validates role assignment permissions
|
||||
5. Checks if user is already a member
|
||||
6. Creates the invitation
|
||||
7. Sends the invitation email
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
UserAlreadyMemberError: If email is already a member
|
||||
InvitationAlreadyExistsError: If pending invitation exists
|
||||
"""
|
||||
email = email.lower().strip()
|
||||
|
||||
logger.info(
|
||||
'Creating organization invitation',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
# Step 2: Check this is not a personal workspace
|
||||
# A personal workspace has org_id matching the user's id
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
# Step 3: Check inviter is a member and has permission
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
# Step 4: Validate role assignment permissions
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
# Get the target role
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 5: Check if user is already a member (by email)
|
||||
existing_user = await UserStore.get_user_by_email_async(email)
|
||||
if existing_user:
|
||||
existing_member = OrgMemberStore.get_org_member(org_id, existing_user.id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'User is already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 6: Create the invitation
|
||||
invitation = await OrgInvitationStore.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_id=target_role.id,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Step 7: Send invitation email
|
||||
try:
|
||||
# Get inviter info for the email
|
||||
inviter_user = UserStore.get_user_by_id(str(inviter_member.user_id))
|
||||
inviter_name = 'A team member'
|
||||
if inviter_user and inviter_user.email:
|
||||
inviter_name = inviter_user.email.split('@')[0]
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email=email,
|
||||
org_name=org.name,
|
||||
inviter_name=inviter_name,
|
||||
role_name=target_role.name,
|
||||
invitation_token=invitation.token,
|
||||
invitation_id=invitation.id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to send invitation email',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'email': email,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Don't fail the invitation creation if email fails
|
||||
# The user can still access via direct link
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def create_invitations_batch(
|
||||
org_id: UUID,
|
||||
emails: list[str],
|
||||
role_name: str,
|
||||
inviter_id: UUID,
|
||||
) -> tuple[list[OrgInvitation], list[tuple[str, str]]]:
|
||||
"""Create multiple organization invitations concurrently.
|
||||
|
||||
Validates permissions once upfront, then creates invitations in parallel.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
emails: List of invitee email addresses
|
||||
role_name: Role to assign on acceptance (owner, admin, member)
|
||||
inviter_id: User ID of the person creating the invitations
|
||||
|
||||
Returns:
|
||||
Tuple of (successful_invitations, failed_emails_with_errors)
|
||||
|
||||
Raises:
|
||||
ValueError: If organization or role not found
|
||||
InsufficientPermissionError: If inviter lacks permission
|
||||
"""
|
||||
logger.info(
|
||||
'Creating batch organization invitations',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'email_count': len(emails),
|
||||
'role_name': role_name,
|
||||
'inviter_id': str(inviter_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Validate permissions upfront (shared for all emails)
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise ValueError(f'Organization {org_id} not found')
|
||||
|
||||
if str(org_id) == str(inviter_id):
|
||||
raise InsufficientPermissionError(
|
||||
'Cannot invite users to a personal workspace'
|
||||
)
|
||||
|
||||
inviter_member = OrgMemberStore.get_org_member(org_id, inviter_id)
|
||||
if not inviter_member:
|
||||
raise InsufficientPermissionError(
|
||||
'You are not a member of this organization'
|
||||
)
|
||||
|
||||
inviter_role = RoleStore.get_role_by_id(inviter_member.role_id)
|
||||
if not inviter_role or inviter_role.name not in [ROLE_OWNER, ROLE_ADMIN]:
|
||||
raise InsufficientPermissionError('Only owners and admins can invite users')
|
||||
|
||||
role_name_lower = role_name.lower()
|
||||
if role_name_lower == ROLE_OWNER and inviter_role.name != ROLE_OWNER:
|
||||
raise InsufficientPermissionError('Only owners can invite with owner role')
|
||||
|
||||
target_role = RoleStore.get_role_by_name(role_name_lower)
|
||||
if not target_role:
|
||||
raise ValueError(f'Invalid role: {role_name}')
|
||||
|
||||
# Step 2: Create invitations concurrently
|
||||
async def create_single(
|
||||
email: str,
|
||||
) -> tuple[str, OrgInvitation | None, str | None]:
|
||||
"""Create single invitation, return (email, invitation, error)."""
|
||||
try:
|
||||
invitation = await OrgInvitationService.create_invitation(
|
||||
org_id=org_id,
|
||||
email=email,
|
||||
role_name=role_name,
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
return (email, invitation, None)
|
||||
except (UserAlreadyMemberError, ValueError) as e:
|
||||
return (email, None, str(e))
|
||||
|
||||
results = await asyncio.gather(*[create_single(email) for email in emails])
|
||||
|
||||
# Step 3: Separate successes and failures
|
||||
successful: list[OrgInvitation] = []
|
||||
failed: list[tuple[str, str]] = []
|
||||
for email, invitation, error in results:
|
||||
if invitation:
|
||||
successful.append(invitation)
|
||||
elif error:
|
||||
failed.append((email, error))
|
||||
|
||||
logger.info(
|
||||
'Batch invitation creation completed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'successful': len(successful),
|
||||
'failed': len(failed),
|
||||
},
|
||||
)
|
||||
|
||||
return successful, failed
|
||||
|
||||
@staticmethod
|
||||
async def accept_invitation(token: str, user_id: UUID) -> OrgInvitation:
|
||||
"""Accept an organization invitation.
|
||||
|
||||
This method:
|
||||
1. Validates the token and invitation status
|
||||
2. Checks expiration
|
||||
3. Verifies user is not already a member
|
||||
4. Creates LiteLLM integration
|
||||
5. Adds user to the organization
|
||||
6. Marks invitation as accepted
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
user_id: The user accepting the invitation
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The accepted invitation
|
||||
|
||||
Raises:
|
||||
InvitationInvalidError: If token is invalid or invitation not pending
|
||||
InvitationExpiredError: If invitation has expired
|
||||
UserAlreadyMemberError: If user is already a member
|
||||
"""
|
||||
logger.info(
|
||||
'Accepting organization invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...' if len(token) > 10 else token,
|
||||
'user_id': str(user_id),
|
||||
},
|
||||
)
|
||||
|
||||
# Step 1: Get and validate invitation
|
||||
invitation = await OrgInvitationStore.get_invitation_by_token(token)
|
||||
|
||||
if not invitation:
|
||||
raise InvitationInvalidError('Invalid invitation token')
|
||||
|
||||
if invitation.status != OrgInvitation.STATUS_PENDING:
|
||||
if invitation.status == OrgInvitation.STATUS_ACCEPTED:
|
||||
raise InvitationInvalidError('Invitation has already been accepted')
|
||||
elif invitation.status == OrgInvitation.STATUS_REVOKED:
|
||||
raise InvitationInvalidError('Invitation has been revoked')
|
||||
else:
|
||||
raise InvitationInvalidError('Invitation is no longer valid')
|
||||
|
||||
# Step 2: Check expiration
|
||||
if OrgInvitationStore.is_token_expired(invitation):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
raise InvitationExpiredError('Invitation has expired')
|
||||
|
||||
# Step 2.5: Verify user email matches invitation email
|
||||
user = await UserStore.get_user_by_id_async(str(user_id))
|
||||
if not user:
|
||||
raise InvitationInvalidError('User not found')
|
||||
|
||||
user_email = user.email
|
||||
# Fallback: fetch email from Keycloak if not in database (for existing users)
|
||||
if not user_email:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(str(user_id))
|
||||
user_email = user_info.get('email') if user_info else None
|
||||
|
||||
if not user_email:
|
||||
raise EmailMismatchError('Your account does not have an email address')
|
||||
|
||||
user_email = user_email.lower().strip()
|
||||
invitation_email = invitation.email.lower().strip()
|
||||
|
||||
if user_email != invitation_email:
|
||||
logger.warning(
|
||||
'Email mismatch during invitation acceptance',
|
||||
extra={
|
||||
'user_id': str(user_id),
|
||||
'user_email': user_email,
|
||||
'invitation_email': invitation_email,
|
||||
'invitation_id': invitation.id,
|
||||
},
|
||||
)
|
||||
raise EmailMismatchError()
|
||||
|
||||
# Step 3: Check if user is already a member
|
||||
existing_member = OrgMemberStore.get_org_member(invitation.org_id, user_id)
|
||||
if existing_member:
|
||||
raise UserAlreadyMemberError(
|
||||
'You are already a member of this organization'
|
||||
)
|
||||
|
||||
# Step 4: Create LiteLLM integration for the user in the new org
|
||||
try:
|
||||
settings = await OrgService.create_litellm_integration(
|
||||
invitation.org_id, str(user_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to create LiteLLM integration for invitation acceptance',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise InvitationInvalidError(
|
||||
'Failed to set up organization access. Please try again.'
|
||||
)
|
||||
|
||||
# Step 5: Add user to organization
|
||||
from storage.org_member_store import OrgMemberStore as OMS
|
||||
|
||||
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
|
||||
# Don't override with org defaults - use invitation-specified role
|
||||
org_member_kwargs.pop('llm_model', None)
|
||||
org_member_kwargs.pop('llm_base_url', None)
|
||||
|
||||
OrgMemberStore.add_user_to_org(
|
||||
org_id=invitation.org_id,
|
||||
user_id=user_id,
|
||||
role_id=invitation.role_id,
|
||||
llm_api_key=settings.llm_api_key,
|
||||
status='active',
|
||||
)
|
||||
|
||||
# Step 6: Mark invitation as accepted
|
||||
updated_invitation = await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id,
|
||||
OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Organization invitation accepted',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'user_id': str(user_id),
|
||||
'org_id': str(invitation.org_id),
|
||||
'role_id': invitation.role_id,
|
||||
},
|
||||
)
|
||||
|
||||
return updated_invitation
|
||||
342
enterprise/server/services/org_member_service.py
Normal file
342
enterprise/server/services/org_member_service.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""Service for managing organization members."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import ROLE_ADMIN, ROLE_MEMBER, ROLE_OWNER
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
MemberUpdateError,
|
||||
MeResponse,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
OrgMemberUpdate,
|
||||
RoleNotFoundError,
|
||||
)
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class OrgMemberService:
|
||||
"""Service for organization member operations."""
|
||||
|
||||
@staticmethod
|
||||
def get_me(org_id: UUID, user_id: UUID) -> MeResponse:
|
||||
"""Get the current user's membership record for an organization.
|
||||
|
||||
Retrieves the authenticated user's role, status, email, and LLM override
|
||||
fields (with masked API keys) within the specified organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
user_id: User ID (UUID)
|
||||
|
||||
Returns:
|
||||
MeResponse: The user's membership data with masked API keys
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If user is not a member of the organization
|
||||
RoleNotFoundError: If the role associated with the member is not found
|
||||
"""
|
||||
# Look up the user's membership in this org
|
||||
org_member = OrgMemberStore.get_org_member(org_id, user_id)
|
||||
if org_member is None:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(user_id))
|
||||
|
||||
# Resolve role name from role_id
|
||||
role = RoleStore.get_role_by_id(org_member.role_id)
|
||||
if role is None:
|
||||
raise RoleNotFoundError(org_member.role_id)
|
||||
|
||||
# Get user email
|
||||
user = UserStore.get_user_by_id(str(user_id))
|
||||
email = user.email if user and user.email else ''
|
||||
|
||||
return MeResponse.from_org_member(org_member, role, email)
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members(
|
||||
org_id: UUID,
|
||||
current_user_id: UUID,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> tuple[bool, str | None, OrgMemberPage | None]:
|
||||
"""Get organization members with authorization check.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_code, data). If success is True, error_code is None.
|
||||
"""
|
||||
# Verify current user is a member of the organization
|
||||
requester_membership = OrgMemberStore.get_org_member(org_id, current_user_id)
|
||||
if not requester_membership:
|
||||
return False, 'not_a_member', None
|
||||
|
||||
# Parse page_id to get offset (page_id is offset encoded as string)
|
||||
offset = 0
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
if offset < 0:
|
||||
return False, 'invalid_page_id', None
|
||||
except ValueError:
|
||||
return False, 'invalid_page_id', None
|
||||
|
||||
# Call store to get paginated members
|
||||
members, has_more = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id, offset=offset, limit=limit
|
||||
)
|
||||
|
||||
# Transform data to response format
|
||||
items = []
|
||||
for member in members:
|
||||
# Access user and role relationships (eagerly loaded)
|
||||
user = member.user
|
||||
role = member.role
|
||||
|
||||
items.append(
|
||||
OrgMemberResponse(
|
||||
user_id=str(member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=member.role_id,
|
||||
role=role.name if role else '',
|
||||
role_rank=role.rank if role else 0,
|
||||
status=member.status,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate next_page_id
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return True, None, OrgMemberPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
@staticmethod
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
target_user_id: UUID,
|
||||
current_user_id: UUID,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""Remove a member from an organization.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, error_message). If success is True, error_message is None.
|
||||
"""
|
||||
|
||||
def _remove_member():
|
||||
# Get current user's membership in the org
|
||||
requester_membership = OrgMemberStore.get_org_member(
|
||||
org_id, current_user_id
|
||||
)
|
||||
if not requester_membership:
|
||||
return False, 'not_a_member'
|
||||
|
||||
# Check if trying to remove self
|
||||
if str(current_user_id) == str(target_user_id):
|
||||
return False, 'cannot_remove_self'
|
||||
|
||||
# Get target user's membership
|
||||
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
|
||||
if not target_membership:
|
||||
return False, 'member_not_found'
|
||||
|
||||
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
|
||||
target_role = RoleStore.get_role_by_id(target_membership.role_id)
|
||||
|
||||
if not requester_role or not target_role:
|
||||
return False, 'role_not_found'
|
||||
|
||||
# Check permission based on roles
|
||||
if not OrgMemberService._can_remove_member(
|
||||
requester_role.name, target_role.name
|
||||
):
|
||||
return False, 'insufficient_permission'
|
||||
|
||||
# Check if removing the last owner
|
||||
if target_role.name == ROLE_OWNER:
|
||||
if OrgMemberService._is_last_owner(org_id, target_user_id):
|
||||
return False, 'cannot_remove_last_owner'
|
||||
|
||||
# Perform the removal
|
||||
success = OrgMemberStore.remove_user_from_org(org_id, target_user_id)
|
||||
if not success:
|
||||
return False, 'removal_failed'
|
||||
|
||||
return True, None
|
||||
|
||||
return await call_sync_from_async(_remove_member)
|
||||
|
||||
@staticmethod
|
||||
async def update_org_member(
|
||||
org_id: UUID,
|
||||
target_user_id: UUID,
|
||||
current_user_id: UUID,
|
||||
update_data: OrgMemberUpdate,
|
||||
) -> OrgMemberResponse:
|
||||
"""Update a member's role in an organization.
|
||||
|
||||
Permission rules:
|
||||
- Admins can change roles of users (rank > ADMIN_RANK) to Admin or User
|
||||
- Admins cannot modify other Admins or Owners
|
||||
- Owners can change roles of non-owners (rank > OWNER_RANK) to any role
|
||||
- Owners cannot modify other Owners
|
||||
|
||||
Args:
|
||||
org_id: Organization ID
|
||||
target_user_id: User ID of the member to update
|
||||
current_user_id: User ID of the requester
|
||||
update_data: Update data containing fields to modify
|
||||
|
||||
Returns:
|
||||
OrgMemberResponse: The updated member data
|
||||
|
||||
Raises:
|
||||
OrgMemberNotFoundError: If requester or target is not a member
|
||||
CannotModifySelfError: If trying to modify self
|
||||
RoleNotFoundError: If role configuration is invalid
|
||||
InvalidRoleError: If new_role_name is not a valid role
|
||||
InsufficientPermissionError: If requester lacks permission
|
||||
LastOwnerError: If trying to demote the last owner
|
||||
MemberUpdateError: If update operation fails
|
||||
"""
|
||||
new_role_name = update_data.role
|
||||
|
||||
def _update_member():
|
||||
# Get current user's membership in the org
|
||||
requester_membership = OrgMemberStore.get_org_member(
|
||||
org_id, current_user_id
|
||||
)
|
||||
if not requester_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(current_user_id))
|
||||
|
||||
# Check if trying to modify self
|
||||
if str(current_user_id) == str(target_user_id):
|
||||
raise CannotModifySelfError('modify')
|
||||
|
||||
# Get target user's membership
|
||||
target_membership = OrgMemberStore.get_org_member(org_id, target_user_id)
|
||||
if not target_membership:
|
||||
raise OrgMemberNotFoundError(str(org_id), str(target_user_id))
|
||||
|
||||
# Get roles
|
||||
requester_role = RoleStore.get_role_by_id(requester_membership.role_id)
|
||||
target_role = RoleStore.get_role_by_id(target_membership.role_id)
|
||||
|
||||
if not requester_role:
|
||||
raise RoleNotFoundError(requester_membership.role_id)
|
||||
if not target_role:
|
||||
raise RoleNotFoundError(target_membership.role_id)
|
||||
|
||||
# If no role change requested, return current state
|
||||
if new_role_name is None:
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
return OrgMemberResponse(
|
||||
user_id=str(target_membership.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=target_membership.role_id,
|
||||
role=target_role.name,
|
||||
role_rank=target_role.rank,
|
||||
status=target_membership.status,
|
||||
)
|
||||
|
||||
# Validate new role exists
|
||||
new_role = RoleStore.get_role_by_name(new_role_name.lower())
|
||||
if not new_role:
|
||||
raise InvalidRoleError(new_role_name)
|
||||
|
||||
# Check permission to modify target
|
||||
if not OrgMemberService._can_update_member_role(
|
||||
requester_role.name, target_role.name, new_role.name
|
||||
):
|
||||
raise InsufficientPermissionError(
|
||||
'You do not have permission to modify this member'
|
||||
)
|
||||
|
||||
# Check if demoting the last owner
|
||||
if (
|
||||
target_role.name == ROLE_OWNER
|
||||
and new_role.name != ROLE_OWNER
|
||||
and OrgMemberService._is_last_owner(org_id, target_user_id)
|
||||
):
|
||||
raise LastOwnerError('demote')
|
||||
|
||||
# Perform the update
|
||||
updated_member = OrgMemberStore.update_user_role_in_org(
|
||||
org_id, target_user_id, new_role.id
|
||||
)
|
||||
if not updated_member:
|
||||
raise MemberUpdateError('Failed to update member')
|
||||
|
||||
# Get user email for response
|
||||
user = UserStore.get_user_by_id(str(target_user_id))
|
||||
|
||||
return OrgMemberResponse(
|
||||
user_id=str(updated_member.user_id),
|
||||
email=user.email if user else None,
|
||||
role_id=updated_member.role_id,
|
||||
role=new_role.name,
|
||||
role_rank=new_role.rank,
|
||||
status=updated_member.status,
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_update_member)
|
||||
|
||||
@staticmethod
|
||||
def _can_update_member_role(
|
||||
requester_role_name: str, target_role_name: str, new_role_name: str
|
||||
) -> bool:
|
||||
"""Check if requester can change target's role to new_role.
|
||||
|
||||
Permission rules:
|
||||
- Owners can modify admins and users, can set any role
|
||||
- Owners cannot modify other owners
|
||||
- Admins can only modify users
|
||||
- Admins can only set admin or user roles (not owner)
|
||||
"""
|
||||
is_requester_owner = requester_role_name == ROLE_OWNER
|
||||
is_requester_admin = requester_role_name == ROLE_ADMIN
|
||||
is_target_owner = target_role_name == ROLE_OWNER
|
||||
is_target_admin = target_role_name == ROLE_ADMIN
|
||||
is_new_role_owner = new_role_name == ROLE_OWNER
|
||||
|
||||
if is_requester_owner:
|
||||
# Owners cannot modify other owners
|
||||
if is_target_owner:
|
||||
return False
|
||||
# Owners can set any role (owner, admin, user)
|
||||
return True
|
||||
elif is_requester_admin:
|
||||
# Admins cannot modify owners or other admins
|
||||
if is_target_owner or is_target_admin:
|
||||
return False
|
||||
# Admins can only set admin or user roles (not owner)
|
||||
return not is_new_role_owner
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _can_remove_member(requester_role_name: str, target_role_name: str) -> bool:
|
||||
"""Check if requester can remove target based on roles."""
|
||||
if requester_role_name == ROLE_OWNER:
|
||||
return True
|
||||
elif requester_role_name == ROLE_ADMIN:
|
||||
# Admins can only remove members (not owners or other admins)
|
||||
return target_role_name == ROLE_MEMBER
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_last_owner(org_id: UUID, user_id: UUID) -> bool:
|
||||
"""Check if user is the last owner of the organization."""
|
||||
members = OrgMemberStore.get_org_members(org_id)
|
||||
owners = []
|
||||
for m in members:
|
||||
# Use role_id (column) instead of role (relationship) to avoid DetachedInstanceError
|
||||
role = RoleStore.get_role_by_id(m.role_id)
|
||||
if role and role.name == ROLE_OWNER:
|
||||
owners.append(m)
|
||||
return len(owners) == 1 and str(owners[0].user_id) == str(user_id)
|
||||
@@ -22,11 +22,70 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.errors import AuthError
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
"""Extended SQLAppConversationInfoService with user and organization-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _get_current_user(self) -> User | None:
|
||||
"""Get the current user using the existing db_session.
|
||||
|
||||
Uses self.db_session to avoid opening a separate database session.
|
||||
|
||||
Returns:
|
||||
User object or None if no user_id is available
|
||||
"""
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str:
|
||||
return None
|
||||
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
result = await self.db_session.execute(
|
||||
select(User).where(User.id == user_id_uuid)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
async def _apply_user_and_org_filter(self, query):
|
||||
"""Apply user_id and org_id filters to ensure conversation isolation.
|
||||
|
||||
Filters conversations by:
|
||||
- user_id: Only show conversations belonging to the current user
|
||||
- org_id: Only show conversations belonging to the user's current organization
|
||||
|
||||
Args:
|
||||
query: SQLAlchemy query to apply filters to
|
||||
|
||||
Returns:
|
||||
Query with user and organization filters applied
|
||||
|
||||
Raises:
|
||||
AuthError: If no user_id is available (secure default: deny access)
|
||||
"""
|
||||
# For internal operations such as getting a conversation by session_api_key
|
||||
# we need a mode that does not have filtering. The dependency `as_admin()`
|
||||
# is used to enable it
|
||||
if self.user_context == ADMIN:
|
||||
return query
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str:
|
||||
# Secure default: no user means no access, not "show everything"
|
||||
raise AuthError('User authentication required')
|
||||
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
# Filter by organization ID to ensure conversations are isolated per organization
|
||||
user = await self._get_current_user()
|
||||
if user and user.current_org_id is not None:
|
||||
query = query.where(
|
||||
StoredConversationMetadataSaas.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
@@ -38,13 +97,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
return await self._apply_user_and_org_filter(query)
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
@@ -57,13 +110,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
return await self._apply_user_and_org_filter(query)
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
@@ -155,21 +202,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
# Apply user and organization filtering
|
||||
query = await self._apply_user_and_org_filter(query)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
@@ -233,7 +275,13 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
# Fetch sub-conversation IDs
|
||||
sub_conversation_ids = await self.get_sub_conversation_ids(conversation_id)
|
||||
return self._to_info_with_user_id(
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
sub_conversation_ids=sub_conversation_ids,
|
||||
)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
@@ -262,8 +310,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
# Fetch sub-conversation IDs for each conversation
|
||||
sub_conversation_ids = await self.get_sub_conversation_ids(
|
||||
UUID(conversation_id)
|
||||
)
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
self._to_info_with_user_id(
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
sub_conversation_ids=sub_conversation_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
@@ -316,10 +372,11 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
sub_conversation_ids: list[UUID] | None = None,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
info = self._to_info(stored, sub_conversation_ids=sub_conversation_ids)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
|
||||
@@ -20,8 +20,10 @@ from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
@@ -65,8 +67,10 @@ __all__ = [
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgInvitation',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'ResendSyncedUser',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
|
||||
@@ -12,6 +12,7 @@ from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -26,7 +27,7 @@ class ApiKeyStore:
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
|
||||
def create_api_key(
|
||||
async def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
) -> str:
|
||||
"""Create a new API key for a user.
|
||||
@@ -40,8 +41,23 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
await call_sync_from_async(
|
||||
self._store_api_key, user_id, org_id, api_key, name, expires_at
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
def _store_api_key(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: str,
|
||||
api_key: str,
|
||||
name: str | None,
|
||||
expires_at: datetime | None = None,
|
||||
) -> None:
|
||||
"""Store an existing API key in the database."""
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
@@ -53,8 +69,6 @@ class ApiKeyStore:
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
|
||||
return api_key
|
||||
|
||||
def validate_api_key(self, api_key: str) -> str | None:
|
||||
"""Validate an API key and return the associated user_id if valid."""
|
||||
now = datetime.now(UTC)
|
||||
@@ -112,33 +126,31 @@ class ApiKeyStore:
|
||||
|
||||
return True
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
return await call_sync_from_async(self._list_api_keys_from_db, user_id, org_id)
|
||||
|
||||
def _list_api_keys_from_db(self, user_id: str, org_id: str) -> list[ApiKey]:
|
||||
with self.session_maker() as session:
|
||||
keys = (
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'id': key.id,
|
||||
'name': key.name,
|
||||
'created_at': key.created_at,
|
||||
'last_used_at': key.last_used_at,
|
||||
'expires_at': key.expires_at,
|
||||
}
|
||||
for key in keys
|
||||
if 'MCP_API_KEY' != key.name
|
||||
]
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id
|
||||
return await call_sync_from_async(
|
||||
self._retrieve_mcp_api_key_from_db, user_id, org_id
|
||||
)
|
||||
|
||||
def _retrieve_mcp_api_key_from_db(self, user_id: str, org_id: str) -> str | None:
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey)
|
||||
|
||||
@@ -4,7 +4,9 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy import select, text, update
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
@@ -12,6 +14,14 @@ from storage.database import a_session_maker
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
# Time buffer (in seconds) before actual expiration to consider token expired
|
||||
# This ensures tokens are refreshed before they actually expire. The
|
||||
# github default is 8 hours, so 15 minutes leeway is ~3% of this.
|
||||
ACCESS_TOKEN_EXPIRY_BUFFER = 900 # 15 minutes
|
||||
|
||||
# Database lock timeout to prevent indefinite blocking
|
||||
LOCK_TIMEOUT_SECONDS = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthTokenStore:
|
||||
@@ -23,6 +33,31 @@ class AuthTokenStore:
|
||||
def identity_provider_value(self) -> str:
|
||||
return self.idp.value
|
||||
|
||||
def _is_token_expired(
|
||||
self, access_token_expires_at: int, refresh_token_expires_at: int
|
||||
) -> tuple[bool, bool]:
|
||||
"""Check if access and refresh tokens are expired.
|
||||
|
||||
Args:
|
||||
access_token_expires_at: Expiration time for access token (seconds since epoch)
|
||||
refresh_token_expires_at: Expiration time for refresh token (seconds since epoch)
|
||||
|
||||
Returns:
|
||||
Tuple of (access_expired, refresh_expired)
|
||||
"""
|
||||
current_time = int(time.time())
|
||||
access_expired = (
|
||||
False
|
||||
if access_token_expires_at == 0
|
||||
else access_token_expires_at < current_time + ACCESS_TOKEN_EXPIRY_BUFFER
|
||||
)
|
||||
refresh_expired = (
|
||||
False
|
||||
if refresh_token_expires_at == 0
|
||||
else refresh_token_expires_at < current_time
|
||||
)
|
||||
return access_expired, refresh_expired
|
||||
|
||||
async def store_tokens(
|
||||
self,
|
||||
access_token: str,
|
||||
@@ -73,87 +108,149 @@ class AuthTokenStore:
|
||||
]
|
||||
| None = None,
|
||||
) -> Dict[str, str | int] | None:
|
||||
"""
|
||||
Load authentication tokens from the database and refresh them if necessary.
|
||||
"""Load authentication tokens from the database and refresh them if necessary.
|
||||
|
||||
This method retrieves the current authentication tokens for the user and checks if they have expired.
|
||||
It uses the provided `check_expiration_and_refresh` function to determine if the tokens need
|
||||
to be refreshed and to refresh the tokens if needed.
|
||||
This method uses a double-checked locking pattern to minimize lock contention:
|
||||
1. First, check if the token is valid WITHOUT acquiring a lock (fast path)
|
||||
2. If refresh is needed, acquire a lock with a timeout
|
||||
3. Double-check if refresh is still needed (another request may have refreshed)
|
||||
4. Perform the refresh if still needed
|
||||
|
||||
The method ensures that only one refresh operation is performed per refresh token by using a
|
||||
row-level lock on the token record.
|
||||
|
||||
The method is designed to handle race conditions where multiple requests might attempt to refresh
|
||||
the same token simultaneously, ensuring that only one refresh call occurs per refresh token.
|
||||
The row-level lock ensures that only one refresh operation is performed per
|
||||
refresh token, which is important because most IDPs invalidate the old refresh
|
||||
token after it's used once.
|
||||
|
||||
Args:
|
||||
check_expiration_and_refresh (Callable, optional): A function that checks if the tokens have expired
|
||||
and attempts to refresh them. It should return a dictionary containing the new access_token, refresh_token,
|
||||
and their respective expiration timestamps. If no refresh is needed, it should return `None`.
|
||||
check_expiration_and_refresh: A function that checks if the tokens have
|
||||
expired and attempts to refresh them. It should return a dictionary
|
||||
containing the new access_token, refresh_token, and their respective
|
||||
expiration timestamps. If no refresh is needed, it should return None.
|
||||
|
||||
Returns:
|
||||
Dict[str, str | int] | None:
|
||||
A dictionary containing the access_token, refresh_token, access_token_expires_at,
|
||||
and refresh_token_expires_at. If no token record is found, returns `None`.
|
||||
A dictionary containing the access_token, refresh_token,
|
||||
access_token_expires_at, and refresh_token_expires_at.
|
||||
If no token record is found, returns None.
|
||||
|
||||
Raises:
|
||||
TokenRefreshError: If the lock cannot be acquired within the timeout
|
||||
period. This typically means another request is holding the lock
|
||||
for an extended period. Callers should handle this by returning
|
||||
a 401 response to prompt the user to re-authenticate.
|
||||
"""
|
||||
# FAST PATH: Check without lock first to avoid unnecessary lock contention
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin(): # Ensures transaction management
|
||||
# Lock the row while we check if we need to refresh the tokens.
|
||||
# There is a race condition where 2 or more calls can load tokens simultaneously.
|
||||
# If it turns out the loaded tokens are expired, then there will be multiple
|
||||
# refresh token calls with the same refresh token. Most IDPs only allow one refresh
|
||||
# per refresh token. This lock ensure that only one refresh call occurs per refresh token
|
||||
result = await session.execute(
|
||||
select(AuthTokens)
|
||||
.filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
AuthTokens.identity_provider == self.identity_provider_value,
|
||||
)
|
||||
.with_for_update()
|
||||
result = await session.execute(
|
||||
select(AuthTokens).filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
AuthTokens.identity_provider == self.identity_provider_value,
|
||||
)
|
||||
token_record = result.scalars().one_or_none()
|
||||
)
|
||||
token_record = result.scalars().one_or_none()
|
||||
|
||||
if not token_record:
|
||||
return None
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
token_refresh = (
|
||||
await check_expiration_and_refresh(
|
||||
# Check if token needs refresh
|
||||
access_expired, _ = self._is_token_expired(
|
||||
token_record.access_token_expires_at,
|
||||
token_record.refresh_token_expires_at,
|
||||
)
|
||||
|
||||
# If token is still valid, return it without acquiring a lock
|
||||
if not access_expired or check_expiration_and_refresh is None:
|
||||
return {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
|
||||
# SLOW PATH: Token needs refresh, acquire lock
|
||||
try:
|
||||
async with self.a_session_maker() as session:
|
||||
async with session.begin():
|
||||
# Set a lock timeout to prevent indefinite blocking
|
||||
# This ensures we don't hold connections forever if something goes wrong
|
||||
await session.execute(
|
||||
text(f"SET LOCAL lock_timeout = '{LOCK_TIMEOUT_SECONDS}s'")
|
||||
)
|
||||
|
||||
# Acquire row-level lock to prevent concurrent refresh attempts
|
||||
result = await session.execute(
|
||||
select(AuthTokens)
|
||||
.filter(
|
||||
AuthTokens.keycloak_user_id == self.keycloak_user_id,
|
||||
AuthTokens.identity_provider
|
||||
== self.identity_provider_value,
|
||||
)
|
||||
.with_for_update()
|
||||
)
|
||||
token_record = result.scalars().one_or_none()
|
||||
|
||||
if not token_record:
|
||||
return None
|
||||
|
||||
# Double-check: another request may have refreshed while we waited for the lock
|
||||
access_expired, _ = self._is_token_expired(
|
||||
token_record.access_token_expires_at,
|
||||
token_record.refresh_token_expires_at,
|
||||
)
|
||||
|
||||
if not access_expired:
|
||||
# Token was refreshed by another request while we waited
|
||||
logger.debug(
|
||||
'Token was refreshed by another request while waiting for lock'
|
||||
)
|
||||
return {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
|
||||
# We're the one doing the refresh
|
||||
token_refresh = await check_expiration_and_refresh(
|
||||
self.idp,
|
||||
token_record.refresh_token,
|
||||
token_record.access_token_expires_at,
|
||||
token_record.refresh_token_expires_at,
|
||||
)
|
||||
if check_expiration_and_refresh
|
||||
else None
|
||||
)
|
||||
|
||||
if token_refresh:
|
||||
await session.execute(
|
||||
update(AuthTokens)
|
||||
.where(AuthTokens.id == token_record.id)
|
||||
.values(
|
||||
access_token=token_refresh['access_token'],
|
||||
refresh_token=token_refresh['refresh_token'],
|
||||
access_token_expires_at=token_refresh[
|
||||
'access_token_expires_at'
|
||||
],
|
||||
refresh_token_expires_at=token_refresh[
|
||||
'refresh_token_expires_at'
|
||||
],
|
||||
if token_refresh:
|
||||
await session.execute(
|
||||
update(AuthTokens)
|
||||
.where(AuthTokens.id == token_record.id)
|
||||
.values(
|
||||
access_token=token_refresh['access_token'],
|
||||
refresh_token=token_refresh['refresh_token'],
|
||||
access_token_expires_at=token_refresh[
|
||||
'access_token_expires_at'
|
||||
],
|
||||
refresh_token_expires_at=token_refresh[
|
||||
'refresh_token_expires_at'
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
await session.commit()
|
||||
|
||||
return (
|
||||
token_refresh
|
||||
if token_refresh
|
||||
else {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
)
|
||||
return (
|
||||
token_refresh
|
||||
if token_refresh
|
||||
else {
|
||||
'access_token': token_record.access_token,
|
||||
'refresh_token': token_record.refresh_token,
|
||||
'access_token_expires_at': token_record.access_token_expires_at,
|
||||
'refresh_token_expires_at': token_record.refresh_token_expires_at,
|
||||
}
|
||||
)
|
||||
except OperationalError as e:
|
||||
# Lock timeout - another request is holding the lock for too long
|
||||
logger.warning(
|
||||
f'Token refresh lock timeout for user {self.keycloak_user_id}: {e}'
|
||||
)
|
||||
raise TokenRefreshError(
|
||||
'Unable to refresh token due to lock timeout. Please try again.'
|
||||
) from e
|
||||
|
||||
async def is_access_token_valid(self) -> bool:
|
||||
"""Check if the access token is still valid.
|
||||
@@ -194,8 +291,8 @@ class AuthTokenStore:
|
||||
"""Get an instance of the AuthTokenStore.
|
||||
|
||||
Args:
|
||||
config: The application configuration
|
||||
keycloak_user_id: The Keycloak user ID
|
||||
idp: The identity provider type
|
||||
|
||||
Returns:
|
||||
An instance of AuthTokenStore
|
||||
|
||||
@@ -18,17 +18,17 @@ def _get_db_session_injector():
|
||||
return _config.db_session
|
||||
|
||||
|
||||
def session_maker():
|
||||
def session_maker(**kwargs):
|
||||
db_session_injector = _get_db_session_injector()
|
||||
session_maker = db_session_injector.get_session_maker()
|
||||
return session_maker()
|
||||
factory = db_session_injector.get_session_maker()
|
||||
return factory(**kwargs)
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def a_session_maker():
|
||||
async def a_session_maker(**kwargs):
|
||||
db_session_injector = _get_db_session_injector()
|
||||
a_session_maker = await db_session_injector.get_async_session_maker()
|
||||
async with a_session_maker() as session:
|
||||
factory = await db_session_injector.get_async_session_maker()
|
||||
async with factory(**kwargs) as session:
|
||||
yield session
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
@@ -18,21 +17,60 @@ from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.encrypt_utils import decrypt_legacy_value
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# Timeout in seconds for BYOR key verification requests to LiteLLM
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
# Timeout in seconds for key verification requests to LiteLLM
|
||||
KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
|
||||
def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for OpenHands Cloud managed keys."""
|
||||
return f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}'
|
||||
|
||||
|
||||
def get_byor_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for BYOR (Bring Your Own Runtime) keys."""
|
||||
return f'BYOR Key - user {keycloak_user_id}, org {org_id}'
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
def get_budget_from_team_info(
|
||||
user_team_info: dict | None, user_id: str, org_id: str
|
||||
) -> tuple[float, float]:
|
||||
"""Extract max_budget and spend from user team info.
|
||||
|
||||
For personal orgs (user_id == org_id), uses litellm_budget_table.max_budget.
|
||||
For team orgs, uses max_budget_in_team (populated by get_user_team_info).
|
||||
|
||||
Args:
|
||||
user_team_info: The response from get_user_team_info
|
||||
user_id: The user's ID
|
||||
org_id: The organization's ID
|
||||
|
||||
Returns:
|
||||
Tuple of (max_budget, spend)
|
||||
"""
|
||||
if not user_team_info:
|
||||
return 0, 0
|
||||
spend = user_team_info.get('spend', 0)
|
||||
if user_id == org_id:
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
else:
|
||||
max_budget = user_team_info.get('max_budget_in_team') or 0
|
||||
return max_budget, spend
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
@@ -61,8 +99,33 @@ class LiteLlmManager:
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Check if team already exists and get its budget
|
||||
# New users joining existing orgs should inherit the team's budget
|
||||
team_budget = 0.0
|
||||
try:
|
||||
existing_team = await LiteLlmManager._get_team(client, org_id)
|
||||
if existing_team:
|
||||
team_info = existing_team.get('team_info', {})
|
||||
team_budget = team_info.get('max_budget', 0.0) or 0.0
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:existing_team_budget',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'team_budget': team_budget,
|
||||
},
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Team doesn't exist yet (404) - this is expected for first user
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:no_existing_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
)
|
||||
|
||||
if create_user:
|
||||
@@ -71,14 +134,14 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id} - org {org_id}',
|
||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -114,8 +177,24 @@ class LiteLlmManager:
|
||||
if not user_json:
|
||||
return None
|
||||
user_info = user_json['user_info']
|
||||
max_budget = user_info.get('max_budget', 0.0)
|
||||
spend = user_info.get('spend', 0.0)
|
||||
|
||||
# Log original user values before any modifications for debugging
|
||||
original_max_budget = user_info.get('max_budget')
|
||||
original_spend = user_info.get('spend')
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:original_user_values',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'original_max_budget': original_max_budget,
|
||||
'original_spend': original_spend,
|
||||
},
|
||||
)
|
||||
|
||||
max_budget = (
|
||||
original_max_budget if original_max_budget is not None else 0.0
|
||||
)
|
||||
spend = original_spend if original_spend is not None else 0.0
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
@@ -136,11 +215,37 @@ class LiteLlmManager:
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
|
||||
if not max_budget:
|
||||
# if max_budget is None, then we've already migrated the User
|
||||
# Check if max_budget is None (not 0.0) or set to unlimited to determine if already migrated
|
||||
# A user with max_budget=0.0 is different from max_budget=None
|
||||
if (
|
||||
original_max_budget is None
|
||||
or original_max_budget == UNLIMITED_BUDGET_SETTING
|
||||
):
|
||||
# if max_budget is None or UNLIMITED, then we've already migrated the User
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:already_migrated',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'original_max_budget': original_max_budget,
|
||||
},
|
||||
)
|
||||
return None
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
# Log calculated migration values before performing updates
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:calculated_values',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'adjusted_max_budget': max_budget,
|
||||
'adjusted_spend': spend,
|
||||
'calculated_credits': credits,
|
||||
'new_user_max_budget': UNLIMITED_BUDGET_SETTING,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:create_team',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
@@ -165,29 +270,60 @@ class LiteLlmManager:
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key:
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=org_id,
|
||||
)
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_user_keys',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user_keys(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
team_id=org_id,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.debug(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=org_id,
|
||||
# Check if the database key exists in LiteLLM
|
||||
# If not, generate a new key to prevent verification failures later
|
||||
db_key = None
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.llm_api_key
|
||||
and user_settings.llm_base_url == LITE_LLM_API_URL
|
||||
):
|
||||
db_key = user_settings.llm_api_key
|
||||
if hasattr(db_key, 'get_secret_value'):
|
||||
db_key = db_key.get_secret_value()
|
||||
|
||||
if db_key:
|
||||
# Verify the database key exists in LiteLLM
|
||||
key_valid = await LiteLlmManager.verify_key(
|
||||
db_key, keycloak_user_id
|
||||
)
|
||||
if not key_valid:
|
||||
logger.warning(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:db_key_not_in_litellm',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'key_prefix': db_key[:10] + '...'
|
||||
if len(db_key) > 10
|
||||
else db_key,
|
||||
},
|
||||
)
|
||||
# Generate a new key for the user
|
||||
new_key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
get_openhands_cloud_key_alias(keycloak_user_id, org_id),
|
||||
None,
|
||||
)
|
||||
if new_key:
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:generated_new_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
# Update user_settings with the new key so it gets stored in org_member
|
||||
user_settings.llm_api_key = SecretStr(new_key)
|
||||
user_settings.llm_api_key_for_byor = SecretStr(new_key)
|
||||
|
||||
logger.info(
|
||||
'LiteLlmManager:migrate_lite_llm_entries:complete',
|
||||
@@ -305,30 +441,16 @@ class LiteLlmManager:
|
||||
client, keycloak_user_id, LITE_LLM_TEAM_ID, restored_budget
|
||||
)
|
||||
|
||||
# Step 4: Update keys to remove org team association (set team_id to default)
|
||||
if user_settings.llm_api_key:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key_for_byor:
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_byor_key',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
user_settings.llm_api_key_for_byor,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
# Step 4: Update all user keys to remove org team association (set team_id to default)
|
||||
logger.debug(
|
||||
'LiteLlmManager:downgrade_entries:update_user_keys',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
await LiteLlmManager._update_user_keys(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
team_id=LITE_LLM_TEAM_ID,
|
||||
)
|
||||
|
||||
# Step 5: Remove user from their org team
|
||||
logger.debug(
|
||||
@@ -605,6 +727,13 @@ class LiteLlmManager:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
try:
|
||||
# Sometimes the key we get is encrypted - attempt to decrypt.
|
||||
key = decrypt_legacy_value(key)
|
||||
except Exception:
|
||||
# The key was not encrypted
|
||||
pass
|
||||
|
||||
payload = {
|
||||
'key': key,
|
||||
}
|
||||
@@ -621,6 +750,7 @@ class LiteLlmManager:
|
||||
'invalid_litellm_key_during_update',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
return
|
||||
@@ -634,6 +764,77 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_keys(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
) -> list[str]:
|
||||
"""Get all keys for a user from LiteLLM.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use for the request
|
||||
keycloak_user_id: The user's Keycloak ID
|
||||
|
||||
Returns:
|
||||
A list of key strings belonging to the user
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return []
|
||||
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/key/list',
|
||||
params={'user_id': keycloak_user_id},
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_getting_user_keys',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
response_json = response.json()
|
||||
keys = response_json.get('keys', [])
|
||||
logger.debug(
|
||||
'LiteLlmManager:_get_user_keys:keys_retrieved',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'key_count': len(keys),
|
||||
},
|
||||
)
|
||||
return keys
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_keys(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""Update all keys belonging to a user with the given parameters.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use for the request
|
||||
keycloak_user_id: The user's Keycloak ID
|
||||
**kwargs: Parameters to update on each key (e.g., team_id)
|
||||
"""
|
||||
keys = await LiteLlmManager._get_user_keys(client, keycloak_user_id)
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:_update_user_keys:updating_keys',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'key_count': len(keys),
|
||||
},
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
await LiteLlmManager._update_key(client, keycloak_user_id, key, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
async def _delete_user(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -745,21 +946,31 @@ class LiteLlmManager:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
team_response = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_response:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
for membership in team_response.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not user_membership:
|
||||
return None
|
||||
|
||||
# For team orgs (user_id != team_id), include team-level budget info
|
||||
# The team's max_budget and spend are shared across all members
|
||||
if keycloak_user_id != team_id:
|
||||
team_info = team_response.get('team_info', {})
|
||||
user_membership['max_budget_in_team'] = team_info.get('max_budget')
|
||||
user_membership['spend'] = team_info.get('spend', 0)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
@@ -905,7 +1116,7 @@ class LiteLlmManager:
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
timeout=KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
@@ -919,7 +1130,7 @@ class LiteLlmManager:
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
'Key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
@@ -927,7 +1138,7 @@ class LiteLlmManager:
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
'Key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
@@ -940,7 +1151,7 @@ class LiteLlmManager:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
'Key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
@@ -984,6 +1195,103 @@ class LiteLlmManager:
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_all_keys_for_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
) -> list[dict]:
|
||||
"""Get all keys for a user from LiteLLM.
|
||||
|
||||
Returns a list of key info dictionaries containing:
|
||||
- token: the key value (hashed or partial)
|
||||
- key_alias: the alias for the key
|
||||
- key_name: the name of the key
|
||||
- spend: the amount spent on this key
|
||||
- max_budget: the max budget for this key
|
||||
- team_id: the team the key belongs to
|
||||
- metadata: any metadata associated with the key
|
||||
|
||||
Returns an empty list if no keys found or on error.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return []
|
||||
|
||||
try:
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={keycloak_user_id}',
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY},
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_json = response.json()
|
||||
# The user/info endpoint returns keys in the 'keys' field
|
||||
return user_json.get('keys', [])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'LiteLlmManager:_get_all_keys_for_user:error',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _verify_existing_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_value: str,
|
||||
keycloak_user_id: str,
|
||||
org_id: str,
|
||||
openhands_type: bool = False,
|
||||
) -> bool:
|
||||
"""Check if an existing key exists for the user/org in LiteLLM.
|
||||
|
||||
Verifies the provided key_value matches a key registered in LiteLLM for
|
||||
the given user and organization. For openhands_type=True, looks for keys
|
||||
with metadata type='openhands' and matching team_id. For openhands_type=False,
|
||||
looks for keys with matching alias and team_id.
|
||||
|
||||
Returns True if the key is found and valid, False otherwise.
|
||||
"""
|
||||
found = False
|
||||
keys = await LiteLlmManager._get_all_keys_for_user(client, keycloak_user_id)
|
||||
for key_info in keys:
|
||||
metadata = key_info.get('metadata') or {}
|
||||
team_id = key_info.get('team_id')
|
||||
key_alias = key_info.get('key_alias')
|
||||
token = None
|
||||
if (
|
||||
openhands_type
|
||||
and metadata.get('type') == 'openhands'
|
||||
and team_id == org_id
|
||||
):
|
||||
# Found an existing OpenHands key for this org
|
||||
key_name = key_info.get('key_name')
|
||||
token = key_name[-4:] if key_name else None # last 4 digits of key
|
||||
if token and key_value.endswith(
|
||||
token
|
||||
): # check if this is our current key
|
||||
found = True
|
||||
break
|
||||
if (
|
||||
not openhands_type
|
||||
and team_id == org_id
|
||||
and (
|
||||
key_alias == get_openhands_cloud_key_alias(keycloak_user_id, org_id)
|
||||
or key_alias == get_byor_key_alias(keycloak_user_id, org_id)
|
||||
)
|
||||
):
|
||||
# Found an existing key for this org (regardless of type)
|
||||
key_name = key_info.get('key_name')
|
||||
token = key_name[-4:] if key_name else None # last 4 digits of key
|
||||
if token and key_value.endswith(
|
||||
token
|
||||
): # check if this is our current key
|
||||
found = True
|
||||
break
|
||||
|
||||
return found
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key_by_alias(
|
||||
client: httpx.AsyncClient,
|
||||
@@ -1081,4 +1389,8 @@ class LiteLlmManager:
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
verify_existing_key = staticmethod(with_http_client(_verify_existing_key))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
get_user_keys = staticmethod(with_http_client(_get_user_keys))
|
||||
delete_key_by_alias = staticmethod(with_http_client(_delete_key_by_alias))
|
||||
update_user_keys = staticmethod(with_http_client(_update_user_keys))
|
||||
|
||||
@@ -46,10 +46,12 @@ class Org(Base): # type: ignore
|
||||
v1_enabled = Column(Boolean, nullable=True)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
condenser_max_size = Column(Integer, nullable=True)
|
||||
byor_export_enabled = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
invitations = relationship('OrgInvitation', back_populates='org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
|
||||
59
enterprise/storage/org_invitation.py
Normal file
59
enterprise/storage/org_invitation.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization Invitation.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class OrgInvitation(Base): # type: ignore
|
||||
"""Organization invitation model.
|
||||
|
||||
Represents an invitation for a user to join an organization.
|
||||
Invitations are created by organization owners/admins and contain
|
||||
a secure token that can be used to accept the invitation.
|
||||
"""
|
||||
|
||||
__tablename__ = 'org_invitation'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
token = Column(String(64), nullable=False, unique=True, index=True)
|
||||
org_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('org.id', ondelete='CASCADE'),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
email = Column(String(255), nullable=False, index=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
inviter_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
status = Column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
server_default=text("'pending'"),
|
||||
)
|
||||
created_at = Column(
|
||||
DateTime,
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
expires_at = Column(DateTime, nullable=False)
|
||||
accepted_at = Column(DateTime, nullable=True)
|
||||
accepted_by_user_id = Column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey('user.id'),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='invitations')
|
||||
role = relationship('Role')
|
||||
inviter = relationship('User', foreign_keys=[inviter_id])
|
||||
accepted_by_user = relationship('User', foreign_keys=[accepted_by_user_id])
|
||||
|
||||
# Status constants
|
||||
STATUS_PENDING = 'pending'
|
||||
STATUS_ACCEPTED = 'accepted'
|
||||
STATUS_REVOKED = 'revoked'
|
||||
STATUS_EXPIRED = 'expired'
|
||||
227
enterprise/storage/org_invitation_store.py
Normal file
227
enterprise/storage/org_invitation_store.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""
|
||||
Store class for managing organization invitations.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Invitation token configuration
|
||||
INVITATION_TOKEN_PREFIX = 'inv-'
|
||||
INVITATION_TOKEN_LENGTH = 48 # Total length will be 52 with prefix
|
||||
DEFAULT_EXPIRATION_DAYS = 7
|
||||
|
||||
|
||||
class OrgInvitationStore:
|
||||
"""Store for managing organization invitations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_token(length: int = INVITATION_TOKEN_LENGTH) -> str:
|
||||
"""Generate a secure invitation token.
|
||||
|
||||
Uses cryptographically secure random generation for tokens.
|
||||
Pattern from api_key_store.py.
|
||||
|
||||
Args:
|
||||
length: Length of the random part of the token
|
||||
|
||||
Returns:
|
||||
str: Token with prefix (e.g., 'inv-aBcDeF123...')
|
||||
"""
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{INVITATION_TOKEN_PREFIX}{random_part}'
|
||||
|
||||
@staticmethod
|
||||
async def create_invitation(
|
||||
org_id: UUID,
|
||||
email: str,
|
||||
role_id: int,
|
||||
inviter_id: UUID,
|
||||
expiration_days: int = DEFAULT_EXPIRATION_DAYS,
|
||||
) -> OrgInvitation:
|
||||
"""Create a new organization invitation.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Invitee's email address
|
||||
role_id: Role ID to assign on acceptance
|
||||
inviter_id: User ID of the person creating the invitation
|
||||
expiration_days: Days until the invitation expires
|
||||
|
||||
Returns:
|
||||
OrgInvitation: The created invitation record
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
expires_at = datetime.utcnow() + timedelta(days=expiration_days)
|
||||
|
||||
invitation = OrgInvitation(
|
||||
token=token,
|
||||
org_id=org_id,
|
||||
email=email.lower().strip(),
|
||||
role_id=role_id,
|
||||
inviter_id=inviter_id,
|
||||
status=OrgInvitation.STATUS_PENDING,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(invitation)
|
||||
await session.commit()
|
||||
|
||||
# Re-fetch with eagerly loaded relationships to avoid DetachedInstanceError
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.id == invitation.id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
logger.info(
|
||||
'Created organization invitation',
|
||||
extra={
|
||||
'invitation_id': invitation.id,
|
||||
'org_id': str(org_id),
|
||||
'email': email,
|
||||
'inviter_id': str(inviter_id),
|
||||
'expires_at': expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
async def get_invitation_by_token(token: str) -> Optional[OrgInvitation]:
|
||||
"""Get an invitation by its token.
|
||||
|
||||
Args:
|
||||
token: The invitation token
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation)
|
||||
.options(joinedload(OrgInvitation.org), joinedload(OrgInvitation.role))
|
||||
.filter(OrgInvitation.token == token)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def get_pending_invitation(
|
||||
org_id: UUID, email: str
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Get a pending invitation for an email in an organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
email: Email address to check
|
||||
|
||||
Returns:
|
||||
OrgInvitation or None if no pending invitation exists
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(
|
||||
and_(
|
||||
OrgInvitation.org_id == org_id,
|
||||
OrgInvitation.email == email.lower().strip(),
|
||||
OrgInvitation.status == OrgInvitation.STATUS_PENDING,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def update_invitation_status(
|
||||
invitation_id: int,
|
||||
status: str,
|
||||
accepted_by_user_id: Optional[UUID] = None,
|
||||
) -> Optional[OrgInvitation]:
|
||||
"""Update an invitation's status.
|
||||
|
||||
Args:
|
||||
invitation_id: The invitation ID
|
||||
status: New status (pending, accepted, revoked, expired)
|
||||
accepted_by_user_id: User ID who accepted (only for 'accepted' status)
|
||||
|
||||
Returns:
|
||||
Updated OrgInvitation or None if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgInvitation).filter(OrgInvitation.id == invitation_id)
|
||||
)
|
||||
invitation = result.scalars().first()
|
||||
|
||||
if not invitation:
|
||||
return None
|
||||
|
||||
old_status = invitation.status
|
||||
invitation.status = status
|
||||
|
||||
if status == OrgInvitation.STATUS_ACCEPTED and accepted_by_user_id:
|
||||
# Use timezone-naive datetime for database compatibility
|
||||
invitation.accepted_at = datetime.utcnow()
|
||||
invitation.accepted_by_user_id = accepted_by_user_id
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(invitation)
|
||||
|
||||
logger.info(
|
||||
'Updated invitation status',
|
||||
extra={
|
||||
'invitation_id': invitation_id,
|
||||
'old_status': old_status,
|
||||
'new_status': status,
|
||||
'accepted_by_user_id': (
|
||||
str(accepted_by_user_id) if accepted_by_user_id else None
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return invitation
|
||||
|
||||
@staticmethod
|
||||
def is_token_expired(invitation: OrgInvitation) -> bool:
|
||||
"""Check if an invitation token has expired.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if expired, False otherwise
|
||||
"""
|
||||
# Use timezone-naive datetime for comparison (database stores without timezone)
|
||||
now = datetime.utcnow()
|
||||
return invitation.expires_at < now
|
||||
|
||||
@staticmethod
|
||||
async def mark_expired_if_needed(invitation: OrgInvitation) -> bool:
|
||||
"""Check if invitation is expired and update status if needed.
|
||||
|
||||
Args:
|
||||
invitation: The invitation to check
|
||||
|
||||
Returns:
|
||||
bool: True if invitation was marked as expired, False otherwise
|
||||
"""
|
||||
if (
|
||||
invitation.status == OrgInvitation.STATUS_PENDING
|
||||
and OrgInvitationStore.is_token_expired(invitation)
|
||||
):
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation.id, OrgInvitation.STATUS_EXPIRED
|
||||
)
|
||||
return True
|
||||
return False
|
||||
@@ -5,8 +5,11 @@ Store class for managing organization-member relationships.
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.org_member import OrgMember
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
@@ -38,7 +41,7 @@ class OrgMemberStore:
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
def get_org_member(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
@@ -48,7 +51,63 @@ class OrgMemberStore:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
async def get_org_member_async(org_id: UUID, user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember).filter(
|
||||
OrgMember.org_id == org_id, OrgMember.user_id == user_id
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_org_member_for_current_org(user_id: UUID) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
with session_maker() as session:
|
||||
result = (
|
||||
session.query(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
async def get_org_member_for_current_org_async(
|
||||
user_id: UUID,
|
||||
) -> Optional[OrgMember]:
|
||||
"""Get the org member for a user's current organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID.
|
||||
|
||||
Returns:
|
||||
The OrgMember for the user's current organization, or None if not found.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgMember)
|
||||
.join(User, User.id == OrgMember.user_id)
|
||||
.filter(
|
||||
User.id == user_id,
|
||||
OrgMember.org_id == User.current_org_id,
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: UUID) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
@@ -68,7 +127,7 @@ class OrgMemberStore:
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
org_id: UUID, user_id: UUID, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
@@ -90,7 +149,7 @@ class OrgMemberStore:
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
def remove_user_from_org(org_id: UUID, user_id: UUID) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
@@ -123,3 +182,36 @@ class OrgMemberStore:
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_paginated(
|
||||
org_id: UUID,
|
||||
offset: int = 0,
|
||||
limit: int = 100,
|
||||
) -> tuple[list[OrgMember], bool]:
|
||||
"""Get paginated list of organization members with user and role info.
|
||||
|
||||
Returns:
|
||||
Tuple of (members_list, has_more) where has_more indicates if there are more results.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
# Query for limit + 1 items to determine if there are more results
|
||||
# Order by user_id for consistent pagination
|
||||
query = (
|
||||
select(OrgMember)
|
||||
.options(joinedload(OrgMember.user), joinedload(OrgMember.role))
|
||||
.filter(OrgMember.org_id == org_id)
|
||||
.order_by(OrgMember.user_id)
|
||||
.offset(offset)
|
||||
.limit(limit + 1)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
members = list(result.scalars().all())
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(members) > limit
|
||||
if has_more:
|
||||
# Remove the extra item
|
||||
members = members[:limit]
|
||||
|
||||
return members, has_more
|
||||
|
||||
@@ -521,6 +521,7 @@ class OrgService:
|
||||
Raises:
|
||||
ValueError: If organization not found
|
||||
PermissionError: If user is not a member, or lacks admin/owner role for LLM settings
|
||||
OrgNameExistsError: If new name already exists for another organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
@@ -550,6 +551,24 @@ class OrgService:
|
||||
'User must be a member of the organization to update it'
|
||||
)
|
||||
|
||||
# Check if name is being updated and validate uniqueness
|
||||
if update_data.name is not None:
|
||||
# Check if new name conflicts with another org
|
||||
existing_org_with_name = OrgStore.get_org_by_name(update_data.name)
|
||||
if (
|
||||
existing_org_with_name is not None
|
||||
and existing_org_with_name.id != org_id
|
||||
):
|
||||
logger.warning(
|
||||
'Attempted to update organization with duplicate name',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'attempted_name': update_data.name,
|
||||
},
|
||||
)
|
||||
raise OrgNameExistsError(update_data.name)
|
||||
|
||||
# Check if update contains any LLM settings
|
||||
llm_fields_being_updated = OrgService._has_llm_settings_updates(update_data)
|
||||
if llm_fields_being_updated:
|
||||
@@ -637,10 +656,9 @@ class OrgService:
|
||||
)
|
||||
return None
|
||||
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(org_id)
|
||||
)
|
||||
spend = user_team_info.get('spend', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
|
||||
logger.debug(
|
||||
@@ -842,3 +860,94 @@ class OrgService:
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to delete organization: {str(e)}')
|
||||
|
||||
@staticmethod
|
||||
async def check_byor_export_enabled(user_id: str) -> bool:
|
||||
"""Check if BYOR export is enabled for the user's current org.
|
||||
|
||||
Returns True if the user's current org has byor_export_enabled set to True.
|
||||
Returns False if the user is not found, has no current org, or the flag is False.
|
||||
|
||||
Args:
|
||||
user_id: User ID to check
|
||||
|
||||
Returns:
|
||||
bool: True if BYOR export is enabled, False otherwise
|
||||
"""
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user or not user.current_org_id:
|
||||
return False
|
||||
|
||||
org = OrgStore.get_org_by_id(user.current_org_id)
|
||||
if not org:
|
||||
return False
|
||||
|
||||
return org.byor_export_enabled
|
||||
|
||||
@staticmethod
|
||||
async def switch_org(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Switch user's current organization to the specified organization.
|
||||
|
||||
This method:
|
||||
1. Validates that the organization exists
|
||||
2. Validates that the user is a member of the organization
|
||||
3. Updates the user's current_org_id
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID to switch to
|
||||
|
||||
Returns:
|
||||
Org: The organization that was switched to
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not a member of the organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Step 2: Validate user is a member of the organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'User attempted to switch to organization they are not a member of',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgAuthorizationError(
|
||||
'User must be a member of the organization to switch to it'
|
||||
)
|
||||
|
||||
# Step 3: Update user's current_org_id
|
||||
try:
|
||||
updated_user = UserStore.update_current_org(user_id, org_id)
|
||||
if not updated_user:
|
||||
raise OrgDatabaseError('User not found')
|
||||
|
||||
logger.info(
|
||||
'Successfully switched user organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except OrgDatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to switch user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to switch organization: {str(e)}')
|
||||
|
||||
@@ -10,6 +10,7 @@ from server.constants import (
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.routes.org_models import OrphanedUserError
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
@@ -320,17 +321,41 @@ class OrgStore:
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 3. Delete organization memberships
|
||||
# 3. Handle users with this as current_org_id BEFORE deleting memberships
|
||||
# Single query to find orphaned users (those with no alternative org)
|
||||
orphaned_users = session.execute(
|
||||
text("""
|
||||
SELECT u.id
|
||||
FROM "user" u
|
||||
WHERE u.current_org_id = :org_id
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
)
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
).fetchall()
|
||||
|
||||
if orphaned_users:
|
||||
raise OrphanedUserError([str(row[0]) for row in orphaned_users])
|
||||
|
||||
# Batch update: reassign current_org_id to an alternative org for all affected users
|
||||
session.execute(
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
text("""
|
||||
UPDATE "user" u
|
||||
SET current_org_id = (
|
||||
SELECT om.org_id FROM org_member om
|
||||
WHERE om.user_id = u.id AND om.org_id != :org_id
|
||||
LIMIT 1
|
||||
)
|
||||
WHERE u.current_org_id = :org_id
|
||||
"""),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# 4. Handle users with this as current_org_id
|
||||
# 4. Delete organization memberships (now safe)
|
||||
session.execute(
|
||||
text(
|
||||
'UPDATE "user" SET current_org_id = NULL WHERE current_org_id = :org_id'
|
||||
),
|
||||
text('DELETE FROM org_member WHERE org_id = :org_id'),
|
||||
{'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
|
||||
35
enterprise/storage/resend_synced_user.py
Normal file
35
enterprise/storage/resend_synced_user.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""SQLAlchemy model for tracking users synced to Resend audiences."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import Column, DateTime, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class ResendSyncedUser(Base): # type: ignore
|
||||
"""Tracks users that have been synced to a Resend audience.
|
||||
|
||||
This table ensures that once a user is synced to a Resend audience,
|
||||
they won't be re-added even if they are later deleted from the
|
||||
Resend UI. This respects manual deletions/unsubscribes.
|
||||
"""
|
||||
|
||||
__tablename__ = 'resend_synced_users'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
email = Column(String, nullable=False, index=True)
|
||||
audience_id = Column(String, nullable=False, index=True)
|
||||
synced_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
keycloak_user_id = Column(String, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
125
enterprise/storage/resend_synced_user_store.py
Normal file
125
enterprise/storage/resend_synced_user_store.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""Store class for managing Resend synced users."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Set
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResendSyncedUserStore:
|
||||
"""Store for tracking users synced to Resend audiences."""
|
||||
|
||||
session_maker: sessionmaker
|
||||
|
||||
def is_user_synced(self, email: str, audience_id: str) -> bool:
|
||||
"""Check if a user has been synced to a specific audience.
|
||||
|
||||
Args:
|
||||
email: The email address to check.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
True if the user has been synced, False otherwise.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = select(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt).first()
|
||||
return result is not None
|
||||
|
||||
def get_synced_emails_for_audience(self, audience_id: str) -> Set[str]:
|
||||
"""Get all synced email addresses for a specific audience.
|
||||
|
||||
Args:
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
A set of lowercase email addresses that have been synced.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = select(ResendSyncedUser.email).where(
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt).scalars().all()
|
||||
return set(result)
|
||||
|
||||
def mark_user_synced(
|
||||
self,
|
||||
email: str,
|
||||
audience_id: str,
|
||||
keycloak_user_id: Optional[str] = None,
|
||||
) -> ResendSyncedUser:
|
||||
"""Mark a user as synced to a specific audience.
|
||||
|
||||
Uses upsert to handle race conditions - if the user is already
|
||||
marked as synced, this is a no-op.
|
||||
|
||||
Args:
|
||||
email: The email address of the user.
|
||||
audience_id: The Resend audience ID.
|
||||
keycloak_user_id: Optional Keycloak user ID.
|
||||
|
||||
Returns:
|
||||
The ResendSyncedUser record.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the record could not be created or retrieved.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = (
|
||||
insert(ResendSyncedUser)
|
||||
.values(
|
||||
email=email.lower(),
|
||||
audience_id=audience_id,
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
synced_at=datetime.now(UTC),
|
||||
)
|
||||
.on_conflict_do_nothing(constraint='uq_resend_synced_email_audience')
|
||||
.returning(ResendSyncedUser)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
row = result.first()
|
||||
if row:
|
||||
return row[0]
|
||||
|
||||
# on_conflict_do_nothing triggered, fetch the existing record
|
||||
existing = session.execute(
|
||||
select(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
).first()
|
||||
if existing:
|
||||
return existing[0]
|
||||
|
||||
raise RuntimeError(
|
||||
f'Failed to create or retrieve synced user record for {email}'
|
||||
)
|
||||
|
||||
def remove_synced_user(self, email: str, audience_id: str) -> bool:
|
||||
"""Remove a user's synced status for a specific audience.
|
||||
|
||||
Args:
|
||||
email: The email address of the user.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
True if a record was deleted, False if no record existed.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
stmt = delete(ResendSyncedUser).where(
|
||||
ResendSyncedUser.email == email.lower(),
|
||||
ResendSyncedUser.audience_id == audience_id,
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -29,6 +29,20 @@ class RoleStore:
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
async def get_role_by_id_async(
|
||||
role_id: int,
|
||||
session: Optional[AsyncSession] = None,
|
||||
) -> Optional[Role]:
|
||||
"""Get role by ID (async version)."""
|
||||
if session is not None:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(Role).where(Role.id == role_id))
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
|
||||
@@ -34,11 +34,10 @@ class SaasConversationStore(ConversationStore):
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
def __init__(self, user_id: str, org_id: UUID, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.org_id = org_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
self.org_id = user.current_org_id if user else None
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
@@ -235,4 +234,6 @@ class SaasConversationStore(ConversationStore):
|
||||
cls, config: OpenHandsConfig, user_id: str | None
|
||||
) -> ConversationStore:
|
||||
# user_id should not be None in SaaS, should we raise?
|
||||
return SaasConversationStore(str(user_id), session_maker)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(str(user_id), org_id, session_maker)
|
||||
|
||||
@@ -8,10 +8,11 @@ from dataclasses import dataclass
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import SecretStr
|
||||
from server.constants import LITE_LLM_API_URL
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.lite_llm_manager import LiteLlmManager, get_openhands_cloud_key_alias
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
@@ -23,6 +24,7 @@ from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.llm import is_openhands_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -123,7 +125,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
with self.session_maker() as session:
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
@@ -144,23 +145,29 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
|
||||
org_id = user.current_org_id
|
||||
# Check if provider is OpenHands and generate API key if needed
|
||||
if self._is_openhands_provider(item):
|
||||
await self._ensure_openhands_api_key(item, str(org_id))
|
||||
org_member = None
|
||||
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
|
||||
org: Org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
# Check if we need to generate an LLM key.
|
||||
if item.llm_base_url == LITE_LLM_API_URL:
|
||||
await self._ensure_api_key(
|
||||
item, str(org_id), openhands_type=is_openhands_model(item.llm_model)
|
||||
)
|
||||
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
@@ -223,32 +230,49 @@ class SaasSettingsStore(SettingsStore):
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _is_openhands_provider(self, item: Settings) -> bool:
|
||||
"""Check if the settings use the OpenHands provider."""
|
||||
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
|
||||
|
||||
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
|
||||
async def _ensure_api_key(
|
||||
self, item: Settings, org_id: str, openhands_type: bool = False
|
||||
) -> None:
|
||||
"""Generate and set the OpenHands API key for the given settings.
|
||||
|
||||
First checks if an existing key with the OpenHands alias exists,
|
||||
and reuses it if found. Otherwise, generates a new key.
|
||||
First checks if an existing key exists for the user and verifies it
|
||||
is valid in LiteLLM. If valid, reuses it. Otherwise, generates a new key.
|
||||
"""
|
||||
# Generate new key if none exists
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
|
||||
# First, check if our current key is valid
|
||||
if item.llm_api_key and not await LiteLlmManager.verify_existing_key(
|
||||
item.llm_api_key.get_secret_value(),
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
openhands_type=openhands_type,
|
||||
):
|
||||
generated_key = None
|
||||
if openhands_type:
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
None,
|
||||
{'type': 'openhands'},
|
||||
)
|
||||
else:
|
||||
# Must delete any existing key with the same alias first
|
||||
key_alias = get_openhands_cloud_key_alias(self.user_id, org_id)
|
||||
await LiteLlmManager.delete_key_by_alias(key_alias=key_alias)
|
||||
generated_key = await LiteLlmManager.generate_key(
|
||||
self.user_id,
|
||||
org_id,
|
||||
key_alias,
|
||||
None,
|
||||
)
|
||||
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if generated_key:
|
||||
item.llm_api_key = SecretStr(generated_key)
|
||||
logger.info(
|
||||
'saas_settings_store:store:generated_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'saas_settings_store:store:failed_to_generate_openhands_key',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ Store class for managing users.
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
@@ -19,6 +20,7 @@ from sqlalchemy.orm import joinedload
|
||||
from storage.database import a_session_maker, session_maker
|
||||
from storage.encrypt_utils import (
|
||||
decrypt_legacy_model,
|
||||
decrypt_legacy_value,
|
||||
encrypt_legacy_value,
|
||||
)
|
||||
from storage.org import Org
|
||||
@@ -26,6 +28,7 @@ from storage.org_member import OrgMember
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
@@ -52,7 +55,8 @@ class UserStore:
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_name=resolve_display_name(user_info)
|
||||
or user_info.get('preferred_username', ''),
|
||||
contact_email=user_info['email'],
|
||||
v1_enabled=True,
|
||||
)
|
||||
@@ -79,6 +83,8 @@ class UserStore:
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
user.email = user_info.get('email')
|
||||
user.email_verified = user_info.get('email_verified')
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('owner')
|
||||
@@ -130,6 +136,25 @@ class UserStore:
|
||||
)
|
||||
return bool(lock_acquired)
|
||||
|
||||
@staticmethod
|
||||
async def _release_user_creation_lock(user_id: str) -> bool:
|
||||
"""Release the distributed lock for user creation.
|
||||
|
||||
Returns True if the lock was released or if Redis is unavailable.
|
||||
Returns False if the lock could not be released.
|
||||
"""
|
||||
redis_client = UserStore._get_redis_client()
|
||||
if redis_client is None:
|
||||
logger.warning(
|
||||
'user_store:_release_user_creation_lock:no_redis_client',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True # Nothing to release if Redis is unavailable
|
||||
|
||||
user_key = f'{_REDIS_USER_CREATION_KEY_PREFIX}{user_id}'
|
||||
deleted = await redis_client.delete(user_key)
|
||||
return bool(deleted)
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
user_id: str,
|
||||
@@ -150,13 +175,28 @@ class UserStore:
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# Check if user has completed billing sessions to enable BYOR export
|
||||
from storage.billing_session import BillingSession
|
||||
|
||||
has_completed_billing = (
|
||||
session.query(BillingSession)
|
||||
.filter(
|
||||
BillingSession.user_id == user_id,
|
||||
BillingSession.status == 'completed',
|
||||
)
|
||||
.first()
|
||||
is not None
|
||||
)
|
||||
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
org_version=user_settings.user_version,
|
||||
contact_name=user_info['username'],
|
||||
contact_name=resolve_display_name(user_info)
|
||||
or user_info.get('username', ''),
|
||||
contact_email=user_info['email'],
|
||||
byor_export_enabled=has_completed_billing,
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
@@ -334,7 +374,9 @@ class UserStore:
|
||||
|
||||
@staticmethod
|
||||
async def downgrade_user(user_id: str) -> UserSettings | None:
|
||||
"""Downgrade a migrated user back to the pre-migration state.
|
||||
"""
|
||||
This method can be removed once orgs is established - probably after Feb 15 2026
|
||||
Downgrade a migrated user back to the pre-migration state.
|
||||
|
||||
This reverses the migrate_user operation:
|
||||
1. Get the user's settings from user_settings table (migrated users) or
|
||||
@@ -388,6 +430,24 @@ class UserStore:
|
||||
)
|
||||
return None
|
||||
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
)
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
'user_store:downgrade_user:unexpected_org_members_count',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'org_members_count': len(org_members),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_member = org_members[0]
|
||||
|
||||
# Get the user_settings (for migrated users)
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
@@ -400,42 +460,34 @@ class UserStore:
|
||||
|
||||
# For new sign-ups after migration, user_settings won't exist
|
||||
# Fall back to getting data from org_members
|
||||
is_new_signup = False
|
||||
if not user_settings:
|
||||
if user_settings:
|
||||
if org_member.llm_api_key and org_member.llm_api_key.get_secret_value():
|
||||
user_settings.llm_api_key = encrypt_legacy_value(
|
||||
org_member.llm_api_key.get_secret_value()
|
||||
)
|
||||
if (
|
||||
org_member.llm_api_key_for_byor
|
||||
and org_member.llm_api_key_for_byor.get_secret_value()
|
||||
):
|
||||
user_settings.llm_api_key_for_byor = encrypt_legacy_value(
|
||||
org_member.llm_api_key_for_byor.get_secret_value()
|
||||
)
|
||||
logger.info(
|
||||
'user_store:downgrade_user:user_settings_not_found_checking_org_members',
|
||||
'user_store:downgrade_user:updated_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Get org_members for this org - should only be one for personal orgs
|
||||
org_members = (
|
||||
session.query(OrgMember).filter(OrgMember.org_id == org.id).all()
|
||||
)
|
||||
|
||||
if len(org_members) != 1:
|
||||
logger.error(
|
||||
'user_store:downgrade_user:unexpected_org_members_count',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'org_members_count': len(org_members),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
org_member = org_members[0]
|
||||
is_new_signup = True
|
||||
|
||||
else:
|
||||
# Create a new user_settings entry from OrgMember, User, and Org data
|
||||
# This is needed for new sign-ups who don't have user_settings
|
||||
user_settings = UserStore._create_user_settings_from_entities(
|
||||
user_id, org_member, user, org
|
||||
)
|
||||
session.add(user_settings)
|
||||
session.flush()
|
||||
logger.info(
|
||||
'user_store:downgrade_user:created_user_settings_from_org_member',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# Call LiteLLM downgrade
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
@@ -445,27 +497,25 @@ class UserStore:
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Get the API keys for LiteLLM downgrade
|
||||
if is_new_signup:
|
||||
# For new signups, we already have decrypted values in user_settings
|
||||
decrypted_user_settings = user_settings
|
||||
else:
|
||||
# For migrated users, decrypt the legacy model
|
||||
kwargs = decrypt_legacy_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
encrypted_fields = [
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
]
|
||||
for field in encrypted_fields:
|
||||
value = getattr(user_settings, field, None)
|
||||
if value:
|
||||
try:
|
||||
value = decrypt_legacy_value(value)
|
||||
setattr(user_settings, field, value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await LiteLlmManager.downgrade_entries(
|
||||
str(org.id),
|
||||
user_id,
|
||||
decrypted_user_settings,
|
||||
user_settings,
|
||||
)
|
||||
logger.debug(
|
||||
'user_store:downgrade_user:done_litellm_downgrade_entries',
|
||||
@@ -569,7 +619,7 @@ class UserStore:
|
||||
]
|
||||
for key in encrypt_keys:
|
||||
value = getattr(user_settings, key, None)
|
||||
if value is not None:
|
||||
if value is not None and not _is_legacy_value_encrypted(value):
|
||||
setattr(user_settings, key, encrypt_legacy_value(value))
|
||||
|
||||
session.merge(user_settings)
|
||||
@@ -613,41 +663,46 @@ class UserStore:
|
||||
asyncio.sleep, GENERAL_TIMEOUT, _RETRY_LOAD_DELAY_SECONDS
|
||||
)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
try:
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
if user:
|
||||
return user
|
||||
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = call_async_from_sync(
|
||||
token_manager.get_user_info_from_user_id,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
finally:
|
||||
call_async_from_sync(
|
||||
UserStore._release_user_creation_lock, GENERAL_TIMEOUT, user_id
|
||||
)
|
||||
user = call_async_from_sync(
|
||||
UserStore.migrate_user,
|
||||
GENERAL_TIMEOUT,
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_id_async(user_id: str) -> Optional[User]:
|
||||
@@ -675,42 +730,69 @@ class UserStore:
|
||||
)
|
||||
await asyncio.sleep(_RETRY_LOAD_DELAY_SECONDS)
|
||||
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
try:
|
||||
# Check for user again as migration could have happened while trying to get the lock.
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:start_migration',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:calling_migrate_user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
finally:
|
||||
await UserStore._release_user_creation_lock(user_id)
|
||||
|
||||
@staticmethod
|
||||
async def get_user_by_email_async(email: str) -> Optional[User]:
|
||||
"""Get user by email address (async version).
|
||||
|
||||
This method looks up a user by their email address. Note that email
|
||||
addresses may not be unique across all users in rare cases.
|
||||
|
||||
Args:
|
||||
email: The email address to search for
|
||||
|
||||
Returns:
|
||||
User: The user with the matching email, or None if not found
|
||||
"""
|
||||
if not email:
|
||||
return None
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.filter(User.email == email.lower().strip())
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return user
|
||||
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:start_migration',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
result = await session.execute(
|
||||
select(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == user_id,
|
||||
UserSettings.already_migrated.is_(False),
|
||||
)
|
||||
)
|
||||
user_settings = result.scalars().first()
|
||||
if user_settings:
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id)
|
||||
logger.info(
|
||||
'user_store:get_user_by_id_async:calling_migrate_user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
user = await UserStore.migrate_user(
|
||||
user_id,
|
||||
user_settings,
|
||||
user_info,
|
||||
)
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
@@ -718,6 +800,75 @@ class UserStore:
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
@staticmethod
|
||||
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
|
||||
"""Update the user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
org_id: The organization ID to set as current
|
||||
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.current_org_id = org_id
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
|
||||
"""Update contact_name on the personal org if it still has a username-style value.
|
||||
|
||||
Called during login to gradually fix existing users whose contact_name
|
||||
was stored as their username (before the resolve_display_name fix).
|
||||
Preserves custom values that were set via the PATCH endpoint.
|
||||
"""
|
||||
real_name = resolve_display_name(user_info)
|
||||
if not real_name:
|
||||
logger.debug(
|
||||
'backfill_contact_name:no_real_name',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
preferred_username = user_info.get('preferred_username', '')
|
||||
username = user_info.get('username', '')
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(Org).filter(Org.id == uuid.UUID(user_id))
|
||||
)
|
||||
org = result.scalars().first()
|
||||
if not org:
|
||||
logger.debug(
|
||||
'backfill_contact_name:org_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
if org.contact_name in (preferred_username, username):
|
||||
logger.info(
|
||||
'backfill_contact_name:updated',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'old': org.contact_name,
|
||||
'new': real_name,
|
||||
},
|
||||
)
|
||||
org.contact_name = real_name
|
||||
await session.commit()
|
||||
|
||||
# Prevent circular imports
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -907,3 +1058,12 @@ class UserStore:
|
||||
return False # Matches old default
|
||||
|
||||
return True # Custom model
|
||||
|
||||
|
||||
def _is_legacy_value_encrypted(value: str) -> bool:
|
||||
"""Check if a legacy value is encrypted by trying to decrypt it"""
|
||||
try:
|
||||
decrypt_legacy_value(value)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
import asyncio
|
||||
import asyncio # noqa: I001
|
||||
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
# This must be before the import of storage
|
||||
# to set up logging and prevent alembic from
|
||||
# running its mouth.
|
||||
from openhands.core.logger import openhands_logger
|
||||
|
||||
from storage.proactive_conversation_store import (
|
||||
ProactiveConversationStore,
|
||||
)
|
||||
|
||||
OLDER_THAN = 30 # 30 minutes
|
||||
|
||||
|
||||
async def main():
|
||||
openhands_logger.info('clean_proactive_convo_table')
|
||||
convo_store = ProactiveConversationStore()
|
||||
await convo_store.clean_old_convos(older_than_minutes=OLDER_THAN)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ Optional environment variables:
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
@@ -34,6 +35,7 @@ import resend
|
||||
from keycloak.exceptions import KeycloakError
|
||||
from resend.exceptions import ResendError
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.resend_synced_user_store import ResendSyncedUserStore
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
@@ -68,9 +70,6 @@ RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
|
||||
# Set up Resend API
|
||||
resend.api_key = RESEND_API_KEY
|
||||
|
||||
print('resend module', resend)
|
||||
print('has contacts', hasattr(resend, 'Contacts'))
|
||||
|
||||
|
||||
class ResendSyncError(Exception):
|
||||
"""Base exception for Resend sync errors."""
|
||||
@@ -90,6 +89,31 @@ class ResendAPIError(ResendSyncError):
|
||||
pass
|
||||
|
||||
|
||||
# Email validation regex pattern - matches standard email format
|
||||
# This pattern is intentionally strict to avoid Resend API validation errors
|
||||
# It rejects special characters like ! that some email providers technically allow
|
||||
# but Resend's API does not accept
|
||||
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$')
|
||||
|
||||
|
||||
def is_valid_email(email: Optional[str]) -> bool:
|
||||
"""Validate an email address format.
|
||||
|
||||
This uses a regex pattern that matches most valid email addresses
|
||||
while rejecting addresses with special characters that Resend's API
|
||||
does not accept (e.g., exclamation marks).
|
||||
|
||||
Args:
|
||||
email: The email address to validate, or None.
|
||||
|
||||
Returns:
|
||||
True if the email is valid, False otherwise (including for None).
|
||||
"""
|
||||
if not email:
|
||||
return False
|
||||
return bool(EMAIL_REGEX.match(email))
|
||||
|
||||
|
||||
def get_keycloak_users(offset: int = 0, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get users from Keycloak using the admin client.
|
||||
|
||||
@@ -173,8 +197,6 @@ def get_resend_contacts(audience_id: str) -> Dict[str, Dict[str, Any]]:
|
||||
Raises:
|
||||
ResendAPIError: If the API call fails.
|
||||
"""
|
||||
print('getting resend contacts')
|
||||
print('has resend contacts', hasattr(resend, 'Contacts'))
|
||||
try:
|
||||
contacts = resend.Contacts.list(audience_id).get('data', [])
|
||||
# Create a dictionary mapping email addresses to contact data for
|
||||
@@ -229,6 +251,15 @@ def add_contact_to_resend(
|
||||
raise
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(MAX_RETRIES),
|
||||
wait=wait_exponential(
|
||||
multiplier=INITIAL_BACKOFF_SECONDS,
|
||||
max=MAX_BACKOFF_SECONDS,
|
||||
exp_base=BACKOFF_FACTOR,
|
||||
),
|
||||
retry=retry_if_exception_type(ResendError),
|
||||
)
|
||||
def send_welcome_email(
|
||||
email: str,
|
||||
first_name: Optional[str] = None,
|
||||
@@ -245,7 +276,7 @@ def send_welcome_email(
|
||||
The API response.
|
||||
|
||||
Raises:
|
||||
ResendError: If the API call fails.
|
||||
ResendError: If the API call fails after retries.
|
||||
"""
|
||||
try:
|
||||
# Prepare the recipient name
|
||||
@@ -291,8 +322,84 @@ def send_welcome_email(
|
||||
raise
|
||||
|
||||
|
||||
def _get_resend_synced_user_store() -> ResendSyncedUserStore:
|
||||
"""Get the ResendSyncedUserStore instance.
|
||||
|
||||
This is separated into a function to allow for easier testing/mocking.
|
||||
"""
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
config = get_global_config()
|
||||
db_session_injector = config.db_session
|
||||
return ResendSyncedUserStore(session_maker=db_session_injector.get_session_maker())
|
||||
|
||||
|
||||
def _backfill_existing_resend_contacts(
|
||||
synced_user_store: ResendSyncedUserStore,
|
||||
audience_id: str,
|
||||
) -> int:
|
||||
"""Backfill the synced_users table with contacts already in Resend.
|
||||
|
||||
This ensures that users who were added to Resend before the tracking
|
||||
table existed are properly recorded, preventing duplicate welcome emails.
|
||||
|
||||
Args:
|
||||
synced_user_store: The store for tracking synced users.
|
||||
audience_id: The Resend audience ID.
|
||||
|
||||
Returns:
|
||||
The number of contacts backfilled.
|
||||
"""
|
||||
logger.info('Starting backfill of existing Resend contacts...')
|
||||
|
||||
try:
|
||||
resend_contacts = get_resend_contacts(audience_id)
|
||||
logger.info(f'Found {len(resend_contacts)} contacts in Resend audience')
|
||||
|
||||
already_synced_emails = synced_user_store.get_synced_emails_for_audience(
|
||||
audience_id
|
||||
)
|
||||
logger.info(
|
||||
f'Found {len(already_synced_emails)} already synced emails in database'
|
||||
)
|
||||
|
||||
backfilled_count = 0
|
||||
for email in resend_contacts:
|
||||
if email.lower() not in already_synced_emails:
|
||||
synced_user_store.mark_user_synced(
|
||||
email=email,
|
||||
audience_id=audience_id,
|
||||
keycloak_user_id=None, # We don't have this info during backfill
|
||||
)
|
||||
backfilled_count += 1
|
||||
logger.debug(f'Backfilled existing Resend contact: {email}')
|
||||
|
||||
logger.info(
|
||||
f'Backfill completed: {backfilled_count} contacts added to tracking'
|
||||
)
|
||||
return backfilled_count
|
||||
|
||||
except Exception:
|
||||
logger.exception('Error during backfill of existing Resend contacts')
|
||||
# Don't fail the entire sync if backfill fails - just log and continue
|
||||
return 0
|
||||
|
||||
|
||||
def sync_users_to_resend():
|
||||
"""Sync users from Keycloak to Resend."""
|
||||
"""Sync users from Keycloak to Resend.
|
||||
|
||||
This function syncs users from Keycloak to a Resend audience. It tracks
|
||||
which users have been synced in the database to ensure that:
|
||||
1. Users are only added once (even across multiple sync runs)
|
||||
2. Users who are manually deleted from Resend are not re-added
|
||||
|
||||
The tracking is done via the resend_synced_users table, which records
|
||||
each email/audience_id combination that has been synced.
|
||||
|
||||
On first run (or when new contacts exist in Resend), it will backfill
|
||||
the tracking table with existing Resend contacts to avoid sending
|
||||
duplicate welcome emails.
|
||||
"""
|
||||
# Check required environment variables
|
||||
required_vars = {
|
||||
'RESEND_API_KEY': RESEND_API_KEY,
|
||||
@@ -318,27 +425,36 @@ def sync_users_to_resend():
|
||||
)
|
||||
|
||||
try:
|
||||
# Get the store for tracking synced users
|
||||
synced_user_store = _get_resend_synced_user_store()
|
||||
|
||||
# Backfill existing Resend contacts into our tracking table
|
||||
# This ensures users already in Resend don't get duplicate welcome emails
|
||||
backfilled_count = _backfill_existing_resend_contacts(
|
||||
synced_user_store, RESEND_AUDIENCE_ID
|
||||
)
|
||||
|
||||
# Get the total number of users
|
||||
total_users = get_total_keycloak_users()
|
||||
logger.info(
|
||||
f'Found {total_users} users in Keycloak realm {KEYCLOAK_REALM_NAME}'
|
||||
)
|
||||
|
||||
# Get contacts from Resend
|
||||
resend_contacts = get_resend_contacts(RESEND_AUDIENCE_ID)
|
||||
logger.info(
|
||||
f'Found {len(resend_contacts)} contacts in Resend audience '
|
||||
f'{RESEND_AUDIENCE_ID}'
|
||||
)
|
||||
|
||||
# Stats
|
||||
stats = {
|
||||
'total_users': total_users,
|
||||
'existing_contacts': len(resend_contacts),
|
||||
'backfilled_contacts': backfilled_count,
|
||||
'already_synced': 0,
|
||||
'added_contacts': 0,
|
||||
'skipped_invalid_emails': 0,
|
||||
'errors': 0,
|
||||
}
|
||||
|
||||
synced_emails = synced_user_store.get_synced_emails_for_audience(
|
||||
RESEND_AUDIENCE_ID
|
||||
)
|
||||
logger.info(f'Found {len(synced_emails)} already synced emails in database')
|
||||
|
||||
# Process users in batches
|
||||
offset = 0
|
||||
while offset < total_users:
|
||||
@@ -351,39 +467,65 @@ def sync_users_to_resend():
|
||||
continue
|
||||
|
||||
email = email.lower()
|
||||
if email in resend_contacts:
|
||||
logger.debug(f'User {email} already exists in Resend, skipping')
|
||||
|
||||
if email in synced_emails:
|
||||
logger.debug(
|
||||
f'User {email} was already synced to this audience, skipping'
|
||||
)
|
||||
stats['already_synced'] += 1
|
||||
continue
|
||||
|
||||
# Validate email format before attempting to add to Resend
|
||||
if not is_valid_email(email):
|
||||
logger.warning(f'Skipping user with invalid email format: {email}')
|
||||
stats['skipped_invalid_emails'] += 1
|
||||
continue
|
||||
|
||||
first_name = user.get('first_name')
|
||||
last_name = user.get('last_name')
|
||||
keycloak_user_id = user.get('id')
|
||||
|
||||
# Mark as synced first (optimistic) to ensure consistency.
|
||||
# If Resend API fails, we remove the record.
|
||||
try:
|
||||
synced_user_store.mark_user_synced(
|
||||
email=email,
|
||||
audience_id=RESEND_AUDIENCE_ID,
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f'Failed to mark user {email} as synced')
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
try:
|
||||
first_name = user.get('first_name')
|
||||
last_name = user.get('last_name')
|
||||
|
||||
# Add the contact to the Resend audience
|
||||
add_contact_to_resend(
|
||||
RESEND_AUDIENCE_ID, email, first_name, last_name
|
||||
)
|
||||
logger.info(f'Added user {email} to Resend')
|
||||
stats['added_contacts'] += 1
|
||||
|
||||
# Sleep to respect rate limit after first API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
# Send a welcome email to the newly added contact
|
||||
try:
|
||||
send_welcome_email(email, first_name, last_name)
|
||||
logger.info(f'Sent welcome email to {email}')
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f'Failed to send welcome email to {email}, but contact was added to audience'
|
||||
)
|
||||
# Continue with the sync process even if sending the welcome email fails
|
||||
|
||||
# Sleep to respect rate limit after second API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
except Exception:
|
||||
logger.exception(f'Error adding user {email} to Resend')
|
||||
synced_user_store.remove_synced_user(email, RESEND_AUDIENCE_ID)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
synced_emails.add(email)
|
||||
stats['added_contacts'] += 1
|
||||
|
||||
# Sleep to respect rate limit after first API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
# Send a welcome email to the newly added contact
|
||||
try:
|
||||
send_welcome_email(email, first_name, last_name)
|
||||
logger.info(f'Sent welcome email to {email}')
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f'Failed to send welcome email to {email}, but contact was added to audience'
|
||||
)
|
||||
|
||||
# Sleep to respect rate limit after second API call
|
||||
time.sleep(1 / RATE_LIMIT)
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
|
||||
@@ -126,3 +126,24 @@ def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeyp
|
||||
# Should be a different instance than the original (copied after handler runs)
|
||||
assert result is not agent
|
||||
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
|
||||
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
|
||||
def test_run_agent_variant_tests_v1_preserves_planning_agent_system_prompt():
|
||||
"""Planning agents should retain their specialized system prompt and not be overwritten by the experiment."""
|
||||
# Arrange
|
||||
planning_agent = make_agent().model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_planning.j2'}
|
||||
)
|
||||
conv_id = uuid4()
|
||||
|
||||
# Act
|
||||
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
|
||||
user_id='user-planning',
|
||||
conversation_id=conv_id,
|
||||
agent=planning_agent,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.system_prompt_filename == 'system_prompt_planning.j2'
|
||||
|
||||
@@ -141,12 +141,14 @@ def test_custom_to_static_conversion():
|
||||
|
||||
def create_provider_tokens(
|
||||
tokens_dict: dict[ProviderType, str],
|
||||
) -> dict[ProviderType, ProviderToken]:
|
||||
"""Helper to create provider tokens dictionary."""
|
||||
return {
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
) -> MappingProxyType:
|
||||
"""Helper to create provider tokens as MappingProxyType."""
|
||||
return MappingProxyType(
|
||||
{
|
||||
provider_type: ProviderToken(token=SecretStr(token_value))
|
||||
for provider_type, token_value in tokens_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -264,3 +266,63 @@ async def test_get_latest_token_can_be_used_with_static_secret(
|
||||
# Assert - this should NOT raise a ValidationError
|
||||
static_secret = StaticSecret(value=token, description='GITHUB authentication token')
|
||||
assert static_secret.get_value() == token_value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for get_authenticated_git_url - ensuring proper authenticated URLs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_authenticated_git_url_raises_when_no_tokens(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that get_authenticated_git_url raises error when no provider tokens available."""
|
||||
# Arrange
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=None)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match='No provider tokens available'):
|
||||
await resolver_context.get_authenticated_git_url('owner/repo')
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_handler_caches_instance(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that _get_provider_handler caches the handler instance."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
|
||||
|
||||
# Act - call _get_provider_handler twice
|
||||
handler1 = await resolver_context._get_provider_handler()
|
||||
handler2 = await resolver_context._get_provider_handler()
|
||||
|
||||
# Assert - should be the same instance (cached)
|
||||
assert handler1 is handler2
|
||||
# get_provider_tokens should only be called once
|
||||
assert mock_saas_user_auth.get_provider_tokens.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_provider_handler_creates_handler_with_correct_params(
|
||||
resolver_context, mock_saas_user_auth
|
||||
):
|
||||
"""Test that _get_provider_handler creates ProviderHandler with correct parameters."""
|
||||
# Arrange
|
||||
token_value = 'ghp_test_token'
|
||||
provider_tokens = create_provider_tokens({ProviderType.GITHUB: token_value})
|
||||
mock_saas_user_auth.get_provider_tokens = AsyncMock(return_value=provider_tokens)
|
||||
mock_saas_user_auth.get_user_id = AsyncMock(return_value='test-user-id')
|
||||
|
||||
# Act
|
||||
handler = await resolver_context._get_provider_handler()
|
||||
|
||||
# Assert
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
|
||||
assert isinstance(handler, ProviderHandler)
|
||||
assert handler.provider_tokens == provider_tokens
|
||||
|
||||
@@ -6,6 +6,9 @@ import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.api_keys import (
|
||||
ByorPermittedResponse,
|
||||
LlmApiKeyResponse,
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
@@ -182,16 +185,18 @@ class TestGetLlmApiKeyForByor:
|
||||
"""Test the get_llm_api_key_for_byor endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_no_key_in_database_generates_new(
|
||||
self, mock_get_key, mock_generate_key, mock_store_key
|
||||
self, mock_get_key, mock_generate_key, mock_store_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when no key exists in database, a new one is generated."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = None
|
||||
mock_generate_key.return_value = new_key
|
||||
mock_store_key.return_value = None
|
||||
@@ -200,21 +205,24 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_generate_key.assert_called_once_with(user_id)
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('storage.lite_llm_manager.LiteLlmManager.verify_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_valid_key_in_database_returns_key(
|
||||
self, mock_get_key, mock_verify_key
|
||||
self, mock_get_key, mock_verify_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when a valid key exists in database, it is returned."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
existing_key = 'sk-existing-valid-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = existing_key
|
||||
mock_verify_key.return_value = True
|
||||
|
||||
@@ -222,11 +230,13 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': existing_key}
|
||||
assert result == LlmApiKeyResponse(key=existing_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_verify_key.assert_called_once_with(existing_key, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -239,12 +249,14 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_delete_key,
|
||||
mock_generate_key,
|
||||
mock_store_key,
|
||||
mock_check_enabled,
|
||||
):
|
||||
"""Test that when an invalid key exists in database, it is regenerated."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
invalid_key = 'sk-invalid-key'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = invalid_key
|
||||
mock_verify_key.return_value = False
|
||||
mock_delete_key.return_value = True
|
||||
@@ -255,7 +267,8 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_get_key.assert_called_once_with(user_id)
|
||||
mock_verify_key.assert_called_once_with(invalid_key, user_id)
|
||||
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||
@@ -263,6 +276,7 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.store_byor_key_in_db')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
|
||||
@@ -275,12 +289,14 @@ class TestGetLlmApiKeyForByor:
|
||||
mock_delete_key,
|
||||
mock_generate_key,
|
||||
mock_store_key,
|
||||
mock_check_enabled,
|
||||
):
|
||||
"""Test that even if deletion fails, regeneration still proceeds."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
invalid_key = 'sk-invalid-key'
|
||||
new_key = 'sk-new-generated-key'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = invalid_key
|
||||
mock_verify_key.return_value = False
|
||||
mock_delete_key.return_value = False # Deletion fails
|
||||
@@ -291,20 +307,23 @@ class TestGetLlmApiKeyForByor:
|
||||
result = await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == {'key': new_key}
|
||||
assert result == LlmApiKeyResponse(key=new_key)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
mock_delete_key.assert_called_once_with(user_id, invalid_key)
|
||||
mock_generate_key.assert_called_once_with(user_id)
|
||||
mock_store_key.assert_called_once_with(user_id, new_key)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.generate_byor_key')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_key_generation_failure_raises_exception(
|
||||
self, mock_get_key, mock_generate_key
|
||||
self, mock_get_key, mock_generate_key, mock_check_enabled
|
||||
):
|
||||
"""Test that when key generation fails, an HTTPException is raised."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.return_value = None
|
||||
mock_generate_key.return_value = None
|
||||
|
||||
@@ -316,11 +335,15 @@ class TestGetLlmApiKeyForByor:
|
||||
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
@patch('server.routes.api_keys.get_byor_key_from_db')
|
||||
async def test_database_error_raises_exception(self, mock_get_key):
|
||||
async def test_database_error_raises_exception(
|
||||
self, mock_get_key, mock_check_enabled
|
||||
):
|
||||
"""Test that database errors are properly handled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
mock_get_key.side_effect = Exception('Database connection error')
|
||||
|
||||
# Act & Assert
|
||||
@@ -330,6 +353,21 @@ class TestGetLlmApiKeyForByor:
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_byor_export_disabled_returns_402(self, mock_check_enabled):
|
||||
"""Test that when BYOR export is disabled, 402 is returned."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_llm_api_key_for_byor(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 402
|
||||
assert 'BYOR key export is not enabled' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteByorKeyFromLitellm:
|
||||
"""Test the delete_byor_key_from_litellm function with alias cleanup."""
|
||||
@@ -425,3 +463,52 @@ class TestDeleteByorKeyFromLitellm:
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestCheckByorPermitted:
|
||||
"""Test the check_byor_permitted endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_permitted_when_enabled(self, mock_check_enabled):
|
||||
"""Test that permitted=True is returned when BYOR export is enabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = True
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == ByorPermittedResponse(permitted=True)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_not_permitted_when_disabled(self, mock_check_enabled):
|
||||
"""Test that permitted=False is returned when BYOR export is disabled."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.return_value = False
|
||||
|
||||
# Act
|
||||
result = await check_byor_permitted(user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert result == ByorPermittedResponse(permitted=False)
|
||||
mock_check_enabled.assert_called_once_with(user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.org_service.OrgService.check_byor_export_enabled')
|
||||
async def test_error_raises_500(self, mock_check_enabled):
|
||||
"""Test that an exception raises 500 error."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_check_enabled.side_effect = Exception('Database error')
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await check_byor_permitted(user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
@@ -1220,3 +1220,60 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
|
||||
|
||||
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
|
||||
assert result == mock_workspace
|
||||
|
||||
|
||||
# Tests for OAuth URL encoding
|
||||
class TestJiraDcOAuthUrlEncoding:
|
||||
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira_dc.get_user_auth')
|
||||
@patch('server.routes.integration.jira_dc.redis_client')
|
||||
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
|
||||
async def test_create_jira_dc_workspace_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_jira_dc_workspace properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
workspace_data = JiraDcWorkspaceCreate(
|
||||
workspace_name='test-workspace',
|
||||
webhook_secret='secret',
|
||||
svc_acc_email='svc@test.com',
|
||||
svc_acc_api_key='key',
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
response = await create_jira_dc_workspace(mock_request, workspace_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira_dc.get_user_auth')
|
||||
@patch('server.routes.integration.jira_dc.redis_client')
|
||||
@patch('server.routes.integration.jira_dc.JIRA_DC_ENABLE_OAUTH', True)
|
||||
async def test_create_workspace_link_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
link_data = JiraDcLinkCreate(workspace_name='test-workspace')
|
||||
|
||||
response = await create_workspace_link(mock_request, link_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@@ -1323,3 +1323,58 @@ async def test_validate_workspace_update_permissions_no_current_link(mock_manage
|
||||
|
||||
result = await _validate_workspace_update_permissions('user1', 'test-workspace')
|
||||
assert result == mock_workspace
|
||||
|
||||
|
||||
# Tests for OAuth URL encoding
|
||||
class TestJiraOAuthUrlEncoding:
|
||||
"""Tests to verify OAuth authorization URLs are properly URL-encoded."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.get_user_auth')
|
||||
@patch('server.routes.integration.jira.redis_client')
|
||||
async def test_create_jira_workspace_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_jira_workspace properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
workspace_data = JiraWorkspaceCreate(
|
||||
workspace_name='test-workspace',
|
||||
webhook_secret='secret',
|
||||
svc_acc_email='svc@test.com',
|
||||
svc_acc_api_key='key',
|
||||
is_active=True,
|
||||
)
|
||||
|
||||
response = await create_jira_workspace(mock_request, workspace_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.integration.jira.get_user_auth')
|
||||
@patch('server.routes.integration.jira.redis_client')
|
||||
async def test_create_workspace_link_url_encoding(
|
||||
self, mock_redis, mock_get_auth, mock_request, mock_user_auth
|
||||
):
|
||||
"""Test that create_workspace_link properly URL-encodes the authorization URL."""
|
||||
mock_get_auth.return_value = mock_user_auth
|
||||
mock_redis.setex.return_value = True
|
||||
link_data = JiraLinkCreate(workspace_name='test-workspace')
|
||||
|
||||
response = await create_workspace_link(mock_request, link_data)
|
||||
content = json.loads(response.body)
|
||||
|
||||
auth_url = content['authorizationUrl']
|
||||
# Verify no raw spaces in the URL (spaces should be encoded as + or %20)
|
||||
assert ' ' not in auth_url
|
||||
# Verify scope parameter contains encoded scopes (+ is valid URL encoding for space)
|
||||
assert 'scope=read%3Ame+read%3Ajira-user+read%3Ajira-work' in auth_url
|
||||
# Verify redirect_uri is properly encoded
|
||||
assert 'redirect_uri=https%3A%2F%2F' in auth_url
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for OAuth2 Device Flow endpoints."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -22,8 +22,10 @@ def mock_device_code_store():
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_key_store():
|
||||
"""Mock API key store."""
|
||||
return MagicMock()
|
||||
"""Mock API key store with async create_api_key."""
|
||||
mock = MagicMock()
|
||||
mock.create_api_key = AsyncMock()
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -204,8 +206,9 @@ class TestDeviceVerificationAuthenticated:
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True
|
||||
|
||||
# Mock API key store
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key = AsyncMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
@@ -228,8 +231,9 @@ class TestDeviceVerificationAuthenticated:
|
||||
@patch('server.routes.oauth_device.device_code_store')
|
||||
async def test_multiple_device_authentication(self, mock_store, mock_api_key_class):
|
||||
"""Test that multiple devices can authenticate simultaneously."""
|
||||
# Mock API key store
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key = AsyncMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Simulate two different devices
|
||||
@@ -486,8 +490,9 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = False # Authorization fails
|
||||
|
||||
# Mock API key store
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key = AsyncMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to authorization failure
|
||||
@@ -518,9 +523,11 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
mock_store.deny_device_code.return_value = True # Cleanup succeeds
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
# Mock API key store to fail on creation (async)
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_store.create_api_key = AsyncMock(
|
||||
side_effect=Exception('Database error')
|
||||
)
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should raise HTTPException due to API key creation failure
|
||||
@@ -558,9 +565,11 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
'Cleanup failed'
|
||||
) # Cleanup fails
|
||||
|
||||
# Mock API key store to fail on creation
|
||||
# Mock API key store to fail on creation (async)
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key.side_effect = Exception('Database error')
|
||||
mock_api_key_store.create_api_key = AsyncMock(
|
||||
side_effect=Exception('Database error')
|
||||
)
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
# Should still raise HTTPException for the original API key creation failure
|
||||
@@ -589,8 +598,9 @@ class TestDeviceVerificationTransactionIntegrity:
|
||||
mock_store.get_by_user_code.return_value = mock_device
|
||||
mock_store.authorize_device_code.return_value = True # Authorization succeeds
|
||||
|
||||
# Mock API key store
|
||||
# Mock API key store with async create_api_key
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.create_api_key = AsyncMock()
|
||||
mock_api_key_class.get_instance.return_value = mock_api_key_store
|
||||
|
||||
result = await device_verification_authenticated(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
1711
enterprise/tests/unit/server/services/test_org_member_service.py
Normal file
1711
enterprise/tests/unit/server/services/test_org_member_service.py
Normal file
File diff suppressed because it is too large
Load Diff
661
enterprise/tests/unit/storage/test_auth_token_store.py
Normal file
661
enterprise/tests/unit/storage/test_auth_token_store.py
Normal file
@@ -0,0 +1,661 @@
|
||||
"""Unit tests for AuthTokenStore."""
|
||||
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.auth.auth_error import TokenRefreshError
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from storage.auth_token_store import (
|
||||
ACCESS_TOKEN_EXPIRY_BUFFER,
|
||||
LOCK_TIMEOUT_SECONDS,
|
||||
AuthTokenStore,
|
||||
)
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
def create_mock_session():
|
||||
"""Create a mock async session with properly configured context managers."""
|
||||
session = AsyncMock()
|
||||
|
||||
# Create async context manager for begin()
|
||||
@asynccontextmanager
|
||||
async def begin_context():
|
||||
yield
|
||||
|
||||
session.begin = begin_context
|
||||
return session
|
||||
|
||||
|
||||
def create_mock_session_maker(mock_session):
|
||||
"""Create a mock async session maker."""
|
||||
|
||||
@asynccontextmanager
|
||||
async def session_context():
|
||||
yield mock_session
|
||||
|
||||
# Return a callable that returns the context manager
|
||||
return lambda: session_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Create mock async session."""
|
||||
return create_mock_session()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Create mock async session maker."""
|
||||
return create_mock_session_maker(mock_session)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_token_store(mock_session_maker):
|
||||
"""Create AuthTokenStore instance with mocked session maker."""
|
||||
return AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Tests for _is_token_expired method."""
|
||||
|
||||
def test_both_tokens_valid(self, auth_token_store):
|
||||
"""Test when both tokens are valid (not expired)."""
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time + 1000
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_access_token_expired(self, auth_token_store):
|
||||
"""Test when access token is expired but within buffer."""
|
||||
current_time = int(time.time())
|
||||
# Access token expires within buffer period
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER - 100
|
||||
refresh_expires = current_time + 10000
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is False
|
||||
|
||||
def test_refresh_token_expired(self, auth_token_store):
|
||||
"""Test when refresh token is expired."""
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
refresh_expires = current_time - 100 # Already expired
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_both_tokens_expired(self, auth_token_store):
|
||||
"""Test when both tokens are expired."""
|
||||
current_time = int(time.time())
|
||||
access_expires = current_time - 100
|
||||
refresh_expires = current_time - 100
|
||||
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(
|
||||
access_expires, refresh_expires
|
||||
)
|
||||
|
||||
assert access_expired is True
|
||||
assert refresh_expired is True
|
||||
|
||||
def test_zero_expiration_treated_as_never_expires(self, auth_token_store):
|
||||
"""Test that 0 expiration time is treated as never expires."""
|
||||
access_expired, refresh_expired = auth_token_store._is_token_expired(0, 0)
|
||||
|
||||
assert access_expired is False
|
||||
assert refresh_expired is False
|
||||
|
||||
|
||||
class TestLoadTokensFastPath:
|
||||
"""Tests for load_tokens fast path (no lock needed)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_token_not_found(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns None when no token record exists."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.load_tokens()
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_valid_token_no_refresh_needed(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns tokens when they are still valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
mock_token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.load_tokens()
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'valid-access-token'
|
||||
assert result['refresh_token'] == 'valid-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_no_refresh_callback_provided(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test fast path returns existing tokens when no refresh callback is provided."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access-token'
|
||||
mock_token.refresh_token = 'valid-refresh-token'
|
||||
# Expired access token
|
||||
mock_token.access_token_expires_at = current_time - 100
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.load_tokens(check_expiration_and_refresh=None)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'expired-access-token'
|
||||
|
||||
|
||||
class TestLoadTokensSlowPath:
|
||||
"""Tests for load_tokens slow path (lock required for refresh)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_successful_refresh(self):
|
||||
"""Test slow path successfully refreshes expired tokens."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
# Second call (slow path) - returns same token for update
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'expired-access-token'
|
||||
expired_token.refresh_token = 'valid-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-access-token',
|
||||
'refresh_token': 'new-refresh-token',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'new-access-token'
|
||||
assert result['refresh_token'] == 'new-refresh-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_double_check_avoids_refresh(self):
|
||||
"""Test double-check locking: token was refreshed by another request."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# Simulate scenario:
|
||||
# 1. Fast path sees expired token
|
||||
# 2. While waiting for lock, another request refreshes
|
||||
# 3. Slow path sees fresh token, skips refresh
|
||||
|
||||
call_count = [0]
|
||||
|
||||
def create_token():
|
||||
call_count[0] += 1
|
||||
token = MagicMock()
|
||||
token.id = 1
|
||||
token.access_token = 'fresh-access-token'
|
||||
token.refresh_token = 'fresh-refresh-token'
|
||||
if call_count[0] == 1:
|
||||
# First call (fast path) - expired
|
||||
token.access_token_expires_at = current_time - 100
|
||||
else:
|
||||
# Second call (slow path) - already refreshed
|
||||
token.access_token_expires_at = (
|
||||
current_time + ACCESS_TOKEN_EXPIRY_BUFFER + 1000
|
||||
)
|
||||
token.refresh_token_expires_at = current_time + 86400
|
||||
return token
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = (
|
||||
lambda: create_token()
|
||||
)
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
refresh_called = [False]
|
||||
|
||||
async def mock_refresh(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int]:
|
||||
refresh_called[0] = True
|
||||
return {
|
||||
'access_token': 'should-not-be-used',
|
||||
'refresh_token': 'should-not-be-used',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# The refresh callback should not be called because double-check
|
||||
# found the token was already refreshed
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'fresh-access-token'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slow_path_token_not_found_after_lock(self):
|
||||
"""Test slow path returns None if token record disappears after lock."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - token exists but expired
|
||||
# Second call (slow path with lock) - token no longer exists
|
||||
call_count = [0]
|
||||
|
||||
def get_token():
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
token = MagicMock()
|
||||
token.access_token_expires_at = current_time - 100 # Expired
|
||||
token.refresh_token_expires_at = current_time + 10000
|
||||
return token
|
||||
return None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.side_effect = get_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
result = await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLoadTokensLockTimeout:
|
||||
"""Tests for lock timeout handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_raises_token_refresh_error(self):
|
||||
"""Test that lock timeout raises TokenRefreshError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
# First call (fast path) - returns expired token
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
# First execute for fast path succeeds
|
||||
# Second execute (for slow path) raises OperationalError
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
# Simulate lock timeout
|
||||
raise OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
assert 'lock timeout' in str(exc_info.value).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lock_timeout_preserves_original_exception(self):
|
||||
"""Test that TokenRefreshError preserves the original OperationalError."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.access_token_expires_at = current_time - 100
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
|
||||
original_error = OperationalError(
|
||||
'canceling statement due to lock timeout', None, None
|
||||
)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
async def execute_side_effect(*args, **kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 1:
|
||||
return mock_result
|
||||
raise original_error
|
||||
|
||||
mock_session.execute = execute_side_effect
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh(*args) -> Dict[str, str | int]:
|
||||
return {
|
||||
'access_token': 'new-token',
|
||||
'refresh_token': 'new-refresh',
|
||||
'access_token_expires_at': current_time + 3600,
|
||||
'refresh_token_expires_at': current_time + 86400,
|
||||
}
|
||||
|
||||
with pytest.raises(TokenRefreshError) as exc_info:
|
||||
await auth_store.load_tokens(check_expiration_and_refresh=mock_refresh)
|
||||
|
||||
# Verify the original exception is chained
|
||||
assert exc_info.value.__cause__ is original_error
|
||||
|
||||
|
||||
class TestLoadTokensRefreshCallbackBehavior:
|
||||
"""Tests for refresh callback return values."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_callback_returns_none(self):
|
||||
"""Test behavior when refresh callback returns None (no refresh performed)."""
|
||||
current_time = int(time.time())
|
||||
mock_session = create_mock_session()
|
||||
|
||||
expired_token = MagicMock()
|
||||
expired_token.id = 1
|
||||
expired_token.access_token = 'old-access-token'
|
||||
expired_token.refresh_token = 'old-refresh-token'
|
||||
expired_token.access_token_expires_at = current_time - 100 # Expired
|
||||
expired_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = expired_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
async def mock_refresh_returns_none(
|
||||
idp: ProviderType, refresh_token: str, access_exp: int, refresh_exp: int
|
||||
) -> Dict[str, str | int] | None:
|
||||
return None
|
||||
|
||||
result = await auth_store.load_tokens(
|
||||
check_expiration_and_refresh=mock_refresh_returns_none
|
||||
)
|
||||
|
||||
# Should return the old tokens when refresh returns None
|
||||
assert result is not None
|
||||
assert result['access_token'] == 'old-access-token'
|
||||
assert result['refresh_token'] == 'old-refresh-token'
|
||||
|
||||
|
||||
class TestStoreTokens:
|
||||
"""Tests for store_tokens method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_creates_new_record(self):
|
||||
"""Test storing tokens when no existing record."""
|
||||
mock_session = create_mock_session()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_tokens_updates_existing_record(self):
|
||||
"""Test storing tokens updates existing record."""
|
||||
mock_session = create_mock_session()
|
||||
existing_token = MagicMock()
|
||||
existing_token.access_token = 'old-access'
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = existing_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_maker = create_mock_session_maker(mock_session)
|
||||
|
||||
auth_store = AuthTokenStore(
|
||||
keycloak_user_id='test-user-123',
|
||||
idp=ProviderType.GITHUB,
|
||||
a_session_maker=mock_session_maker,
|
||||
)
|
||||
|
||||
await auth_store.store_tokens(
|
||||
access_token='new-access-token',
|
||||
refresh_token='new-refresh-token',
|
||||
access_token_expires_at=1234567890,
|
||||
refresh_token_expires_at=1234657890,
|
||||
)
|
||||
|
||||
assert existing_token.access_token == 'new-access-token'
|
||||
assert existing_token.refresh_token == 'new-refresh-token'
|
||||
|
||||
|
||||
class TestIsAccessTokenValid:
|
||||
"""Tests for is_access_token_valid method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_when_no_tokens(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns False when no tokens found."""
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_true_for_valid_token(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns True when token is valid."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'valid-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time + 1000
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_access_token_valid_returns_false_for_expired_token(
|
||||
self, auth_token_store, mock_session_maker, mock_session
|
||||
):
|
||||
"""Test returns False when token is expired."""
|
||||
current_time = int(time.time())
|
||||
mock_token = MagicMock()
|
||||
mock_token.access_token = 'expired-access'
|
||||
mock_token.refresh_token = 'valid-refresh'
|
||||
mock_token.access_token_expires_at = current_time - 100 # Expired
|
||||
mock_token.refresh_token_expires_at = current_time + 10000
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.one_or_none.return_value = mock_token
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
result = await auth_token_store.is_access_token_valid()
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestGetInstance:
|
||||
"""Tests for get_instance class method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_instance_creates_auth_token_store(self):
|
||||
"""Test get_instance creates an AuthTokenStore with correct params."""
|
||||
with patch('storage.auth_token_store.a_session_maker') as mock_a_session_maker:
|
||||
store = await AuthTokenStore.get_instance(
|
||||
keycloak_user_id='user-123', idp=ProviderType.GITHUB
|
||||
)
|
||||
|
||||
assert store.keycloak_user_id == 'user-123'
|
||||
assert store.idp == ProviderType.GITHUB
|
||||
assert store.a_session_maker is mock_a_session_maker
|
||||
|
||||
|
||||
class TestIdentityProviderValue:
|
||||
"""Tests for identity_provider_value property."""
|
||||
|
||||
def test_identity_provider_value_returns_idp_value(self, auth_token_store):
|
||||
"""Test that identity_provider_value returns the enum value."""
|
||||
assert auth_token_store.identity_provider_value == ProviderType.GITHUB.value
|
||||
|
||||
def test_identity_provider_value_for_different_providers(self):
|
||||
"""Test identity_provider_value for different providers."""
|
||||
for provider in [
|
||||
ProviderType.GITHUB,
|
||||
ProviderType.GITLAB,
|
||||
ProviderType.BITBUCKET,
|
||||
]:
|
||||
store = AuthTokenStore(
|
||||
keycloak_user_id='test-user',
|
||||
idp=provider,
|
||||
a_session_maker=MagicMock(),
|
||||
)
|
||||
assert store.identity_provider_value == provider.value
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module constants."""
|
||||
|
||||
def test_access_token_expiry_buffer_value(self):
|
||||
"""Test ACCESS_TOKEN_EXPIRY_BUFFER is set to 15 minutes."""
|
||||
assert ACCESS_TOKEN_EXPIRY_BUFFER == 900
|
||||
|
||||
def test_lock_timeout_seconds_value(self):
|
||||
"""Test LOCK_TIMEOUT_SECONDS is set to 5 seconds."""
|
||||
assert LOCK_TIMEOUT_SECONDS == 5
|
||||
99
enterprise/tests/unit/storage/test_database.py
Normal file
99
enterprise/tests/unit/storage/test_database.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Tests for the enterprise storage.database module.
|
||||
|
||||
These tests verify that the session_maker function properly forwards
|
||||
keyword arguments to the underlying session maker for backward compatibility.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestSessionMaker:
|
||||
"""Test cases for the session_maker function."""
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_without_args(self, mock_get_injector):
|
||||
"""Test that session_maker works without any arguments."""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
# Call session_maker without arguments
|
||||
result = session_maker()
|
||||
|
||||
# Verify the inner session maker was called without arguments
|
||||
mock_inner_session_maker.assert_called_once_with()
|
||||
assert result == mock_session
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_with_expire_on_commit_false(self, mock_get_injector):
|
||||
"""Test that session_maker accepts expire_on_commit keyword argument.
|
||||
|
||||
This is a critical backward compatibility test - the session_maker
|
||||
must accept keyword arguments like expire_on_commit=False which is
|
||||
used in slack.py and potentially other integration modules.
|
||||
"""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
# Call session_maker with expire_on_commit=False
|
||||
# This is the exact call pattern used in slack.py line 242
|
||||
result = session_maker(expire_on_commit=False)
|
||||
|
||||
# Verify the inner session maker was called with the keyword argument
|
||||
mock_inner_session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
assert result == mock_session
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_with_multiple_kwargs(self, mock_get_injector):
|
||||
"""Test that session_maker passes through multiple keyword arguments."""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
# Call with multiple kwargs
|
||||
result = session_maker(
|
||||
expire_on_commit=False, autoflush=False, autocommit=False
|
||||
)
|
||||
|
||||
# Verify all kwargs were passed through
|
||||
mock_inner_session_maker.assert_called_once_with(
|
||||
expire_on_commit=False, autoflush=False, autocommit=False
|
||||
)
|
||||
assert result == mock_session
|
||||
|
||||
@patch('enterprise.storage.database._get_db_session_injector')
|
||||
def test_session_maker_returns_correct_session(self, mock_get_injector):
|
||||
"""Test that session_maker returns the session from the inner session maker."""
|
||||
from enterprise.storage.database import session_maker
|
||||
|
||||
# Set up mock
|
||||
mock_injector = MagicMock()
|
||||
mock_inner_session_maker = MagicMock()
|
||||
mock_session = MagicMock()
|
||||
mock_inner_session_maker.return_value = mock_session
|
||||
mock_injector.get_session_maker.return_value = mock_inner_session_maker
|
||||
mock_get_injector.return_value = mock_injector
|
||||
|
||||
result = session_maker()
|
||||
|
||||
# Verify the returned session is from the inner session maker
|
||||
assert result is mock_session
|
||||
158
enterprise/tests/unit/storage/test_resend_synced_user_store.py
Normal file
158
enterprise/tests/unit/storage/test_resend_synced_user_store.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Unit tests for ResendSyncedUserStore."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Import directly from the module files to avoid loading all of storage/__init__.py
|
||||
# which has many dependencies
|
||||
from storage.resend_synced_user import ResendSyncedUser
|
||||
from storage.resend_synced_user_store import ResendSyncedUserStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
"""Mock database session."""
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker(mock_session):
|
||||
"""Mock session maker."""
|
||||
session_maker = MagicMock()
|
||||
session_maker.return_value.__enter__.return_value = mock_session
|
||||
session_maker.return_value.__exit__.return_value = None
|
||||
return session_maker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(mock_session_maker):
|
||||
"""Create ResendSyncedUserStore instance."""
|
||||
return ResendSyncedUserStore(session_maker=mock_session_maker)
|
||||
|
||||
|
||||
class TestResendSyncedUserStore:
|
||||
"""Test cases for ResendSyncedUserStore."""
|
||||
|
||||
def test_is_user_synced_returns_true_when_exists(self, store, mock_session):
|
||||
"""Test is_user_synced returns True when user exists in database."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_row = MagicMock()
|
||||
mock_session.execute.return_value.first.return_value = mock_row
|
||||
|
||||
result = store.is_user_synced(email, audience_id)
|
||||
|
||||
assert result is True
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_is_user_synced_returns_false_when_not_exists(self, store, mock_session):
|
||||
"""Test is_user_synced returns False when user doesn't exist."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
result = store.is_user_synced(email, audience_id)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
|
||||
"""Test that is_user_synced normalizes email to lowercase."""
|
||||
email = 'TEST@EXAMPLE.COM'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_session.execute.return_value.first.return_value = None
|
||||
|
||||
store.is_user_synced(email, audience_id)
|
||||
|
||||
# Verify the query was called (we can't easily check the exact SQL)
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_creates_new_record(self, store, mock_session):
|
||||
"""Test that mark_user_synced creates a new record."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
keycloak_user_id = 'kc-user-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = store.mark_user_synced(email, audience_id, keycloak_user_id)
|
||||
|
||||
assert result == mock_synced_user
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_handles_existing_record(self, store, mock_session):
|
||||
"""Test that mark_user_synced handles conflict (existing record)."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
# First execute (insert) returns None (conflict occurred)
|
||||
# Second execute (select existing) returns the record
|
||||
mock_existing_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result_insert = MagicMock()
|
||||
mock_result_insert.first.return_value = None
|
||||
|
||||
mock_result_select = MagicMock()
|
||||
mock_result_select.first.return_value = (mock_existing_user,)
|
||||
|
||||
mock_session.execute.side_effect = [mock_result_insert, mock_result_select]
|
||||
|
||||
result = store.mark_user_synced(email, audience_id)
|
||||
|
||||
assert result == mock_existing_user
|
||||
assert mock_session.execute.call_count == 2
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_normalizes_email_to_lowercase(self, store, mock_session):
|
||||
"""Test that mark_user_synced normalizes email to lowercase."""
|
||||
email = 'TEST@EXAMPLE.COM'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
store.mark_user_synced(email, audience_id)
|
||||
|
||||
# Verify execute was called (the email normalization happens in the SQL)
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_mark_user_synced_without_keycloak_user_id(self, store, mock_session):
|
||||
"""Test that mark_user_synced works without keycloak_user_id."""
|
||||
email = 'test@example.com'
|
||||
audience_id = 'test-audience-123'
|
||||
|
||||
mock_synced_user = MagicMock(spec=ResendSyncedUser)
|
||||
mock_result = MagicMock()
|
||||
mock_result.first.return_value = (mock_synced_user,)
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
result = store.mark_user_synced(email, audience_id)
|
||||
|
||||
assert result == mock_synced_user
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
|
||||
class TestResendSyncedUser:
|
||||
"""Test cases for ResendSyncedUser model."""
|
||||
|
||||
def test_model_has_required_fields(self):
|
||||
"""Test that the model has all required fields."""
|
||||
assert hasattr(ResendSyncedUser, 'id')
|
||||
assert hasattr(ResendSyncedUser, 'email')
|
||||
assert hasattr(ResendSyncedUser, 'audience_id')
|
||||
assert hasattr(ResendSyncedUser, 'synced_at')
|
||||
assert hasattr(ResendSyncedUser, 'keycloak_user_id')
|
||||
|
||||
def test_model_table_name(self):
|
||||
"""Test the model's table name."""
|
||||
assert ResendSyncedUser.__tablename__ == 'resend_synced_users'
|
||||
@@ -10,8 +10,12 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
@@ -20,10 +24,15 @@ from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
|
||||
from openhands.app_server.utils.sql_utils import Base
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
# Test UUIDs
|
||||
USER1_ID = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
USER2_ID = UUID('b2222222-2222-2222-2222-222222222222')
|
||||
ORG1_ID = UUID('c1111111-1111-1111-1111-111111111111')
|
||||
ORG2_ID = UUID('d2222222-2222-2222-2222-222222222222')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
@@ -55,6 +64,41 @@ async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_with_users(async_engine) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Create an async session with pre-populated Org and User rows for testing."""
|
||||
async_session_maker = async_sessionmaker(
|
||||
async_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
async with async_session_maker() as db_session:
|
||||
# Insert Orgs first (required for User foreign key)
|
||||
org1 = Org(
|
||||
id=ORG1_ID,
|
||||
name='test-org-1',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
org2 = Org(
|
||||
id=ORG2_ID,
|
||||
name='test-org-2',
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
db_session.add(org1)
|
||||
db_session.add(org2)
|
||||
await db_session.flush()
|
||||
|
||||
# Insert Users
|
||||
user1 = User(id=USER1_ID, current_org_id=ORG1_ID)
|
||||
user2 = User(id=USER2_ID, current_org_id=ORG2_ID)
|
||||
db_session.add(user1)
|
||||
db_session.add(user2)
|
||||
await db_session.commit()
|
||||
|
||||
yield db_session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service(async_session) -> SaasSQLAppConversationInfoService:
|
||||
"""Create a SQLAppConversationInfoService instance for testing."""
|
||||
@@ -178,15 +222,26 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert user1_id != user2_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_includes_user_filtering(
|
||||
async def test_secure_select_includes_user_and_org_filtering(
|
||||
self,
|
||||
saas_service_user1: SaasSQLAppConversationInfoService,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that _secure_select method includes user filtering."""
|
||||
# This test verifies that the _secure_select method exists and can be called
|
||||
# The actual SQL generation is tested implicitly through integration
|
||||
query = await saas_service_user1._secure_select()
|
||||
assert query is not None
|
||||
"""Test that _secure_select method includes both user_id and org_id filtering."""
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
query = await service._secure_select()
|
||||
|
||||
# Convert query to string to verify filters are present
|
||||
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||
|
||||
# Verify user_id filter is present
|
||||
assert str(USER1_ID) in query_str or str(USER1_ID).replace('-', '') in query_str
|
||||
|
||||
# Verify org_id filter is present (user1 is in org1)
|
||||
assert str(ORG1_ID) in query_str or str(ORG1_ID).replace('-', '') in query_str
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_to_info_with_user_id_functionality(
|
||||
@@ -241,100 +296,32 @@ class TestSaasSQLAppConversationInfoService:
|
||||
assert result.sandbox_id == 'test-sandbox'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_isolation(
|
||||
async def test_user_isolation_different_users(
|
||||
self,
|
||||
async_session: AsyncSession,
|
||||
multiple_conversation_infos: list[AppConversationInfo],
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that user isolation works correctly."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from storage.user import User
|
||||
|
||||
# Mock the database session execute method to return mock users
|
||||
# This mock intercepts User queries and returns a mock user object
|
||||
# with user_id and org_id the same as the user_id_uuid from the query
|
||||
original_execute = async_session.execute
|
||||
|
||||
async def mock_execute(query):
|
||||
query_str = str(query)
|
||||
|
||||
# Check if this is a User query
|
||||
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
|
||||
# Extract the UUID from the query parameters
|
||||
# The query will have bound parameters, we need to get the UUID value
|
||||
if hasattr(query, 'compile'):
|
||||
try:
|
||||
compiled = query.compile(compile_kwargs={'literal_binds': True})
|
||||
query_with_params = str(compiled)
|
||||
|
||||
# Extract UUID from the query string
|
||||
import re
|
||||
|
||||
# Try both formats: with dashes and without dashes
|
||||
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
||||
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
|
||||
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_with_dashes, query_with_params
|
||||
)
|
||||
if not uuid_match:
|
||||
uuid_match = re.search(
|
||||
uuid_pattern_without_dashes, query_with_params
|
||||
)
|
||||
|
||||
if uuid_match:
|
||||
user_id_str = uuid_match.group(0)
|
||||
# If the UUID doesn't have dashes, add them
|
||||
if len(user_id_str) == 32 and '-' not in user_id_str:
|
||||
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
|
||||
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Create a mock user with user_id and org_id the same as user_id_uuid
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.id = user_id_uuid
|
||||
mock_user.current_org_id = user_id_uuid
|
||||
|
||||
# Create a mock result
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
return mock_result
|
||||
except Exception:
|
||||
# If there's any error in parsing, fall back to original execute
|
||||
pass
|
||||
|
||||
# For all other queries, use the original execute method
|
||||
return await original_execute(query)
|
||||
|
||||
# Apply the mock
|
||||
async_session.execute = mock_execute
|
||||
|
||||
"""Test that different users cannot see each other's conversations."""
|
||||
# Create services for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='a1111111-1111-1111-1111-111111111111'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session,
|
||||
user_context=SpecifyUserContext(
|
||||
user_id='b2222222-2222-2222-2222-222222222222'
|
||||
),
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
# Create conversations for different users
|
||||
user1_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='a1111111-1111-1111-1111-111111111111',
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User 1 Conversation',
|
||||
)
|
||||
|
||||
user2_info = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id='b2222222-2222-2222-2222-222222222222',
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id='sandbox_user2',
|
||||
title='User 2 Conversation',
|
||||
)
|
||||
@@ -346,18 +333,12 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 1 should only see their conversation
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 1
|
||||
assert (
|
||||
user1_page.items[0].created_by_user_id
|
||||
== 'a1111111-1111-1111-1111-111111111111'
|
||||
)
|
||||
assert user1_page.items[0].created_by_user_id == str(USER1_ID)
|
||||
|
||||
# User 2 should only see their conversation
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 1
|
||||
assert (
|
||||
user2_page.items[0].created_by_user_id
|
||||
== 'b2222222-2222-2222-2222-222222222222'
|
||||
)
|
||||
assert user2_page.items[0].created_by_user_id == str(USER2_ID)
|
||||
|
||||
# User 1 should not be able to get user 2's conversation
|
||||
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
|
||||
@@ -366,3 +347,319 @@ class TestSaasSQLAppConversationInfoService:
|
||||
# User 2 should not be able to get user 1's conversation
|
||||
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
|
||||
assert user1_from_user2 is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_user_org_switching_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that the same user switching orgs cannot see conversations from other orgs.
|
||||
|
||||
This tests the actual bug scenario: a user creates a conversation in org1,
|
||||
then switches to org2, and should NOT see org1's conversations.
|
||||
"""
|
||||
# Create service for user1 in org1
|
||||
user1_service_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create a conversation while user is in org1
|
||||
conv_in_org1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org1',
|
||||
title='Conversation in Org 1',
|
||||
)
|
||||
await user1_service_org1.save_app_conversation_info(conv_in_org1)
|
||||
|
||||
# Verify user can see the conversation in org1
|
||||
page_in_org1 = await user1_service_org1.search_app_conversation_info()
|
||||
assert len(page_in_org1.items) == 1
|
||||
assert page_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
# Simulate user switching to org2 by updating current_org_id using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
# Clear SQLAlchemy's identity map cache to simulate a new request
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Create new service instance (simulating a new request after org switch)
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should NOT see org1's conversations after switching to org2
|
||||
page_in_org2 = await user1_service_org2.search_app_conversation_info()
|
||||
assert (
|
||||
len(page_in_org2.items) == 0
|
||||
), 'User should not see conversations from org1 after switching to org2'
|
||||
|
||||
# User should not be able to get the specific conversation from org1
|
||||
conv_from_org2 = await user1_service_org2.get_app_conversation_info(
|
||||
conv_in_org1.id
|
||||
)
|
||||
assert (
|
||||
conv_from_org2 is None
|
||||
), 'User should not be able to access org1 conversation from org2'
|
||||
|
||||
# Now create a conversation in org2
|
||||
conv_in_org2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_org2',
|
||||
title='Conversation in Org 2',
|
||||
)
|
||||
await user1_service_org2.save_app_conversation_info(conv_in_org2)
|
||||
|
||||
# User should only see org2's conversation
|
||||
page_in_org2_after = await user1_service_org2.search_app_conversation_info()
|
||||
assert len(page_in_org2_after.items) == 1
|
||||
assert page_in_org2_after.items[0].title == 'Conversation in Org 2'
|
||||
|
||||
# Switch back to org1 and verify isolation works both ways
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG1_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_back_to_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# User should only see org1's conversation now
|
||||
page_back_in_org1 = (
|
||||
await user1_service_back_to_org1.search_app_conversation_info()
|
||||
)
|
||||
assert len(page_back_in_org1.items) == 1
|
||||
assert page_back_in_org1.items[0].title == 'Conversation in Org 1'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_respects_org_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that count_app_conversation_info respects org isolation."""
|
||||
# Create service for user1 in org1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations in org1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_org1_{i}',
|
||||
title=f'Org1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Count should be 3
|
||||
count_org1 = await user1_service.count_app_conversation_info()
|
||||
assert count_org1 == 3
|
||||
|
||||
# Switch to org2 using ORM
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
user1_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Count should be 0 in org2
|
||||
count_org2 = await user1_service_org2.count_app_conversation_info()
|
||||
assert count_org2 == 0
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoServiceAdminContext:
|
||||
"""Test suite for SaasSQLAppConversationInfoService with ADMIN context."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_context_returns_unfiltered_data(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that ADMIN context returns unfiltered data (no user/org filtering)."""
|
||||
# Create conversations for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations for user1 in org1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_user1_{i}',
|
||||
title=f'User1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Now create an ADMIN service
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
# ADMIN should see ALL conversations (unfiltered)
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert (
|
||||
len(admin_page.items) == 3
|
||||
), 'ADMIN context should see all conversations without filtering'
|
||||
|
||||
# ADMIN count should return total count (3)
|
||||
admin_count = await admin_service.count_app_conversation_info()
|
||||
assert (
|
||||
admin_count == 3
|
||||
), 'ADMIN context should count all conversations without filtering'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_context_can_access_any_conversation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that ADMIN context can access any conversation regardless of owner."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Create a conversation as user1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_user1',
|
||||
title='User1 Private Conversation',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Create a service as user2 in org2 - should not see user1's conversation
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 0, 'User2 should not see User1 conversation'
|
||||
|
||||
# But ADMIN should see ALL conversations including user1's
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert len(admin_page.items) == 1
|
||||
assert admin_page.items[0].id == conv.id
|
||||
|
||||
# ADMIN should also be able to get specific conversation by ID
|
||||
admin_get_conv = await admin_service.get_app_conversation_info(conv.id)
|
||||
assert admin_get_conv is not None
|
||||
assert admin_get_conv.id == conv.id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_secure_select_admin_bypasses_filtering(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that _secure_select returns unfiltered query for ADMIN context."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Create an ADMIN service
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
# Get the secure select query
|
||||
query = await admin_service._secure_select()
|
||||
|
||||
# Convert query to string to verify NO filters are present
|
||||
query_str = str(query.compile(compile_kwargs={'literal_binds': True}))
|
||||
|
||||
# For ADMIN, there should be no user_id or org_id filtering
|
||||
# The query should not contain filters for user_id or org_id
|
||||
assert str(USER1_ID) not in query_str.replace(
|
||||
'-', ''
|
||||
), 'ADMIN context should not filter by user_id'
|
||||
assert str(USER2_ID) not in query_str.replace(
|
||||
'-', ''
|
||||
), 'ADMIN context should not filter by user_id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_user_context_filters_correctly(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that regular user context properly filters data (control test)."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Create conversations for different users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create 3 conversations for user1
|
||||
for i in range(3):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=f'sandbox_user1_{i}',
|
||||
title=f'User1 Conversation {i}',
|
||||
)
|
||||
await user1_service.save_app_conversation_info(conv)
|
||||
|
||||
# Create 2 conversations for user2
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
for i in range(2):
|
||||
conv = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id=f'sandbox_user2_{i}',
|
||||
title=f'User2 Conversation {i}',
|
||||
)
|
||||
await user2_service.save_app_conversation_info(conv)
|
||||
|
||||
# User1 should only see their 3 conversations
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
assert len(user1_page.items) == 3
|
||||
|
||||
# User2 should only see their 2 conversations
|
||||
user2_page = await user2_service.search_app_conversation_info()
|
||||
assert len(user2_page.items) == 2
|
||||
|
||||
# But ADMIN should see all 5 conversations
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert len(admin_page.items) == 5
|
||||
|
||||
267
enterprise/tests/unit/sync/test_resend_keycloak.py
Normal file
267
enterprise/tests/unit/sync/test_resend_keycloak.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for Resend Keycloak sync functionality."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from resend.exceptions import ResendError
|
||||
from tenacity import RetryError
|
||||
|
||||
# Set required environment variables before importing the module
|
||||
# that reads them at import time
|
||||
os.environ['RESEND_API_KEY'] = 'test_api_key'
|
||||
os.environ['RESEND_AUDIENCE_ID'] = 'test_audience_id'
|
||||
os.environ['KEYCLOAK_SERVER_URL'] = 'http://localhost:8080'
|
||||
os.environ['KEYCLOAK_REALM_NAME'] = 'test_realm'
|
||||
os.environ['KEYCLOAK_ADMIN_PASSWORD'] = 'test_password'
|
||||
|
||||
from enterprise.sync.resend_keycloak import ( # noqa: E402
|
||||
add_contact_to_resend,
|
||||
is_valid_email,
|
||||
send_welcome_email,
|
||||
)
|
||||
|
||||
|
||||
class TestIsValidEmail:
|
||||
"""Test cases for is_valid_email function."""
|
||||
|
||||
def test_valid_simple_email(self):
|
||||
"""Test that a simple valid email passes validation."""
|
||||
assert is_valid_email('user@example.com') is True
|
||||
|
||||
def test_valid_email_with_plus(self):
|
||||
"""Test that email with + modifier passes validation."""
|
||||
assert is_valid_email('user+tag@example.com') is True
|
||||
|
||||
def test_valid_email_with_dots(self):
|
||||
"""Test that email with dots in local part passes validation."""
|
||||
assert is_valid_email('first.last@example.com') is True
|
||||
|
||||
def test_valid_email_with_numbers(self):
|
||||
"""Test that email with numbers passes validation."""
|
||||
assert is_valid_email('user123@example.com') is True
|
||||
|
||||
def test_valid_email_with_subdomain(self):
|
||||
"""Test that email with subdomain passes validation."""
|
||||
assert is_valid_email('user@mail.example.com') is True
|
||||
|
||||
def test_valid_email_with_hyphen_domain(self):
|
||||
"""Test that email with hyphen in domain passes validation."""
|
||||
assert is_valid_email('user@example-site.com') is True
|
||||
|
||||
def test_valid_email_with_underscore(self):
|
||||
"""Test that email with underscore passes validation."""
|
||||
assert is_valid_email('user_name@example.com') is True
|
||||
|
||||
def test_valid_email_with_percent(self):
|
||||
"""Test that email with percent sign passes validation."""
|
||||
assert is_valid_email('user%name@example.com') is True
|
||||
|
||||
def test_invalid_email_with_exclamation(self):
|
||||
"""Test that email with exclamation mark fails validation.
|
||||
|
||||
This is the specific case from the bug report:
|
||||
ethanjames3713+!@gmail.com
|
||||
"""
|
||||
assert is_valid_email('ethanjames3713+!@gmail.com') is False
|
||||
|
||||
def test_invalid_email_with_special_chars(self):
|
||||
"""Test that email with other special characters fails validation."""
|
||||
assert is_valid_email('user!name@example.com') is False
|
||||
assert is_valid_email('user#name@example.com') is False
|
||||
assert is_valid_email('user$name@example.com') is False
|
||||
assert is_valid_email('user&name@example.com') is False
|
||||
assert is_valid_email("user'name@example.com") is False
|
||||
assert is_valid_email('user*name@example.com') is False
|
||||
assert is_valid_email('user=name@example.com') is False
|
||||
assert is_valid_email('user^name@example.com') is False
|
||||
assert is_valid_email('user`name@example.com') is False
|
||||
assert is_valid_email('user{name@example.com') is False
|
||||
assert is_valid_email('user|name@example.com') is False
|
||||
assert is_valid_email('user}name@example.com') is False
|
||||
assert is_valid_email('user~name@example.com') is False
|
||||
|
||||
def test_invalid_email_no_at_symbol(self):
|
||||
"""Test that email without @ symbol fails validation."""
|
||||
assert is_valid_email('userexample.com') is False
|
||||
|
||||
def test_invalid_email_no_domain(self):
|
||||
"""Test that email without domain fails validation."""
|
||||
assert is_valid_email('user@') is False
|
||||
|
||||
def test_invalid_email_no_local_part(self):
|
||||
"""Test that email without local part fails validation."""
|
||||
assert is_valid_email('@example.com') is False
|
||||
|
||||
def test_invalid_email_no_tld(self):
|
||||
"""Test that email without TLD fails validation."""
|
||||
assert is_valid_email('user@example') is False
|
||||
|
||||
def test_invalid_email_single_char_tld(self):
|
||||
"""Test that email with single character TLD fails validation."""
|
||||
assert is_valid_email('user@example.c') is False
|
||||
|
||||
def test_invalid_email_empty_string(self):
|
||||
"""Test that empty string fails validation."""
|
||||
assert is_valid_email('') is False
|
||||
|
||||
def test_invalid_email_none(self):
|
||||
"""Test that None fails validation."""
|
||||
assert is_valid_email(None) is False
|
||||
|
||||
def test_invalid_email_whitespace(self):
|
||||
"""Test that email with whitespace fails validation."""
|
||||
assert is_valid_email('user @example.com') is False
|
||||
assert is_valid_email('user@ example.com') is False
|
||||
assert is_valid_email(' user@example.com') is False
|
||||
assert is_valid_email('user@example.com ') is False
|
||||
|
||||
def test_invalid_email_double_at(self):
|
||||
"""Test that email with double @ fails validation."""
|
||||
assert is_valid_email('user@@example.com') is False
|
||||
|
||||
def test_email_double_dot_domain(self):
|
||||
"""Test email with double dot in domain.
|
||||
|
||||
Note: The regex allows this as it's technically valid in some edge cases,
|
||||
and Resend's API may accept it. The main goal is to reject special
|
||||
characters like ! that Resend definitely rejects.
|
||||
"""
|
||||
# This is allowed by our regex - Resend may or may not accept it
|
||||
assert is_valid_email('user@example..com') is True
|
||||
|
||||
def test_case_insensitive_validation(self):
|
||||
"""Test that validation works for uppercase emails."""
|
||||
assert is_valid_email('USER@EXAMPLE.COM') is True
|
||||
assert is_valid_email('User@Example.Com') is True
|
||||
|
||||
|
||||
class TestSendWelcomeEmail:
|
||||
"""Tests for send_welcome_email function."""
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_success(self, mock_send: MagicMock) -> None:
|
||||
"""Test successful welcome email sending."""
|
||||
mock_send.return_value = {'id': 'email_123'}
|
||||
|
||||
result = send_welcome_email(
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
assert result == {'id': 'email_123'}
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args[0][0]
|
||||
assert call_args['to'] == ['test@example.com']
|
||||
assert call_args['subject'] == 'Welcome to OpenHands Cloud'
|
||||
assert 'Hi John Doe,' in call_args['html']
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_retries_on_rate_limit(
|
||||
self, mock_send: MagicMock
|
||||
) -> None:
|
||||
"""Test that send_welcome_email retries on rate limit errors."""
|
||||
# First two calls raise rate limit error, third succeeds
|
||||
mock_send.side_effect = [
|
||||
ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
),
|
||||
ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
),
|
||||
{'id': 'email_123'},
|
||||
]
|
||||
|
||||
result = send_welcome_email(
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
assert result == {'id': 'email_123'}
|
||||
assert mock_send.call_count == 3
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_fails_after_max_retries(
|
||||
self, mock_send: MagicMock
|
||||
) -> None:
|
||||
"""Test that send_welcome_email fails after max retries."""
|
||||
# All calls raise rate limit error
|
||||
mock_send.side_effect = ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
)
|
||||
|
||||
# Tenacity wraps the final exception in RetryError
|
||||
with pytest.raises(RetryError):
|
||||
send_welcome_email(
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
# Default MAX_RETRIES is 3
|
||||
assert mock_send.call_count == 3
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Emails.send')
|
||||
def test_send_welcome_email_no_name(self, mock_send: MagicMock) -> None:
|
||||
"""Test welcome email with no name provided."""
|
||||
mock_send.return_value = {'id': 'email_123'}
|
||||
|
||||
result = send_welcome_email(email='test@example.com')
|
||||
|
||||
assert result == {'id': 'email_123'}
|
||||
call_args = mock_send.call_args[0][0]
|
||||
assert 'Hi there,' in call_args['html']
|
||||
|
||||
|
||||
class TestAddContactToResend:
|
||||
"""Tests for add_contact_to_resend function."""
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
|
||||
def test_add_contact_to_resend_success(self, mock_create: MagicMock) -> None:
|
||||
"""Test successful contact addition."""
|
||||
mock_create.return_value = {'id': 'contact_123'}
|
||||
|
||||
result = add_contact_to_resend(
|
||||
audience_id='test_audience',
|
||||
email='test@example.com',
|
||||
first_name='John',
|
||||
last_name='Doe',
|
||||
)
|
||||
|
||||
assert result == {'id': 'contact_123'}
|
||||
mock_create.assert_called_once()
|
||||
|
||||
@patch('enterprise.sync.resend_keycloak.resend.Contacts.create')
|
||||
def test_add_contact_to_resend_retries_on_rate_limit(
|
||||
self, mock_create: MagicMock
|
||||
) -> None:
|
||||
"""Test that add_contact_to_resend retries on rate limit errors."""
|
||||
# First call raises rate limit error, second succeeds
|
||||
mock_create.side_effect = [
|
||||
ResendError(
|
||||
code=429,
|
||||
message='Too many requests',
|
||||
error_type='rate_limit_exceeded',
|
||||
suggested_action='',
|
||||
),
|
||||
{'id': 'contact_123'},
|
||||
]
|
||||
|
||||
result = add_contact_to_resend(
|
||||
audience_id='test_audience',
|
||||
email='test@example.com',
|
||||
)
|
||||
|
||||
assert result == {'id': 'contact_123'}
|
||||
assert mock_create.call_count == 2
|
||||
@@ -32,6 +32,11 @@ def api_key_store(mock_session_maker):
|
||||
return ApiKeyStore(mock_session_maker)
|
||||
|
||||
|
||||
def run_sync(func, *args, **kwargs):
|
||||
"""Helper to execute sync functions directly (mocks call_sync_from_async)."""
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def test_generate_api_key(api_key_store):
|
||||
"""Test that generate_api_key returns a string with sk-oh- prefix and expected length."""
|
||||
key = api_key_store.generate_api_key(length=32)
|
||||
@@ -41,8 +46,12 @@ def test_generate_api_key(api_key_store):
|
||||
assert len(key) == len('sk-oh-') + 32
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_create_api_key(
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test creating an API key."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
@@ -51,7 +60,7 @@ def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
|
||||
|
||||
# Execute
|
||||
result = api_key_store.create_api_key(user_id, name)
|
||||
result = await api_key_store.create_api_key(user_id, name)
|
||||
|
||||
# Verify
|
||||
assert result == 'test-api-key'
|
||||
@@ -219,8 +228,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_list_api_keys(
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test listing API keys for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
@@ -247,26 +260,30 @@ def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
mock_filter_org.all.return_value = [mock_key1, mock_key2]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.list_api_keys(user_id)
|
||||
result = await api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert len(result) == 2
|
||||
assert result[0]['id'] == 1
|
||||
assert result[0]['name'] == 'Key 1'
|
||||
assert result[0]['created_at'] == now
|
||||
assert result[0]['last_used_at'] == now
|
||||
assert result[0]['expires_at'] == now + timedelta(days=30)
|
||||
assert result[0].id == 1
|
||||
assert result[0].name == 'Key 1'
|
||||
assert result[0].created_at == now
|
||||
assert result[0].last_used_at == now
|
||||
assert result[0].expires_at == now + timedelta(days=30)
|
||||
|
||||
assert result[1]['id'] == 2
|
||||
assert result[1]['name'] == 'Key 2'
|
||||
assert result[1]['created_at'] == now
|
||||
assert result[1]['last_used_at'] is None
|
||||
assert result[1]['expires_at'] is None
|
||||
assert result[1].id == 2
|
||||
assert result[1].name == 'Key 2'
|
||||
assert result[1].created_at == now
|
||||
assert result[1].last_used_at is None
|
||||
assert result[1].expires_at is None
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_retrieve_mcp_api_key(
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key for a user."""
|
||||
# Setup
|
||||
user_id = 'test-user-123'
|
||||
@@ -287,16 +304,18 @@ def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_u
|
||||
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
result = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
assert result == 'mcp-test-key'
|
||||
|
||||
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id')
|
||||
def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, api_key_store, mock_session, mock_user
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.api_key_store.call_sync_from_async', side_effect=run_sync)
|
||||
@patch('storage.api_key_store.UserStore.get_user_by_id_async')
|
||||
async def test_retrieve_mcp_api_key_not_found(
|
||||
mock_get_user, mock_call_sync, api_key_store, mock_session, mock_user
|
||||
):
|
||||
"""Test retrieving MCP API key when none exists."""
|
||||
# Setup
|
||||
@@ -314,7 +333,7 @@ def test_retrieve_mcp_api_key_not_found(
|
||||
mock_filter_org.all.return_value = [mock_other_key]
|
||||
|
||||
# Execute
|
||||
result = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
result = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
# Verify
|
||||
mock_get_user.assert_called_once_with(user_id)
|
||||
|
||||
181
enterprise/tests/unit/test_auth_invitation_callback.py
Normal file
181
enterprise/tests/unit/test_auth_invitation_callback.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Tests for auth callback invitation acceptance - EmailMismatchError handling."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAuthCallbackInvitationEmailMismatch:
|
||||
"""Test cases for EmailMismatchError handling during auth callback."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redirect_url(self):
|
||||
"""Base redirect URL."""
|
||||
return 'https://app.example.com/'
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_id(self):
|
||||
"""Mock user ID."""
|
||||
return '87654321-4321-8765-4321-876543218765'
|
||||
|
||||
def test_email_mismatch_appends_to_url_without_query_params(
|
||||
self, mock_redirect_url, mock_user_id
|
||||
):
|
||||
"""Test that email_mismatch=true is appended correctly when URL has no query params."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
# Simulate the logic from auth.py
|
||||
redirect_url = mock_redirect_url
|
||||
try:
|
||||
raise EmailMismatchError('Your email does not match the invitation')
|
||||
except EmailMismatchError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?email_mismatch=true'
|
||||
|
||||
def test_email_mismatch_appends_to_url_with_query_params(self, mock_user_id):
|
||||
"""Test that email_mismatch=true is appended correctly when URL has existing query params."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
redirect_url = 'https://app.example.com/?other_param=value'
|
||||
try:
|
||||
raise EmailMismatchError()
|
||||
except EmailMismatchError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
assert (
|
||||
redirect_url
|
||||
== 'https://app.example.com/?other_param=value&email_mismatch=true'
|
||||
)
|
||||
|
||||
def test_email_mismatch_error_has_default_message(self):
|
||||
"""Test that EmailMismatchError has the default message."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
error = EmailMismatchError()
|
||||
assert str(error) == 'Your email does not match the invitation'
|
||||
|
||||
def test_email_mismatch_error_accepts_custom_message(self):
|
||||
"""Test that EmailMismatchError accepts a custom message."""
|
||||
from server.routes.org_invitation_models import EmailMismatchError
|
||||
|
||||
custom_message = 'Custom error message'
|
||||
error = EmailMismatchError(custom_message)
|
||||
assert str(error) == custom_message
|
||||
|
||||
def test_email_mismatch_error_is_invitation_error(self):
|
||||
"""Test that EmailMismatchError inherits from InvitationError."""
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InvitationError,
|
||||
)
|
||||
|
||||
error = EmailMismatchError()
|
||||
assert isinstance(error, InvitationError)
|
||||
|
||||
|
||||
class TestInvitationTokenInOAuthState:
|
||||
"""Test cases for invitation token handling in OAuth state."""
|
||||
|
||||
def test_invitation_token_included_in_oauth_state(self):
|
||||
"""Test that invitation token is included in OAuth state data."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
# Simulate building OAuth state with invitation token
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
'invitation_token': 'inv-test-token-12345',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
decoded_data = json.loads(base64.b64decode(encoded_state))
|
||||
|
||||
assert decoded_data['invitation_token'] == 'inv-test-token-12345'
|
||||
assert decoded_data['redirect_url'] == 'https://app.example.com/'
|
||||
|
||||
def test_invitation_token_extracted_from_oauth_state(self):
|
||||
"""Test that invitation token can be extracted from OAuth state."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
'invitation_token': 'inv-test-token-12345',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
|
||||
# Simulate decoding in callback
|
||||
decoded_state = json.loads(base64.b64decode(encoded_state))
|
||||
invitation_token = decoded_state.get('invitation_token')
|
||||
|
||||
assert invitation_token == 'inv-test-token-12345'
|
||||
|
||||
def test_oauth_state_without_invitation_token(self):
|
||||
"""Test that OAuth state works without invitation token."""
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://app.example.com/',
|
||||
}
|
||||
|
||||
encoded_state = base64.b64encode(json.dumps(state_data).encode()).decode()
|
||||
decoded_data = json.loads(base64.b64decode(encoded_state))
|
||||
|
||||
assert 'invitation_token' not in decoded_data
|
||||
assert decoded_data['redirect_url'] == 'https://app.example.com/'
|
||||
|
||||
|
||||
class TestAuthCallbackInvitationErrors:
|
||||
"""Test cases for various invitation error scenarios in auth callback."""
|
||||
|
||||
def test_invitation_expired_appends_flag(self):
|
||||
"""Test that invitation_expired=true is appended for expired invitations."""
|
||||
from server.routes.org_invitation_models import InvitationExpiredError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise InvitationExpiredError()
|
||||
except InvitationExpiredError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?invitation_expired=true'
|
||||
|
||||
def test_invitation_invalid_appends_flag(self):
|
||||
"""Test that invitation_invalid=true is appended for invalid invitations."""
|
||||
from server.routes.org_invitation_models import InvitationInvalidError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise InvitationInvalidError()
|
||||
except InvitationInvalidError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?invitation_invalid=true'
|
||||
|
||||
def test_already_member_appends_flag(self):
|
||||
"""Test that already_member=true is appended when user is already a member."""
|
||||
from server.routes.org_invitation_models import UserAlreadyMemberError
|
||||
|
||||
redirect_url = 'https://app.example.com/'
|
||||
try:
|
||||
raise UserAlreadyMemberError()
|
||||
except UserAlreadyMemberError:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
assert redirect_url == 'https://app.example.com/?already_member=true'
|
||||
@@ -284,3 +284,85 @@ async def test_middleware_ignores_email_resend_path_no_tos_check(
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
# Should not raise TosNotAcceptedError for this path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_skips_webhook_endpoints(
|
||||
middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware skips webhook endpoints (/api/v1/webhooks/*) and doesn't require auth."""
|
||||
# Test various webhook paths
|
||||
webhook_paths = [
|
||||
'/api/v1/webhooks/events',
|
||||
'/api/v1/webhooks/events/123',
|
||||
'/api/v1/webhooks/stats',
|
||||
'/api/v1/webhooks/parent-conversation',
|
||||
]
|
||||
|
||||
for path in webhook_paths:
|
||||
mock_request.cookies = {}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = path
|
||||
mock_call_next = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Act
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Assert - middleware should skip auth check and call next
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_skips_webhook_secrets_endpoint(
|
||||
middleware, mock_request, mock_response
|
||||
):
|
||||
"""Test middleware skips the old /api/v1/webhooks/secrets endpoint."""
|
||||
# This was explicitly in ignore_paths but is now handled by the prefix check
|
||||
mock_request.cookies = {}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = '/api/v1/webhooks/secrets'
|
||||
mock_call_next = AsyncMock(return_value=mock_response)
|
||||
|
||||
# Act
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Assert - middleware should skip auth check and call next
|
||||
assert result == mock_response
|
||||
mock_call_next.assert_called_once_with(mock_request)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_does_not_skip_similar_non_webhook_paths(
|
||||
middleware, mock_response
|
||||
):
|
||||
"""Test middleware does NOT skip paths that start with /api/v1/webhook (without 's')."""
|
||||
# These paths should still be processed by the middleware (not skipped)
|
||||
# They start with /api so _should_attach returns True, and since there's no auth,
|
||||
# middleware should return 401 response (it catches NoCredentialsError internally)
|
||||
non_webhook_paths = [
|
||||
'/api/v1/webhook/events',
|
||||
'/api/v1/webhook/something',
|
||||
]
|
||||
|
||||
for path in non_webhook_paths:
|
||||
# Create a fresh mock request for each test
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.cookies = {}
|
||||
mock_request.url = MagicMock()
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.path = path
|
||||
mock_request.headers = MagicMock()
|
||||
mock_request.headers.get = MagicMock(side_effect=lambda k: None)
|
||||
|
||||
# Since these paths start with /api, _should_attach returns True
|
||||
# Since there's no auth, middleware catches NoCredentialsError and returns 401
|
||||
mock_call_next = AsyncMock()
|
||||
result = await middleware(mock_request, mock_call_next)
|
||||
|
||||
# Should return a 401 response, not raise an exception
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
# Should NOT call next for non-webhook paths when auth is missing
|
||||
mock_call_next.assert_not_called()
|
||||
|
||||
@@ -153,6 +153,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = False
|
||||
@@ -188,6 +189,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -259,6 +261,7 @@ async def test_keycloak_callback_email_not_verified(mock_request):
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -306,6 +309,7 @@ async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -347,6 +351,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
@@ -581,6 +586,7 @@ async def test_keycloak_callback_blocked_email_domain(mock_request):
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = True
|
||||
@@ -644,6 +650,7 @@ async def test_keycloak_callback_allowed_email_domain(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
@@ -707,6 +714,7 @@ async def test_keycloak_callback_domain_blocking_inactive(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = False
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
@@ -768,6 +776,7 @@ async def test_keycloak_callback_missing_email(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_active.return_value = True
|
||||
|
||||
@@ -813,6 +822,7 @@ async def test_keycloak_callback_duplicate_email_detected(mock_request):
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -857,6 +867,7 @@ async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
# Act
|
||||
result = await keycloak_callback(
|
||||
@@ -914,6 +925,7 @@ async def test_keycloak_callback_duplicate_check_exception(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -971,6 +983,7 @@ async def test_keycloak_callback_no_duplicate_email(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1031,6 +1044,7 @@ async def test_keycloak_callback_no_email_in_user_info(mock_request):
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1187,6 +1201,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1251,6 +1266,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
@@ -1333,6 +1349,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1420,6 +1437,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1504,6 +1522,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1587,6 +1606,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1667,6 +1687,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1733,6 +1754,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1805,6 +1827,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_verifier.is_active.return_value = True
|
||||
mock_verifier.is_user_allowed.return_value = True
|
||||
@@ -1875,6 +1898,7 @@ class TestKeycloakCallbackRecaptcha:
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user_store.get_user_by_id_async = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.create_user = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
|
||||
mock_domain_blocker.is_domain_blocked.return_value = False
|
||||
|
||||
|
||||
756
enterprise/tests/unit/test_authorization.py
Normal file
756
enterprise/tests/unit/test_authorization.py
Normal file
@@ -0,0 +1,756 @@
|
||||
"""
|
||||
Unit tests for permission-based authorization (authorization.py).
|
||||
|
||||
Tests the FastAPI dependencies that validate user permissions within organizations.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.auth.authorization import (
|
||||
ROLE_PERMISSIONS,
|
||||
Permission,
|
||||
RoleName,
|
||||
get_role_permissions,
|
||||
get_user_org_role,
|
||||
has_permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# Tests for Permission enum
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPermission:
|
||||
"""Tests for Permission enum."""
|
||||
|
||||
def test_permission_values(self):
|
||||
"""
|
||||
GIVEN: Permission enum
|
||||
WHEN: Accessing permission values
|
||||
THEN: All expected permissions exist with correct string values
|
||||
"""
|
||||
assert Permission.MANAGE_SECRETS.value == 'manage_secrets'
|
||||
assert Permission.MANAGE_MCP.value == 'manage_mcp'
|
||||
assert Permission.MANAGE_INTEGRATIONS.value == 'manage_integrations'
|
||||
assert (
|
||||
Permission.MANAGE_APPLICATION_SETTINGS.value
|
||||
== 'manage_application_settings'
|
||||
)
|
||||
assert Permission.MANAGE_API_KEYS.value == 'manage_api_keys'
|
||||
assert Permission.VIEW_LLM_SETTINGS.value == 'view_llm_settings'
|
||||
assert Permission.EDIT_LLM_SETTINGS.value == 'edit_llm_settings'
|
||||
assert Permission.VIEW_BILLING.value == 'view_billing'
|
||||
assert Permission.ADD_CREDITS.value == 'add_credits'
|
||||
assert (
|
||||
Permission.INVITE_USER_TO_ORGANIZATION.value
|
||||
== 'invite_user_to_organization'
|
||||
)
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER.value == 'change_user_role:member'
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN.value == 'change_user_role:admin'
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER.value == 'change_user_role:owner'
|
||||
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
|
||||
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
|
||||
|
||||
def test_permission_from_string(self):
|
||||
"""
|
||||
GIVEN: Valid permission string
|
||||
WHEN: Creating Permission from string
|
||||
THEN: Correct enum value is returned
|
||||
"""
|
||||
assert Permission('manage_secrets') == Permission.MANAGE_SECRETS
|
||||
assert Permission('view_llm_settings') == Permission.VIEW_LLM_SETTINGS
|
||||
assert Permission('delete_organization') == Permission.DELETE_ORGANIZATION
|
||||
|
||||
def test_permission_invalid_string(self):
|
||||
"""
|
||||
GIVEN: Invalid permission string
|
||||
WHEN: Creating Permission from string
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
Permission('invalid_permission')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for RoleName enum
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRoleName:
|
||||
"""Tests for RoleName enum."""
|
||||
|
||||
def test_role_name_values(self):
|
||||
"""
|
||||
GIVEN: RoleName enum
|
||||
WHEN: Accessing role name values
|
||||
THEN: All expected roles exist with correct string values
|
||||
"""
|
||||
assert RoleName.OWNER.value == 'owner'
|
||||
assert RoleName.ADMIN.value == 'admin'
|
||||
assert RoleName.MEMBER.value == 'member'
|
||||
|
||||
def test_role_name_from_string(self):
|
||||
"""
|
||||
GIVEN: Valid role name string
|
||||
WHEN: Creating RoleName from string
|
||||
THEN: Correct enum value is returned
|
||||
"""
|
||||
assert RoleName('owner') == RoleName.OWNER
|
||||
assert RoleName('admin') == RoleName.ADMIN
|
||||
assert RoleName('member') == RoleName.MEMBER
|
||||
|
||||
def test_role_name_invalid_string(self):
|
||||
"""
|
||||
GIVEN: Invalid role name string
|
||||
WHEN: Creating RoleName from string
|
||||
THEN: ValueError is raised
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
RoleName('invalid_role')
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for ROLE_PERMISSIONS mapping
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRolePermissions:
|
||||
"""Tests for role permission mappings."""
|
||||
|
||||
def test_owner_has_all_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking owner permissions
|
||||
THEN: Owner has all permissions including owner-only permissions
|
||||
"""
|
||||
owner_perms = ROLE_PERMISSIONS[RoleName.OWNER]
|
||||
assert Permission.MANAGE_SECRETS in owner_perms
|
||||
assert Permission.MANAGE_MCP in owner_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in owner_perms
|
||||
assert Permission.EDIT_LLM_SETTINGS in owner_perms
|
||||
assert Permission.VIEW_BILLING in owner_perms
|
||||
assert Permission.ADD_CREDITS in owner_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in owner_perms
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
|
||||
assert Permission.DELETE_ORGANIZATION in owner_perms
|
||||
|
||||
def test_admin_has_admin_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking admin permissions
|
||||
THEN: Admin has admin permissions but not owner-only permissions
|
||||
"""
|
||||
admin_perms = ROLE_PERMISSIONS[RoleName.ADMIN]
|
||||
assert Permission.MANAGE_SECRETS in admin_perms
|
||||
assert Permission.MANAGE_MCP in admin_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in admin_perms
|
||||
assert Permission.EDIT_LLM_SETTINGS in admin_perms
|
||||
assert Permission.VIEW_BILLING in admin_perms
|
||||
assert Permission.ADD_CREDITS in admin_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
|
||||
# Admin should NOT have owner-only permissions
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
|
||||
assert Permission.DELETE_ORGANIZATION not in admin_perms
|
||||
|
||||
def test_member_has_limited_permissions(self):
|
||||
"""
|
||||
GIVEN: ROLE_PERMISSIONS mapping
|
||||
WHEN: Checking member permissions
|
||||
THEN: Member has limited permissions
|
||||
"""
|
||||
member_perms = ROLE_PERMISSIONS[RoleName.MEMBER]
|
||||
# Member has basic settings permissions
|
||||
assert Permission.MANAGE_SECRETS in member_perms
|
||||
assert Permission.MANAGE_MCP in member_perms
|
||||
assert Permission.MANAGE_INTEGRATIONS in member_perms
|
||||
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
|
||||
assert Permission.MANAGE_API_KEYS in member_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in member_perms
|
||||
assert Permission.VIEW_ORG_SETTINGS in member_perms
|
||||
# Member should NOT have admin/owner permissions
|
||||
assert Permission.EDIT_LLM_SETTINGS not in member_perms
|
||||
assert Permission.VIEW_BILLING not in member_perms
|
||||
assert Permission.ADD_CREDITS not in member_perms
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN not in member_perms
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in member_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in member_perms
|
||||
assert Permission.DELETE_ORGANIZATION not in member_perms
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_role_permissions function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetRolePermissions:
|
||||
"""Tests for get_role_permissions function."""
|
||||
|
||||
def test_get_owner_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'owner'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Owner permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('owner')
|
||||
assert Permission.DELETE_ORGANIZATION in perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in perms
|
||||
|
||||
def test_get_admin_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'admin'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Admin permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('admin')
|
||||
assert Permission.EDIT_LLM_SETTINGS in perms
|
||||
assert Permission.DELETE_ORGANIZATION not in perms
|
||||
|
||||
def test_get_member_permissions(self):
|
||||
"""
|
||||
GIVEN: Role name 'member'
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Member permissions are returned
|
||||
"""
|
||||
perms = get_role_permissions('member')
|
||||
assert Permission.VIEW_LLM_SETTINGS in perms
|
||||
assert Permission.EDIT_LLM_SETTINGS not in perms
|
||||
|
||||
def test_get_invalid_role_permissions(self):
|
||||
"""
|
||||
GIVEN: Invalid role name
|
||||
WHEN: get_role_permissions is called
|
||||
THEN: Empty frozenset is returned
|
||||
"""
|
||||
perms = get_role_permissions('invalid_role')
|
||||
assert perms == frozenset()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for has_permission function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestHasPermission:
|
||||
"""Tests for has_permission function."""
|
||||
|
||||
def test_owner_has_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is True
|
||||
|
||||
def test_owner_has_view_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: Checking for VIEW_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
|
||||
|
||||
def test_admin_has_edit_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: Checking for EDIT_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is True
|
||||
|
||||
def test_admin_lacks_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
def test_member_has_view_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for VIEW_LLM_SETTINGS permission
|
||||
THEN: Returns True
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is True
|
||||
|
||||
def test_member_lacks_edit_llm_settings_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for EDIT_LLM_SETTINGS permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.EDIT_LLM_SETTINGS) is False
|
||||
|
||||
def test_member_lacks_delete_organization_permission(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: Checking for DELETE_ORGANIZATION permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
def test_invalid_role_has_no_permissions(self):
|
||||
"""
|
||||
GIVEN: User with invalid role
|
||||
WHEN: Checking for any permission
|
||||
THEN: Returns False
|
||||
"""
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'invalid_role'
|
||||
assert has_permission(mock_role, Permission.VIEW_LLM_SETTINGS) is False
|
||||
assert has_permission(mock_role, Permission.DELETE_ORGANIZATION) is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for get_user_org_role function
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetUserOrgRole:
|
||||
"""Tests for get_user_org_role function."""
|
||||
|
||||
def test_returns_role_when_member_exists(self):
|
||||
"""
|
||||
GIVEN: User is a member of organization with role
|
||||
WHEN: get_user_org_role is called
|
||||
THEN: Role object is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.role_id = 1
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
return_value=mock_org_member,
|
||||
),
|
||||
patch(
|
||||
'server.auth.authorization.RoleStore.get_role_by_id',
|
||||
return_value=mock_role,
|
||||
),
|
||||
):
|
||||
result = get_user_org_role(user_id, org_id)
|
||||
assert result == mock_role
|
||||
|
||||
def test_returns_none_when_not_member(self):
|
||||
"""
|
||||
GIVEN: User is not a member of organization
|
||||
WHEN: get_user_org_role is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_org_role(user_id, org_id)
|
||||
assert result is None
|
||||
|
||||
def test_returns_role_when_org_id_is_none(self):
|
||||
"""
|
||||
GIVEN: User with a current organization
|
||||
WHEN: get_user_org_role is called with org_id=None
|
||||
THEN: Role object is returned using get_org_member_for_current_org
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
mock_org_member = MagicMock()
|
||||
mock_org_member.role_id = 1
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
|
||||
return_value=mock_org_member,
|
||||
) as mock_get_current,
|
||||
patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member',
|
||||
) as mock_get_org_member,
|
||||
patch(
|
||||
'server.auth.authorization.RoleStore.get_role_by_id',
|
||||
return_value=mock_role,
|
||||
),
|
||||
):
|
||||
result = get_user_org_role(user_id, None)
|
||||
assert result == mock_role
|
||||
mock_get_current.assert_called_once()
|
||||
mock_get_org_member.assert_not_called()
|
||||
|
||||
def test_returns_none_when_org_id_is_none_and_no_current_org(self):
|
||||
"""
|
||||
GIVEN: User with no current organization membership
|
||||
WHEN: get_user_org_role is called with org_id=None
|
||||
THEN: None is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.OrgMemberStore.get_org_member_for_current_org',
|
||||
return_value=None,
|
||||
):
|
||||
result = get_user_org_role(user_id, None)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for require_permission dependency
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
"""Tests for require_permission dependency factory."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_id_when_authorized(self):
|
||||
"""
|
||||
GIVEN: User with required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_401_when_not_authenticated(self):
|
||||
"""
|
||||
GIVEN: No user ID (not authenticated)
|
||||
WHEN: Permission checker is called
|
||||
THEN: 401 Unauthorized is raised
|
||||
"""
|
||||
org_id = uuid4()
|
||||
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'not authenticated' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_not_member(self):
|
||||
"""
|
||||
GIVEN: User is not a member of organization
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_insufficient_permission(self):
|
||||
"""
|
||||
GIVEN: User without required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'delete_organization' in exc_info.value.detail.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_delete_organization(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: DELETE_ORGANIZATION permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_delete_organization(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: DELETE_ORGANIZATION permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_on_insufficient_permission(self):
|
||||
"""
|
||||
GIVEN: User without required permission
|
||||
WHEN: Permission checker is called
|
||||
THEN: Warning is logged with details
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
),
|
||||
patch('server.auth.authorization.logger') as mock_logger,
|
||||
):
|
||||
permission_checker = require_permission(Permission.DELETE_ORGANIZATION)
|
||||
with pytest.raises(HTTPException):
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
call_args = mock_logger.warning.call_args
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
assert call_args[1]['extra']['user_role'] == 'member'
|
||||
assert call_args[1]['extra']['required_permission'] == 'delete_organization'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_user_id_when_org_id_is_none(self):
|
||||
"""
|
||||
GIVEN: User with required permission in their current org
|
||||
WHEN: Permission checker is called with org_id=None
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
) as mock_get_role:
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
result = await permission_checker(org_id=None, user_id=user_id)
|
||||
assert result == user_id
|
||||
mock_get_role.assert_called_once_with(user_id, None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_403_when_org_id_is_none_and_not_member(self):
|
||||
"""
|
||||
GIVEN: User not a member of their current organization
|
||||
WHEN: Permission checker is called with org_id=None
|
||||
THEN: HTTPException with 403 status is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
permission_checker = require_permission(Permission.VIEW_LLM_SETTINGS)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=None, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member' in exc_info.value.detail
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for permission-based access control scenarios
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPermissionScenarios:
|
||||
"""Tests for real-world permission scenarios."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_can_manage_secrets(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: MANAGE_SECRETS permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.MANAGE_SECRETS)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_member_cannot_invite_users(self):
|
||||
"""
|
||||
GIVEN: User with member role
|
||||
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'member'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_can_invite_users(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: INVITE_USER_TO_ORGANIZATION permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(
|
||||
Permission.INVITE_USER_TO_ORGANIZATION
|
||||
)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_cannot_change_owner_role(self):
|
||||
"""
|
||||
GIVEN: User with admin role
|
||||
WHEN: CHANGE_USER_ROLE_OWNER permission is required
|
||||
THEN: 403 Forbidden is raised
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'admin'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await permission_checker(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_owner_can_change_owner_role(self):
|
||||
"""
|
||||
GIVEN: User with owner role
|
||||
WHEN: CHANGE_USER_ROLE_OWNER permission is required
|
||||
THEN: User ID is returned
|
||||
"""
|
||||
user_id = str(uuid4())
|
||||
org_id = uuid4()
|
||||
|
||||
mock_role = MagicMock()
|
||||
mock_role.name = 'owner'
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role_async',
|
||||
AsyncMock(return_value=mock_role),
|
||||
):
|
||||
permission_checker = require_permission(Permission.CHANGE_USER_ROLE_OWNER)
|
||||
result = await permission_checker(org_id=org_id, user_id=user_id)
|
||||
assert result == user_id
|
||||
@@ -101,7 +101,7 @@ async def test_get_credits_success():
|
||||
json={
|
||||
'user_info': {
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
'max_budget_in_team': 100.00,
|
||||
}
|
||||
},
|
||||
request=MagicMock(),
|
||||
@@ -121,7 +121,7 @@ async def test_get_credits_success():
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
'max_budget_in_team': 100.00,
|
||||
},
|
||||
),
|
||||
):
|
||||
@@ -299,6 +299,8 @@ async def test_success_callback_success():
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
@@ -311,7 +313,7 @@ async def test_success_callback_success():
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 25.50,
|
||||
'litellm_budget_table': {'max_budget': 100.00},
|
||||
'max_budget_in_team': 100.00,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
@@ -319,7 +321,17 @@ async def test_success_callback_success():
|
||||
) as mock_update_budget,
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
# First query: BillingSession (query().filter().filter().first())
|
||||
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
# Second query: Org (query().filter().first()) - use side_effect for different return chains
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
@@ -337,9 +349,12 @@ async def test_success_callback_success():
|
||||
# Verify LiteLLM API calls
|
||||
mock_update_budget.assert_called_once_with(
|
||||
'mock_org_id',
|
||||
125.0, # 100 + (25.00 from Stripe)
|
||||
125.0, # 100 + 25.00
|
||||
)
|
||||
|
||||
# Verify BYOR export is enabled for the org (updated in same session)
|
||||
assert mock_org.byor_export_enabled is True
|
||||
|
||||
# Verify database updates
|
||||
assert mock_billing_session.status == 'completed'
|
||||
assert mock_billing_session.price == 25.0
|
||||
@@ -387,6 +402,68 @@ async def test_success_callback_lite_llm_error():
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_callback_lite_llm_update_budget_error_rollback():
|
||||
"""Test that database changes are not committed when update_team_and_users_budget fails.
|
||||
|
||||
This test verifies that if LiteLlmManager.update_team_and_users_budget raises an exception,
|
||||
the database transaction rolls back.
|
||||
"""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_billing_session = MagicMock()
|
||||
mock_billing_session.status = 'in_progress'
|
||||
mock_billing_session.user_id = 'mock_user'
|
||||
|
||||
mock_org = MagicMock()
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
|
||||
patch(
|
||||
'storage.user_store.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(current_org_id='mock_org_id'),
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
|
||||
return_value={
|
||||
'spend': 0,
|
||||
'max_budget_in_team': 0,
|
||||
},
|
||||
),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget',
|
||||
side_effect=Exception('LiteLLM API Error'),
|
||||
),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_query_chain_billing = MagicMock()
|
||||
mock_query_chain_billing.filter.return_value.filter.return_value.first.return_value = mock_billing_session
|
||||
mock_query_chain_org = MagicMock()
|
||||
mock_query_chain_org.filter.return_value.first.return_value = mock_org
|
||||
mock_db_session.query.side_effect = [
|
||||
mock_query_chain_billing,
|
||||
mock_query_chain_org,
|
||||
]
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
mock_stripe_retrieve.return_value = MagicMock(
|
||||
status='complete',
|
||||
amount_subtotal=1000, # $10
|
||||
customer='mock_customer_id',
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='LiteLLM API Error'):
|
||||
await success_callback('test_session_id', mock_request)
|
||||
|
||||
# Verify no database commit occurred - the transaction should roll back
|
||||
assert mock_billing_session.status == 'in_progress'
|
||||
mock_db_session.merge.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_callback_session_not_found():
|
||||
"""Test cancel callback when billing session is not found."""
|
||||
@@ -502,6 +579,6 @@ async def test_create_customer_setup_session_success():
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com/?free_credits=success',
|
||||
success_url='https://test.com/?setup=success',
|
||||
cancel_url='https://test.com/',
|
||||
)
|
||||
|
||||
@@ -48,7 +48,7 @@ async def test_create_customer_setup_session_uses_customer_id():
|
||||
customer=customer_id,
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
success_url=f'{request.base_url}?setup=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
|
||||
|
||||
192
enterprise/tests/unit/test_email_service.py
Normal file
192
enterprise/tests/unit/test_email_service.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Tests for email service."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from server.services.email_service import (
|
||||
DEFAULT_WEB_HOST,
|
||||
EmailService,
|
||||
)
|
||||
|
||||
|
||||
class TestEmailServiceInvitationUrl:
|
||||
"""Test cases for invitation URL generation."""
|
||||
|
||||
def test_invitation_url_uses_correct_endpoint(self):
|
||||
"""Test that invitation URL points to the correct API endpoint."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
# Get the call arguments
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
# Verify the URL in the email HTML contains the correct endpoint
|
||||
assert (
|
||||
'/api/organizations/members/invite/accept?token='
|
||||
in email_params['html']
|
||||
)
|
||||
assert 'inv-test-token-12345' in email_params['html']
|
||||
|
||||
def test_invitation_url_uses_web_host_env_var(self):
|
||||
"""Test that invitation URL uses WEB_HOST environment variable."""
|
||||
custom_host = 'https://custom.example.com'
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{'RESEND_API_KEY': 'test-key', 'WEB_HOST': custom_host},
|
||||
),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
expected_url = f'{custom_host}/api/organizations/members/invite/accept?token=inv-test-token-12345'
|
||||
assert expected_url in email_params['html']
|
||||
|
||||
def test_invitation_url_uses_default_host_when_env_not_set(self):
|
||||
"""Test that invitation URL falls back to DEFAULT_WEB_HOST when env not set."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
env_without_web_host = {'RESEND_API_KEY': 'test-key'}
|
||||
# Remove WEB_HOST if it exists
|
||||
env_without_web_host.pop('WEB_HOST', None)
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, env_without_web_host, clear=True),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
# Clear WEB_HOST from the environment
|
||||
os.environ.pop('WEB_HOST', None)
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
expected_url = f'{DEFAULT_WEB_HOST}/api/organizations/members/invite/accept?token=inv-test-token-12345'
|
||||
assert expected_url in email_params['html']
|
||||
|
||||
|
||||
class TestEmailServiceGetResendClient:
|
||||
"""Test cases for Resend client initialization."""
|
||||
|
||||
def test_get_resend_client_returns_false_when_resend_not_available(self):
|
||||
"""Test that _get_resend_client returns False when resend is not installed."""
|
||||
with patch('server.services.email_service.RESEND_AVAILABLE', False):
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is False
|
||||
|
||||
def test_get_resend_client_returns_false_when_api_key_not_configured(self):
|
||||
"""Test that _get_resend_client returns False when API key is missing."""
|
||||
with (
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch.dict(os.environ, {}, clear=True),
|
||||
):
|
||||
os.environ.pop('RESEND_API_KEY', None)
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is False
|
||||
|
||||
def test_get_resend_client_returns_true_when_configured(self):
|
||||
"""Test that _get_resend_client returns True when properly configured."""
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
result = EmailService._get_resend_client()
|
||||
assert result is True
|
||||
assert mock_resend.api_key == 'test-key'
|
||||
|
||||
|
||||
class TestEmailServiceSendInvitationEmail:
|
||||
"""Test cases for send_invitation_email method."""
|
||||
|
||||
def test_send_invitation_email_skips_when_client_not_ready(self):
|
||||
"""Test that email sending is skipped when client is not ready."""
|
||||
with patch.object(
|
||||
EmailService, '_get_resend_client', return_value=False
|
||||
) as mock_get_client:
|
||||
# Should not raise, just return early
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Test Org',
|
||||
inviter_name='Inviter',
|
||||
role_name='member',
|
||||
invitation_token='inv-test-token',
|
||||
invitation_id=1,
|
||||
)
|
||||
|
||||
mock_get_client.assert_called_once()
|
||||
|
||||
def test_send_invitation_email_includes_all_required_info(self):
|
||||
"""Test that invitation email includes org name, inviter name, and role."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.get.return_value = 'test-email-id'
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'RESEND_API_KEY': 'test-key'}),
|
||||
patch('server.services.email_service.RESEND_AVAILABLE', True),
|
||||
patch('server.services.email_service.resend') as mock_resend,
|
||||
):
|
||||
mock_resend.Emails.send.return_value = mock_response
|
||||
|
||||
EmailService.send_invitation_email(
|
||||
to_email='test@example.com',
|
||||
org_name='Acme Corp',
|
||||
inviter_name='John Doe',
|
||||
role_name='admin',
|
||||
invitation_token='inv-test-token-12345',
|
||||
invitation_id=42,
|
||||
)
|
||||
|
||||
call_args = mock_resend.Emails.send.call_args
|
||||
email_params = call_args[0][0]
|
||||
|
||||
# Verify email content
|
||||
assert email_params['to'] == ['test@example.com']
|
||||
assert 'Acme Corp' in email_params['subject']
|
||||
assert 'John Doe' in email_params['html']
|
||||
assert 'Acme Corp' in email_params['html']
|
||||
assert 'admin' in email_params['html']
|
||||
116
enterprise/tests/unit/test_identity_utils.py
Normal file
116
enterprise/tests/unit/test_identity_utils.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Tests for the resolve_display_name helper.
|
||||
|
||||
The resolve_display_name helper extracts the best available display name from
|
||||
Keycloak user_info claims. It is used by both the /api/user/info fallback path
|
||||
and the user_store create/migrate paths to avoid duplicating name-resolution logic.
|
||||
|
||||
The fallback chain is: name → given_name + family_name → None.
|
||||
It intentionally does NOT fall back to preferred_username/username — callers
|
||||
that need a guaranteed non-None value handle that separately, because the
|
||||
/api/user/info route should return name=None when no real name is available.
|
||||
"""
|
||||
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
|
||||
class TestResolveDisplayName:
|
||||
"""Test resolve_display_name with various Keycloak claim combinations."""
|
||||
|
||||
def test_returns_name_when_present(self):
|
||||
"""When user_info has a 'name' claim, use it directly."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'name': 'Jane Doe',
|
||||
'given_name': 'Jane',
|
||||
'family_name': 'Doe',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) == 'Jane Doe'
|
||||
|
||||
def test_combines_given_and_family_name_when_name_absent(self):
|
||||
"""When 'name' is missing, combine given_name + family_name."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'given_name': 'Jane',
|
||||
'family_name': 'Doe',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) == 'Jane Doe'
|
||||
|
||||
def test_uses_given_name_only_when_family_name_absent(self):
|
||||
"""When only given_name is available, use it alone."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'given_name': 'Jane',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) == 'Jane'
|
||||
|
||||
def test_uses_family_name_only_when_given_name_absent(self):
|
||||
"""When only family_name is available, use it alone."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'family_name': 'Doe',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) == 'Doe'
|
||||
|
||||
def test_returns_none_when_no_name_claims(self):
|
||||
"""When no name claims exist at all, return None."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'preferred_username': 'j.doe',
|
||||
'email': 'jane@example.com',
|
||||
}
|
||||
assert resolve_display_name(user_info) is None
|
||||
|
||||
def test_returns_none_for_empty_name(self):
|
||||
"""When 'name' is an empty string, treat it as absent."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'name': '',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) is None
|
||||
|
||||
def test_returns_none_for_whitespace_only_name(self):
|
||||
"""When 'name' is whitespace only, treat it as absent."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'name': ' ',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) is None
|
||||
|
||||
def test_returns_none_for_empty_given_and_family_names(self):
|
||||
"""When given_name and family_name are both empty strings, return None."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'given_name': '',
|
||||
'family_name': '',
|
||||
'preferred_username': 'j.doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) is None
|
||||
|
||||
def test_returns_none_for_empty_dict(self):
|
||||
"""An empty user_info dict returns None."""
|
||||
assert resolve_display_name({}) is None
|
||||
|
||||
def test_strips_whitespace_from_combined_name(self):
|
||||
"""Whitespace around given_name/family_name is stripped."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'given_name': ' Jane ',
|
||||
'family_name': ' Doe ',
|
||||
}
|
||||
assert resolve_display_name(user_info) == 'Jane Doe'
|
||||
|
||||
def test_name_claim_takes_priority_over_given_family(self):
|
||||
"""When both 'name' and given/family are present, 'name' wins."""
|
||||
user_info = {
|
||||
'sub': '123',
|
||||
'name': 'Dr. Jane Doe',
|
||||
'given_name': 'Jane',
|
||||
'family_name': 'Doe',
|
||||
}
|
||||
assert resolve_display_name(user_info) == 'Dr. Jane Doe'
|
||||
@@ -11,7 +11,11 @@ from pydantic import SecretStr
|
||||
from server.constants import (
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.lite_llm_manager import (
|
||||
LiteLlmManager,
|
||||
get_byor_key_alias,
|
||||
get_openhands_cloud_key_alias,
|
||||
)
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
@@ -138,44 +142,192 @@ class TestLiteLlmManager:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
|
||||
"""Test create_entries in cloud deployment mode."""
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_settings,
|
||||
create_user=False,
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert (
|
||||
result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
)
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
# Verify API calls were made
|
||||
assert (
|
||||
mock_client.post.call_count == 3
|
||||
) # create_team, create_user, add_user_to_team, generate_key
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.agent == 'CodeActAgent'
|
||||
assert result.llm_model == get_default_litellm_model()
|
||||
assert result.llm_api_key.get_secret_value() == 'test-api-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify API calls were made (get_team + 3 posts)
|
||||
assert mock_client.get.call_count == 1 # get_team
|
||||
assert (
|
||||
mock_client.post.call_count == 3
|
||||
) # create_team, add_user_to_team, generate_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_inherits_existing_team_budget(
|
||||
self, mock_settings, mock_response
|
||||
):
|
||||
"""Test that create_entries inherits budget from existing team."""
|
||||
mock_team_response = MagicMock()
|
||||
mock_team_response.is_success = True
|
||||
mock_team_response.status_code = 200
|
||||
mock_team_response.json.return_value = {
|
||||
'team_info': {'max_budget': 30.0, 'spend': 5.0},
|
||||
'team_memberships': [],
|
||||
}
|
||||
mock_team_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_team_response
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify _get_team was called first
|
||||
mock_client.get.assert_called_once()
|
||||
get_call_url = mock_client.get.call_args[0][0]
|
||||
assert 'team/info' in get_call_url
|
||||
assert 'test-org-id' in get_call_url
|
||||
|
||||
# Verify _create_team was called with inherited budget (30.0)
|
||||
create_team_call = mock_client.post.call_args_list[0]
|
||||
assert 'team/new' in create_team_call[0][0]
|
||||
assert create_team_call[1]['json']['max_budget'] == 30.0
|
||||
|
||||
# Verify _add_user_to_team was called with inherited budget (30.0)
|
||||
add_user_call = mock_client.post.call_args_list[1]
|
||||
assert 'team/member_add' in add_user_call[0][0]
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_new_org_uses_zero_budget(
|
||||
self, mock_settings, mock_response
|
||||
):
|
||||
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify _create_team was called with budget=0
|
||||
create_team_call = mock_client.post.call_args_list[0]
|
||||
assert 'team/new' in create_team_call[0][0]
|
||||
assert create_team_call[1]['json']['max_budget'] == 0.0
|
||||
|
||||
# Verify _add_user_to_team was called with budget=0
|
||||
add_user_call = mock_client.post.call_args_list[1]
|
||||
assert 'team/member_add' in add_user_call[0][0]
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
|
||||
"""Test that create_entries propagates non-404 errors from _get_team."""
|
||||
mock_500_response = MagicMock()
|
||||
mock_500_response.status_code = 500
|
||||
mock_500_response.is_success = False
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_500_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Internal Server Error',
|
||||
request=MagicMock(),
|
||||
response=mock_500_response,
|
||||
)
|
||||
)
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
):
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_missing_config(self, mock_user_settings):
|
||||
@@ -284,6 +436,16 @@ class TestLiteLlmManager:
|
||||
self, mock_user_settings, mock_user_response, mock_response
|
||||
):
|
||||
"""Test successful migrate_entries operation."""
|
||||
# Mock response for key list
|
||||
mock_key_list_response = MagicMock()
|
||||
mock_key_list_response.is_success = True
|
||||
mock_key_list_response.status_code = 200
|
||||
mock_key_list_response.json.return_value = {
|
||||
'keys': ['test-key-1', 'test-key-2'],
|
||||
'total_count': 2,
|
||||
}
|
||||
mock_key_list_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
@@ -301,14 +463,22 @@ class TestLiteLlmManager:
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_user_response
|
||||
# First GET is for _get_user, second GET is for _get_user_keys
|
||||
mock_client.get.side_effect = [
|
||||
mock_user_response,
|
||||
mock_key_list_response,
|
||||
]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
# Mock verify_key to return True (key exists in LiteLLM)
|
||||
with patch.object(
|
||||
LiteLlmManager, 'verify_key', return_value=True
|
||||
):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# migrate_entries returns the user_settings unchanged
|
||||
assert result is not None
|
||||
@@ -317,10 +487,97 @@ class TestLiteLlmManager:
|
||||
assert result.llm_api_key.get_secret_value() == 'test-key'
|
||||
assert result.llm_base_url == 'http://test.com'
|
||||
|
||||
# Verify migration steps were called
|
||||
# Verify migration steps were called:
|
||||
# - 2 GET requests: _get_user, _get_user_keys
|
||||
# - POST requests: create_team, update_user, add_user_to_team,
|
||||
# and update_key for each key (2 keys)
|
||||
assert mock_client.get.call_count == 2
|
||||
assert (
|
||||
mock_client.post.call_count == 4
|
||||
) # create_team, update_user, add_user_to_team, update_key
|
||||
mock_client.post.call_count == 5
|
||||
) # create_team, update_user, add_user_to_team, 2x update_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrate_entries_generates_key_when_db_key_not_in_litellm(
|
||||
self, mock_user_settings, mock_user_response, mock_response
|
||||
):
|
||||
"""Test migrate_entries generates a new key when the DB key doesn't exist in LiteLLM."""
|
||||
# Mock response for key list
|
||||
mock_key_list_response = MagicMock()
|
||||
mock_key_list_response.is_success = True
|
||||
mock_key_list_response.status_code = 200
|
||||
mock_key_list_response.json.return_value = {
|
||||
'keys': ['test-key-1', 'test-key-2'],
|
||||
'total_count': 2,
|
||||
}
|
||||
mock_key_list_response.raise_for_status = MagicMock()
|
||||
|
||||
# Mock response for key generation
|
||||
mock_generate_response = MagicMock()
|
||||
mock_generate_response.is_success = True
|
||||
mock_generate_response.status_code = 200
|
||||
mock_generate_response.json.return_value = {'key': 'new-generated-key'}
|
||||
mock_generate_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
|
||||
):
|
||||
with patch(
|
||||
'storage.lite_llm_manager.TokenManager'
|
||||
) as mock_token_manager:
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = (
|
||||
AsyncMock(return_value={'email': 'test@example.com'})
|
||||
)
|
||||
|
||||
with patch('httpx.AsyncClient') as mock_client_class:
|
||||
mock_client = AsyncMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
# First GET is for _get_user, second GET is for _get_user_keys
|
||||
mock_client.get.side_effect = [
|
||||
mock_user_response,
|
||||
mock_key_list_response,
|
||||
]
|
||||
# POST responses: create_team, update_user, add_user_to_team,
|
||||
# 2x update_key, and 1x generate_key
|
||||
mock_client.post.side_effect = [
|
||||
mock_response, # create_team
|
||||
mock_response, # update_user
|
||||
mock_response, # add_user_to_team
|
||||
mock_response, # update_key 1
|
||||
mock_response, # update_key 2
|
||||
mock_generate_response, # generate_key
|
||||
]
|
||||
|
||||
# Mock verify_key to return False (key doesn't exist in LiteLLM)
|
||||
with patch.object(
|
||||
LiteLlmManager, 'verify_key', return_value=False
|
||||
):
|
||||
result = await LiteLlmManager.migrate_entries(
|
||||
'test-org-id',
|
||||
'test-user-id',
|
||||
mock_user_settings,
|
||||
)
|
||||
|
||||
# migrate_entries should update user_settings with the new key
|
||||
assert result is not None
|
||||
assert (
|
||||
result.llm_api_key.get_secret_value()
|
||||
== 'new-generated-key'
|
||||
)
|
||||
assert (
|
||||
result.llm_api_key_for_byor.get_secret_value()
|
||||
== 'new-generated-key'
|
||||
)
|
||||
|
||||
# Verify migration steps were called including key generation:
|
||||
# - 2 GET requests: _get_user, _get_user_keys
|
||||
# - 6 POST requests: create_team, update_user, add_user_to_team,
|
||||
# 2x update_key, 1x generate_key
|
||||
assert mock_client.get.call_count == 2
|
||||
assert mock_client.post.call_count == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_team_and_users_budget_missing_config(self):
|
||||
@@ -654,7 +911,7 @@ class TestLiteLlmManager:
|
||||
# Assert
|
||||
mock_logger.warning.assert_called_once_with(
|
||||
'invalid_litellm_key_during_update',
|
||||
extra={'user_id': 'test-user-id'},
|
||||
extra={'user_id': 'test-user-id', 'text': 'Unauthorized'},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -678,6 +935,133 @@ class TestLiteLlmManager:
|
||||
mock_http_client, 'test-user-id', 'test-api-key', team_id='test-team-id'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_get_user_keys_success(self, mock_http_client):
|
||||
"""Test successful _get_user_keys operation."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'keys': ['key-1', 'key-2', 'key-3'],
|
||||
'total_count': 3,
|
||||
}
|
||||
mock_http_client.get.return_value = mock_response
|
||||
|
||||
# Act
|
||||
keys = await LiteLlmManager._get_user_keys(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert keys == ['key-1', 'key-2', 'key-3']
|
||||
mock_http_client.get.assert_called_once()
|
||||
call_args = mock_http_client.get.call_args
|
||||
assert 'http://test.com/key/list' in call_args[0]
|
||||
assert call_args[1]['params'] == {'user_id': 'test-user-id'}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_get_user_keys_empty_list(self, mock_http_client):
|
||||
"""Test _get_user_keys returns empty list when user has no keys."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = True
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'keys': [],
|
||||
'total_count': 0,
|
||||
}
|
||||
mock_http_client.get.return_value = mock_response
|
||||
|
||||
# Act
|
||||
keys = await LiteLlmManager._get_user_keys(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert keys == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_get_user_keys_error_returns_empty_list(self, mock_http_client):
|
||||
"""Test _get_user_keys returns empty list on error."""
|
||||
# Arrange
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_success = False
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = 'Internal server error'
|
||||
mock_http_client.get.return_value = mock_response
|
||||
|
||||
# Act
|
||||
keys = await LiteLlmManager._get_user_keys(mock_http_client, 'test-user-id')
|
||||
|
||||
# Assert
|
||||
assert keys == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_keys_missing_config(self, mock_http_client):
|
||||
"""Test _get_user_keys returns empty list when config is missing."""
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
keys = await LiteLlmManager._get_user_keys(
|
||||
mock_http_client, 'test-user-id'
|
||||
)
|
||||
assert keys == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_update_user_keys_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _update_user_keys operation."""
|
||||
# Arrange
|
||||
mock_key_list_response = MagicMock()
|
||||
mock_key_list_response.is_success = True
|
||||
mock_key_list_response.status_code = 200
|
||||
mock_key_list_response.json.return_value = {
|
||||
'keys': ['key-1', 'key-2'],
|
||||
'total_count': 2,
|
||||
}
|
||||
mock_http_client.get.return_value = mock_key_list_response
|
||||
mock_http_client.post.return_value = mock_response
|
||||
|
||||
# Act
|
||||
await LiteLlmManager._update_user_keys(
|
||||
mock_http_client, 'test-user-id', team_id='test-team-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should call GET once for key list
|
||||
assert mock_http_client.get.call_count == 1
|
||||
# Should call POST twice (once for each key)
|
||||
assert mock_http_client.post.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com')
|
||||
@patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key')
|
||||
async def test_update_user_keys_no_keys(self, mock_http_client, mock_response):
|
||||
"""Test _update_user_keys when user has no keys."""
|
||||
# Arrange
|
||||
mock_key_list_response = MagicMock()
|
||||
mock_key_list_response.is_success = True
|
||||
mock_key_list_response.status_code = 200
|
||||
mock_key_list_response.json.return_value = {
|
||||
'keys': [],
|
||||
'total_count': 0,
|
||||
}
|
||||
mock_http_client.get.return_value = mock_key_list_response
|
||||
|
||||
# Act
|
||||
await LiteLlmManager._update_user_keys(
|
||||
mock_http_client, 'test-user-id', team_id='test-team-id'
|
||||
)
|
||||
|
||||
# Assert
|
||||
# Should call GET once for key list
|
||||
assert mock_http_client.get.call_count == 1
|
||||
# Should not call POST since there are no keys
|
||||
assert mock_http_client.post.call_count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_key_success(self, mock_http_client, mock_response):
|
||||
"""Test successful _generate_key operation."""
|
||||
@@ -1242,6 +1626,15 @@ class TestLiteLlmManager:
|
||||
}
|
||||
mock_team_info_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_key_list_response = MagicMock()
|
||||
mock_key_list_response.is_success = True
|
||||
mock_key_list_response.status_code = 200
|
||||
mock_key_list_response.json.return_value = {
|
||||
'keys': ['test-key-1', 'test-key-2'],
|
||||
'total_count': 2,
|
||||
}
|
||||
mock_key_list_response.raise_for_status = MagicMock()
|
||||
|
||||
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
|
||||
with patch(
|
||||
@@ -1255,7 +1648,12 @@ class TestLiteLlmManager:
|
||||
mock_client_class.return_value.__aenter__.return_value = (
|
||||
mock_client
|
||||
)
|
||||
mock_client.get.return_value = mock_team_info_response
|
||||
# GET requests: get_team (x2 for team info), get_user_keys
|
||||
mock_client.get.side_effect = [
|
||||
mock_team_info_response,
|
||||
mock_team_info_response,
|
||||
mock_key_list_response,
|
||||
]
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
result = await LiteLlmManager.downgrade_entries(
|
||||
@@ -1269,15 +1667,19 @@ class TestLiteLlmManager:
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
# Verify downgrade steps were called:
|
||||
# GET requests:
|
||||
# 1. get_team (GET)
|
||||
# 2. get_user_team_info (GET via _get_team)
|
||||
# 3. update_user (POST)
|
||||
# 4. add_user_to_team (POST)
|
||||
# 5. update_key (POST)
|
||||
# 6. remove_user_from_team (POST)
|
||||
# 7. delete_team (POST)
|
||||
assert mock_client.get.call_count >= 1
|
||||
assert mock_client.post.call_count >= 4
|
||||
# 3. get_user_keys (GET)
|
||||
# POST requests:
|
||||
# 1. update_user (POST)
|
||||
# 2. add_user_to_team (POST)
|
||||
# 3. update_key for key 1 (POST)
|
||||
# 4. update_key for key 2 (POST)
|
||||
# 5. remove_user_from_team (POST)
|
||||
# 6. delete_team (POST)
|
||||
assert mock_client.get.call_count == 3
|
||||
assert mock_client.post.call_count == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_downgrade_entries_local_deployment(self, mock_user_settings):
|
||||
@@ -1297,3 +1699,321 @@ class TestLiteLlmManager:
|
||||
# making any LiteLLM calls
|
||||
assert result is not None
|
||||
assert result.agent == 'TestAgent'
|
||||
|
||||
|
||||
class TestGetAllKeysForUser:
|
||||
"""Test cases for _get_all_keys_for_user method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_missing_config(self):
|
||||
"""Test _get_all_keys_for_user when LiteLLM config is missing."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
assert result == []
|
||||
mock_client.get.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_success(self):
|
||||
"""Test _get_all_keys_for_user returns keys on success."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
'keys': [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'test-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
},
|
||||
{
|
||||
'key_name': 'sk-test5678',
|
||||
'key_alias': 'another-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': None,
|
||||
},
|
||||
]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]['key_name'] == 'sk-test1234'
|
||||
assert result[1]['key_name'] == 'sk-test5678'
|
||||
|
||||
# Verify API key header is included
|
||||
mock_client.get.assert_called_once()
|
||||
call_kwargs = mock_client.get.call_args
|
||||
assert call_kwargs.kwargs['headers'] == {
|
||||
'x-goog-api-key': 'test-api-key'
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_empty_response(self):
|
||||
"""Test _get_all_keys_for_user returns empty list when user has no keys."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {'keys': []}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_client.get.return_value = mock_response
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_keys_api_error(self):
|
||||
"""Test _get_all_keys_for_user handles API errors gracefully."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
mock_client.get.side_effect = Exception('API Error')
|
||||
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-api-key'):
|
||||
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
|
||||
result = await LiteLlmManager._get_all_keys_for_user(
|
||||
mock_client, 'test-user-id'
|
||||
)
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestVerifyExistingKey:
|
||||
"""Test cases for _verify_existing_key method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_openhands_type_found(self):
|
||||
"""Test _verify_existing_key finds matching OpenHands key."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Key ending with '1234' should match 'sk-test1234'
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-1234',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_openhands_type_not_found(self):
|
||||
"""Test _verify_existing_key returns False when key doesn't match."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Key ending with '5678' should NOT match 'sk-test1234'
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-5678',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_by_alias_openhands_cloud(self):
|
||||
"""Test _verify_existing_key finds key by OpenHands Cloud alias."""
|
||||
user_id = 'test-user-id'
|
||||
org_id = 'test-org'
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-testABCD',
|
||||
'key_alias': get_openhands_cloud_key_alias(user_id, org_id),
|
||||
'team_id': org_id,
|
||||
'metadata': None,
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-ABCD',
|
||||
user_id,
|
||||
org_id,
|
||||
openhands_type=False,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_by_alias_byor(self):
|
||||
"""Test _verify_existing_key finds key by BYOR alias."""
|
||||
user_id = 'test-user-id'
|
||||
org_id = 'test-org'
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-testXYZW',
|
||||
'key_alias': get_byor_key_alias(user_id, org_id),
|
||||
'team_id': org_id,
|
||||
'metadata': None,
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-XYZW',
|
||||
user_id,
|
||||
org_id,
|
||||
openhands_type=False,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_wrong_team(self):
|
||||
"""Test _verify_existing_key returns False for wrong team_id."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': 'sk-test1234',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'different-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'my-key-ending-with-1234',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_no_keys(self):
|
||||
"""Test _verify_existing_key returns False when user has no keys."""
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = []
|
||||
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'some-key-value',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_handles_none_key_name(self):
|
||||
"""Test _verify_existing_key handles None key_name gracefully."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': None,
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Should not raise TypeError, should return False
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'some-key-value',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_existing_key_handles_empty_key_name(self):
|
||||
"""Test _verify_existing_key handles empty key_name gracefully."""
|
||||
mock_keys = [
|
||||
{
|
||||
'key_name': '',
|
||||
'key_alias': 'some-alias',
|
||||
'team_id': 'test-org',
|
||||
'metadata': {'type': 'openhands'},
|
||||
}
|
||||
]
|
||||
|
||||
mock_client = AsyncMock(spec=httpx.AsyncClient)
|
||||
|
||||
with patch.object(
|
||||
LiteLlmManager, '_get_all_keys_for_user', new_callable=AsyncMock
|
||||
) as mock_get_keys:
|
||||
mock_get_keys.return_value = mock_keys
|
||||
|
||||
# Should not raise error, should return False
|
||||
result = await LiteLlmManager._verify_existing_key(
|
||||
mock_client,
|
||||
'some-key-value',
|
||||
'test-user-id',
|
||||
'test-org',
|
||||
openhands_type=True,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
464
enterprise/tests/unit/test_org_invitation_service.py
Normal file
464
enterprise/tests/unit/test_org_invitation_service.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""Tests for organization invitation service - email validation."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from storage.org_invitation import OrgInvitation
|
||||
|
||||
|
||||
class TestAcceptInvitationEmailValidation:
|
||||
"""Test cases for email validation during invitation acceptance."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create a mock invitation with pending status."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.id = 1
|
||||
invitation.email = 'alice@example.com'
|
||||
invitation.status = OrgInvitation.STATUS_PENDING
|
||||
invitation.org_id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
invitation.role_id = 1
|
||||
return invitation
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create a mock user with email."""
|
||||
user = MagicMock()
|
||||
user.id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
user.email = 'alice@example.com'
|
||||
return user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_matches(self, mock_invitation, mock_user):
|
||||
"""Test that invitation is accepted when user email matches invitation email."""
|
||||
# Arrange
|
||||
user_id = mock_user.id
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationService, 'accept_invitation', new_callable=AsyncMock
|
||||
) as mock_accept:
|
||||
mock_accept.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert
|
||||
mock_accept.assert_called_once_with(token, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_mismatch_raises_error(
|
||||
self, mock_invitation, mock_user
|
||||
):
|
||||
"""Test that EmailMismatchError is raised when emails don't match."""
|
||||
# Arrange
|
||||
user_id = mock_user.id
|
||||
token = 'inv-test-token-12345'
|
||||
mock_user.email = 'bob@example.com' # Different email
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(EmailMismatchError):
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_user_no_email_keycloak_fallback_matches(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that Keycloak email is used when user has no email in database."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = None # No email in database
|
||||
|
||||
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.TokenManager'
|
||||
) as mock_token_manager_class,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock TokenManager instance
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.get_user_info_from_user_id = AsyncMock(
|
||||
return_value=mock_keycloak_user_info
|
||||
)
|
||||
mock_token_manager_class.return_value = mock_token_manager
|
||||
|
||||
mock_get_member.return_value = None # Not already a member
|
||||
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because Keycloak email matches
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert
|
||||
mock_token_manager.get_user_info_from_user_id.assert_called_once_with(
|
||||
str(user_id)
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_no_email_anywhere_raises_error(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that EmailMismatchError is raised when user has no email in database or Keycloak."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = None # No email in database
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.TokenManager'
|
||||
) as mock_token_manager_class,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock TokenManager to return no email
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.get_user_info_from_user_id = AsyncMock(return_value={})
|
||||
mock_token_manager_class.return_value = mock_token_manager
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(EmailMismatchError) as exc_info:
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
assert 'does not have an email address' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_email_comparison_is_case_insensitive(
|
||||
self, mock_invitation
|
||||
):
|
||||
"""Test that email comparison is case insensitive."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = 'ALICE@EXAMPLE.COM' # Uppercase email
|
||||
|
||||
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id_async',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member'
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org'
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
mock_create_litellm.return_value = MagicMock(llm_api_key='test-key')
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because emails match case-insensitively
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert - invitation was accepted (update_invitation_status was called)
|
||||
mock_update_status.assert_called_once()
|
||||
|
||||
|
||||
class TestCreateInvitationsBatch:
|
||||
"""Test cases for batch invitation creation."""
|
||||
|
||||
@pytest.fixture
|
||||
def org_id(self):
|
||||
"""Organization UUID for testing."""
|
||||
return UUID('12345678-1234-5678-1234-567812345678')
|
||||
|
||||
@pytest.fixture
|
||||
def inviter_id(self):
|
||||
"""Inviter UUID for testing."""
|
||||
return UUID('87654321-4321-8765-4321-876543218765')
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org(self):
|
||||
"""Create a mock organization."""
|
||||
org = MagicMock()
|
||||
org.id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
org.name = 'Test Org'
|
||||
return org
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inviter_member(self):
|
||||
"""Create a mock inviter member with owner role."""
|
||||
member = MagicMock()
|
||||
member.user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
member.role_id = 1
|
||||
return member
|
||||
|
||||
@pytest.fixture
|
||||
def mock_owner_role(self):
|
||||
"""Create a mock owner role."""
|
||||
role = MagicMock()
|
||||
role.id = 1
|
||||
role.name = 'owner'
|
||||
return role
|
||||
|
||||
@pytest.fixture
|
||||
def mock_member_role(self):
|
||||
"""Create a mock member role."""
|
||||
role = MagicMock()
|
||||
role.id = 3
|
||||
role.name = 'member'
|
||||
return role
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_creates_all_invitations_successfully(
|
||||
self,
|
||||
org_id,
|
||||
inviter_id,
|
||||
mock_org,
|
||||
mock_inviter_member,
|
||||
mock_owner_role,
|
||||
mock_member_role,
|
||||
):
|
||||
"""Test that batch creation succeeds for all valid emails."""
|
||||
# Arrange
|
||||
emails = ['alice@example.com', 'bob@example.com']
|
||||
mock_invitation_1 = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation_1.id = 1
|
||||
mock_invitation_2 = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation_2.id = 2
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=mock_member_role,
|
||||
),
|
||||
patch.object(
|
||||
OrgInvitationService,
|
||||
'create_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[mock_invitation_1, mock_invitation_2],
|
||||
),
|
||||
):
|
||||
# Act
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(successful) == 2
|
||||
assert len(failed) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_handles_partial_success(
|
||||
self,
|
||||
org_id,
|
||||
inviter_id,
|
||||
mock_org,
|
||||
mock_inviter_member,
|
||||
mock_owner_role,
|
||||
mock_member_role,
|
||||
):
|
||||
"""Test that batch returns partial results when some emails fail."""
|
||||
# Arrange
|
||||
from server.routes.org_invitation_models import UserAlreadyMemberError
|
||||
|
||||
emails = ['alice@example.com', 'existing@example.com']
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=mock_member_role,
|
||||
),
|
||||
patch.object(
|
||||
OrgInvitationService,
|
||||
'create_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[mock_invitation, UserAlreadyMemberError()],
|
||||
),
|
||||
):
|
||||
# Act
|
||||
successful, failed = await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(successful) == 1
|
||||
assert len(failed) == 1
|
||||
assert failed[0][0] == 'existing@example.com'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_fails_entirely_on_permission_error(self, org_id, inviter_id):
|
||||
"""Test that permission error fails the entire batch upfront."""
|
||||
# Arrange
|
||||
|
||||
emails = ['alice@example.com', 'bob@example.com']
|
||||
|
||||
with patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=None, # Organization not found
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='member',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
assert 'not found' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_fails_on_invalid_role(
|
||||
self, org_id, inviter_id, mock_org, mock_inviter_member, mock_owner_role
|
||||
):
|
||||
"""Test that invalid role fails the entire batch."""
|
||||
# Arrange
|
||||
emails = ['alice@example.com']
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
return_value=mock_org,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
return_value=mock_inviter_member,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_id',
|
||||
return_value=mock_owner_role,
|
||||
),
|
||||
patch(
|
||||
'server.services.org_invitation_service.RoleStore.get_role_by_name',
|
||||
return_value=None, # Invalid role
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgInvitationService.create_invitations_batch(
|
||||
org_id=org_id,
|
||||
emails=emails,
|
||||
role_name='invalid_role',
|
||||
inviter_id=inviter_id,
|
||||
)
|
||||
|
||||
assert 'Invalid role' in str(exc_info.value)
|
||||
308
enterprise/tests/unit/test_org_invitation_store.py
Normal file
308
enterprise/tests/unit/test_org_invitation_store.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Tests for organization invitation store."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_invitation_store import (
|
||||
INVITATION_TOKEN_LENGTH,
|
||||
INVITATION_TOKEN_PREFIX,
|
||||
OrgInvitationStore,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateToken:
|
||||
"""Test cases for token generation."""
|
||||
|
||||
def test_generate_token_has_correct_prefix(self):
|
||||
"""Test that generated tokens have the correct prefix."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
assert token.startswith(INVITATION_TOKEN_PREFIX)
|
||||
|
||||
def test_generate_token_has_correct_length(self):
|
||||
"""Test that generated tokens have the correct total length."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
expected_length = len(INVITATION_TOKEN_PREFIX) + INVITATION_TOKEN_LENGTH
|
||||
assert len(token) == expected_length
|
||||
|
||||
def test_generate_token_uses_alphanumeric_characters(self):
|
||||
"""Test that generated tokens use only alphanumeric characters."""
|
||||
token = OrgInvitationStore.generate_token()
|
||||
# Remove prefix and check the rest is alphanumeric
|
||||
random_part = token[len(INVITATION_TOKEN_PREFIX) :]
|
||||
assert random_part.isalnum()
|
||||
|
||||
def test_generate_token_is_unique(self):
|
||||
"""Test that generated tokens are unique (probabilistically)."""
|
||||
tokens = [OrgInvitationStore.generate_token() for _ in range(100)]
|
||||
assert len(set(tokens)) == 100
|
||||
|
||||
|
||||
class TestIsTokenExpired:
|
||||
"""Test cases for token expiration checking."""
|
||||
|
||||
def test_token_not_expired_when_future(self):
|
||||
"""Test that tokens with future expiration are not expired."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is False
|
||||
|
||||
def test_token_expired_when_past(self):
|
||||
"""Test that tokens with past expiration are expired."""
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = datetime.utcnow() - timedelta(seconds=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is True
|
||||
|
||||
def test_token_expired_at_exact_boundary(self):
|
||||
"""Test that tokens at exact expiration time are expired."""
|
||||
# A token that expires "now" should be expired
|
||||
now = datetime.utcnow()
|
||||
invitation = MagicMock(spec=OrgInvitation)
|
||||
invitation.expires_at = now - timedelta(microseconds=1)
|
||||
|
||||
result = OrgInvitationStore.is_token_expired(invitation)
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestCreateInvitation:
|
||||
"""Test cases for invitation creation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_invitation_normalizes_email(self):
|
||||
"""Test that email is normalized (lowercase, stripped) on creation."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.execute = AsyncMock()
|
||||
|
||||
# Mock the result of the re-fetch query
|
||||
mock_result = MagicMock()
|
||||
mock_invitation = MagicMock()
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.email = 'test@example.com'
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
await OrgInvitationStore.create_invitation(
|
||||
org_id=UUID('12345678-1234-5678-1234-567812345678'),
|
||||
email=' TEST@EXAMPLE.COM ',
|
||||
role_id=1,
|
||||
inviter_id=UUID('87654321-4321-8765-4321-876543218765'),
|
||||
)
|
||||
|
||||
# Verify that the OrgInvitation was created with normalized email
|
||||
add_call = mock_session.add.call_args
|
||||
created_invitation = add_call[0][0]
|
||||
assert created_invitation.email == 'test@example.com'
|
||||
|
||||
|
||||
class TestGetInvitationByToken:
|
||||
"""Test cases for getting invitation by token."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_by_token_returns_invitation(self):
|
||||
"""Test that get_invitation_by_token returns the invitation when found."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.token = 'inv-test-token-12345'
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.get_invitation_by_token(
|
||||
'inv-test-token-12345'
|
||||
)
|
||||
assert result == mock_invitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_invitation_by_token_returns_none_when_not_found(self):
|
||||
"""Test that get_invitation_by_token returns None when not found."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.get_invitation_by_token(
|
||||
'inv-nonexistent-token'
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetPendingInvitation:
|
||||
"""Test cases for getting pending invitation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_pending_invitation_normalizes_email(self):
|
||||
"""Test that email is normalized when querying for pending invitations."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
await OrgInvitationStore.get_pending_invitation(
|
||||
org_id=UUID('12345678-1234-5678-1234-567812345678'),
|
||||
email=' TEST@EXAMPLE.COM ',
|
||||
)
|
||||
|
||||
# Verify the query was called (email normalization happens in the filter)
|
||||
assert mock_session.execute.called
|
||||
|
||||
|
||||
class TestUpdateInvitationStatus:
|
||||
"""Test cases for updating invitation status."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_sets_accepted_at_for_accepted(self):
|
||||
"""Test that accepted_at is set when status is accepted."""
|
||||
from uuid import UUID
|
||||
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = mock_invitation
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.refresh = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
await OrgInvitationStore.update_invitation_status(
|
||||
invitation_id=1,
|
||||
status=OrgInvitation.STATUS_ACCEPTED,
|
||||
accepted_by_user_id=user_id,
|
||||
)
|
||||
|
||||
assert mock_invitation.accepted_at is not None
|
||||
assert mock_invitation.accepted_by_user_id == user_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_status_returns_none_when_not_found(self):
|
||||
"""Test that update returns None when invitation not found."""
|
||||
mock_session = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.first.return_value = None
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
with patch(
|
||||
'storage.org_invitation_store.a_session_maker'
|
||||
) as mock_session_maker:
|
||||
mock_session_manager = AsyncMock()
|
||||
mock_session_manager.__aenter__.return_value = mock_session
|
||||
mock_session_manager.__aexit__.return_value = None
|
||||
mock_session_maker.return_value = mock_session_manager
|
||||
|
||||
result = await OrgInvitationStore.update_invitation_status(
|
||||
invitation_id=999,
|
||||
status=OrgInvitation.STATUS_ACCEPTED,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestMarkExpiredIfNeeded:
|
||||
"""Test cases for marking expired invitations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_marks_expired_when_pending_and_past_expiry(self):
|
||||
"""Test that pending expired invitations are marked as expired."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is True
|
||||
mock_update.assert_called_once_with(1, OrgInvitation.STATUS_EXPIRED)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_mark_when_not_expired(self):
|
||||
"""Test that non-expired invitations are not marked."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_PENDING
|
||||
mock_invitation.expires_at = datetime.utcnow() + timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is False
|
||||
mock_update.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_mark_when_not_pending(self):
|
||||
"""Test that non-pending invitations are not marked even if expired."""
|
||||
mock_invitation = MagicMock(spec=OrgInvitation)
|
||||
mock_invitation.id = 1
|
||||
mock_invitation.status = OrgInvitation.STATUS_ACCEPTED
|
||||
mock_invitation.expires_at = datetime.utcnow() - timedelta(days=1)
|
||||
|
||||
with patch.object(
|
||||
OrgInvitationStore,
|
||||
'update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update:
|
||||
result = await OrgInvitationStore.mark_expired_if_needed(mock_invitation)
|
||||
|
||||
assert result is False
|
||||
mock_update.assert_not_called()
|
||||
388
enterprise/tests/unit/test_org_invitations_router.py
Normal file
388
enterprise/tests/unit/test_org_invitations_router.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Tests for organization invitations API router."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from server.routes.org_invitation_models import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.routes.org_invitations import accept_router, invitation_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a FastAPI app with the invitation routers."""
|
||||
app = FastAPI()
|
||||
app.include_router(invitation_router)
|
||||
app.include_router(accept_router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""Create a test client for the app."""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestRouterPrefixes:
|
||||
"""Test that router prefixes are configured correctly."""
|
||||
|
||||
def test_invitation_router_has_correct_prefix(self):
|
||||
"""Test that invitation_router has /api/organizations/{org_id}/members prefix."""
|
||||
assert invitation_router.prefix == '/api/organizations/{org_id}/members'
|
||||
|
||||
def test_accept_router_has_correct_prefix(self):
|
||||
"""Test that accept_router has /api/organizations/members/invite prefix."""
|
||||
assert accept_router.prefix == '/api/organizations/members/invite'
|
||||
|
||||
|
||||
class TestAcceptInvitationEndpoint:
|
||||
"""Test cases for the accept invitation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_auth(self):
|
||||
"""Create a mock user auth."""
|
||||
user_auth = MagicMock()
|
||||
user_auth.get_user_id = AsyncMock(
|
||||
return_value='87654321-4321-8765-4321-876543218765'
|
||||
)
|
||||
return user_auth
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_unauthenticated_redirects_to_login(self, client):
|
||||
"""Test that unauthenticated users are redirected to login with invitation token."""
|
||||
with patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert '/login?invitation_token=inv-test-token-123' in response.headers.get(
|
||||
'location', ''
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_authenticated_success_redirects_home(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that successful acceptance redirects to home page."""
|
||||
mock_invitation = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_invitation,
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
location = response.headers.get('location', '')
|
||||
assert location.endswith('/')
|
||||
assert 'invitation_expired' not in location
|
||||
assert 'invitation_invalid' not in location
|
||||
assert 'email_mismatch' not in location
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_expired_invitation_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that expired invitation redirects with invitation_expired=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InvitationExpiredError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_expired=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invalid_invitation_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that invalid invitation redirects with invitation_invalid=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InvitationInvalidError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_invalid=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_already_member_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that already member error redirects with already_member=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=UserAlreadyMemberError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'already_member=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_email_mismatch_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that email mismatch error redirects with email_mismatch=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=EmailMismatchError(),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'email_mismatch=true' in response.headers.get('location', '')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_unexpected_error_redirects_with_flag(
|
||||
self, client, mock_user_auth
|
||||
):
|
||||
"""Test that unexpected errors redirect with invitation_error=true."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.accept_invitation',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception('Unexpected error'),
|
||||
),
|
||||
):
|
||||
response = client.get(
|
||||
'/api/organizations/members/invite/accept?token=inv-test-token-123',
|
||||
follow_redirects=False,
|
||||
)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert 'invitation_error=true' in response.headers.get('location', '')
|
||||
|
||||
|
||||
class TestCreateInvitationBatchEndpoint:
|
||||
"""Test cases for the batch invitation creation endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def batch_app(self):
|
||||
"""Create a FastAPI app with dependency overrides for batch tests."""
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(invitation_router)
|
||||
|
||||
# Override the get_user_id dependency
|
||||
app.dependency_overrides[get_user_id] = (
|
||||
lambda: '87654321-4321-8765-4321-876543218765'
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def batch_client(self, batch_app):
|
||||
"""Create a test client with dependency overrides."""
|
||||
return TestClient(batch_app)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_invitation(self):
|
||||
"""Create a mock invitation."""
|
||||
from datetime import datetime
|
||||
|
||||
invitation = MagicMock()
|
||||
invitation.id = 1
|
||||
invitation.email = 'alice@example.com'
|
||||
invitation.role = MagicMock(name='member')
|
||||
invitation.role.name = 'member'
|
||||
invitation.role_id = 3
|
||||
invitation.status = 'pending'
|
||||
invitation.created_at = datetime(2026, 2, 17, 10, 0, 0)
|
||||
invitation.expires_at = datetime(2026, 2, 24, 10, 0, 0)
|
||||
return invitation
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_returns_successful_invitations(
|
||||
self, batch_client, mock_invitation
|
||||
):
|
||||
"""Test that batch creation returns successful invitations."""
|
||||
mock_invitation_2 = MagicMock()
|
||||
mock_invitation_2.id = 2
|
||||
mock_invitation_2.email = 'bob@example.com'
|
||||
mock_invitation_2.role = MagicMock()
|
||||
mock_invitation_2.role.name = 'member'
|
||||
mock_invitation_2.role_id = 3
|
||||
mock_invitation_2.status = 'pending'
|
||||
mock_invitation_2.created_at = mock_invitation.created_at
|
||||
mock_invitation_2.expires_at = mock_invitation.expires_at
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
return_value=([mock_invitation, mock_invitation_2], []),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={
|
||||
'emails': ['alice@example.com', 'bob@example.com'],
|
||||
'role': 'member',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert len(data['successful']) == 2
|
||||
assert len(data['failed']) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_returns_partial_success(
|
||||
self, batch_client, mock_invitation
|
||||
):
|
||||
"""Test that batch creation returns both successful and failed invitations."""
|
||||
failed_emails = [('existing@example.com', 'User is already a member')]
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
return_value=([mock_invitation], failed_emails),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={
|
||||
'emails': ['alice@example.com', 'existing@example.com'],
|
||||
'role': 'member',
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert len(data['successful']) == 1
|
||||
assert len(data['failed']) == 1
|
||||
assert data['failed'][0]['email'] == 'existing@example.com'
|
||||
assert 'already a member' in data['failed'][0]['error']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_permission_denied_returns_403(self, batch_client):
|
||||
"""Test that permission denied returns 403 for entire batch."""
|
||||
from server.routes.org_invitation_models import InsufficientPermissionError
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=InsufficientPermissionError(
|
||||
'Only owners and admins can invite'
|
||||
),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={'emails': ['alice@example.com'], 'role': 'member'},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert 'owners and admins' in response.json()['detail']
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_create_invalid_role_returns_400(self, batch_client):
|
||||
"""Test that invalid role returns 400."""
|
||||
with (
|
||||
patch(
|
||||
'server.routes.org_invitations.check_rate_limit_by_user_id',
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
'server.routes.org_invitations.OrgInvitationService.create_invitations_batch',
|
||||
new_callable=AsyncMock,
|
||||
side_effect=ValueError('Invalid role: superuser'),
|
||||
),
|
||||
):
|
||||
response = batch_client.post(
|
||||
'/api/organizations/12345678-1234-5678-1234-567812345678/members/invite',
|
||||
json={'emails': ['alice@example.com'], 'role': 'superuser'},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert 'Invalid role' in response.json()['detail']
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user