mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
260 Commits
1.2.1
...
optimize-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e9a0bbe87 | ||
|
|
c9af5edad9 | ||
|
|
0c7ce4ad48 | ||
|
|
4dab34e7b0 | ||
|
|
f8bbd352a9 | ||
|
|
17347a95f8 | ||
|
|
01ef87aaaa | ||
|
|
8059c18b57 | ||
|
|
c82ee4c7db | ||
|
|
7fdb423f99 | ||
|
|
530065dfa7 | ||
|
|
a4cd2d81a5 | ||
|
|
003b430e96 | ||
|
|
d63565186e | ||
|
|
5f42d03ec5 | ||
|
|
62241e2e00 | ||
|
|
f5197bd76a | ||
|
|
e1408f7b15 | ||
|
|
d6b8d80026 | ||
|
|
1e6a92b454 | ||
|
|
b4a3e5db2f | ||
|
|
f9d553d0bb | ||
|
|
f6f6c1ab25 | ||
|
|
c511a89426 | ||
|
|
1f82ff04d9 | ||
|
|
eec17311c7 | ||
|
|
c34fdf4b37 | ||
|
|
25076ee44c | ||
|
|
baaec8473a | ||
|
|
402fa47422 | ||
|
|
8dde385843 | ||
|
|
a905e35531 | ||
|
|
1f185173b7 | ||
|
|
ddc7a78723 | ||
|
|
a29ed4d926 | ||
|
|
b8ab4bb44e | ||
|
|
ddd544f8d6 | ||
|
|
3804b66e32 | ||
|
|
b97adf392a | ||
|
|
dcb584913a | ||
|
|
d2fd54a083 | ||
|
|
112d863287 | ||
|
|
c8680caec3 | ||
|
|
d4b9fb1d03 | ||
|
|
409df1287d | ||
|
|
a92bfe6cc0 | ||
|
|
f93e3254d3 | ||
|
|
0476d57451 | ||
|
|
a4cd21e155 | ||
|
|
7f3af371d1 | ||
|
|
1421794c1b | ||
|
|
2fc689457c | ||
|
|
3161b365a8 | ||
|
|
18ab56ef4e | ||
|
|
a9c0df778c | ||
|
|
51b989b5f8 | ||
|
|
dc039d81d6 | ||
|
|
8e4559b14a | ||
|
|
b84f352b63 | ||
|
|
a0dba6124a | ||
|
|
951739f3eb | ||
|
|
0f1ad46a47 | ||
|
|
5367bef43a | ||
|
|
3afeccfe7f | ||
|
|
0677c035ff | ||
|
|
68165b52d9 | ||
|
|
dcc8217317 | ||
|
|
d1410949ff | ||
|
|
a6c0d80fe1 | ||
|
|
0efb1db85d | ||
|
|
8e0f74c92c | ||
|
|
6e1ba3d836 | ||
|
|
0ec97893d1 | ||
|
|
ddb809bc43 | ||
|
|
872f2b87f2 | ||
|
|
ee86005a3a | ||
|
|
d4aa30580b | ||
|
|
2f0e879129 | ||
|
|
3bc2ef954e | ||
|
|
32ab2a24c6 | ||
|
|
a6e148d1e6 | ||
|
|
3fc977eddd | ||
|
|
89a6890269 | ||
|
|
8927ac2230 | ||
|
|
f3429e33ca | ||
|
|
7cd219792b | ||
|
|
2aabe2ed8c | ||
|
|
731a9a813e | ||
|
|
123e556fed | ||
|
|
6676cae249 | ||
|
|
fede37b496 | ||
|
|
3bcd6f18df | ||
|
|
0da18440c2 | ||
|
|
ac76e10048 | ||
|
|
b98bae8b5f | ||
|
|
516721d1ee | ||
|
|
4d6f66ca28 | ||
|
|
b18568da0b | ||
|
|
83dd3c169c | ||
|
|
35bddb14f1 | ||
|
|
e8425218e2 | ||
|
|
0a879fa781 | ||
|
|
41e142bbab | ||
|
|
b06b9eedac | ||
|
|
a9afafa991 | ||
|
|
663ace4b39 | ||
|
|
2d085a6e0a | ||
|
|
8b7112abe8 | ||
|
|
34547ba947 | ||
|
|
5f958ab60d | ||
|
|
d7656bf1c9 | ||
|
|
2bc107564c | ||
|
|
85eb1e1504 | ||
|
|
cd235cc8c7 | ||
|
|
40f52dfabc | ||
|
|
bab7bf85e8 | ||
|
|
c856537f65 | ||
|
|
736f5b2255 | ||
|
|
c1d9d11772 | ||
|
|
85244499fe | ||
|
|
c55084e223 | ||
|
|
e3bb75deb4 | ||
|
|
1948200762 | ||
|
|
affe0af361 | ||
|
|
f20c956196 | ||
|
|
4a089a3a0d | ||
|
|
aa0b2d0b74 | ||
|
|
bef9b80b9d | ||
|
|
c4a90b1f89 | ||
|
|
0d13c57d9f | ||
|
|
b3422f1275 | ||
|
|
f139a9970b | ||
|
|
54d156122c | ||
|
|
ac072bf686 | ||
|
|
a53812c029 | ||
|
|
1d1c0925b5 | ||
|
|
872f41e3c0 | ||
|
|
d43ff82534 | ||
|
|
8cd8c011b2 | ||
|
|
5c68b10983 | ||
|
|
a97fad1976 | ||
|
|
4c3542a91c | ||
|
|
f460057f58 | ||
|
|
4fa2ad0f47 | ||
|
|
dd8be12809 | ||
|
|
89475095d9 | ||
|
|
05d5f8848a | ||
|
|
ee2885eb0b | ||
|
|
545257f870 | ||
|
|
b23ab33a01 | ||
|
|
a9ede73391 | ||
|
|
634c2439b4 | ||
|
|
a1989a40b3 | ||
|
|
e38f1283ea | ||
|
|
07eb791735 | ||
|
|
c355c4819f | ||
|
|
9d8e4c44cc | ||
|
|
25cc55e558 | ||
|
|
0e825c38d7 | ||
|
|
ce04e70b5b | ||
|
|
7b0589ad40 | ||
|
|
682465a862 | ||
|
|
1bb4c844d4 | ||
|
|
d6c11fe517 | ||
|
|
b088d4857e | ||
|
|
0f05898d55 | ||
|
|
d1f0a01a57 | ||
|
|
f5a9d28999 | ||
|
|
afa0417608 | ||
|
|
e688fba761 | ||
|
|
d1ec5cbdf6 | ||
|
|
f42625f789 | ||
|
|
fe28519677 | ||
|
|
e62ceafa4a | ||
|
|
0b8c69fad2 | ||
|
|
37d9b672a4 | ||
|
|
c8b867a634 | ||
|
|
59834beba7 | ||
|
|
d2eced9cff | ||
|
|
7836136ff8 | ||
|
|
fdb04dfe5d | ||
|
|
3d4cb89441 | ||
|
|
9fb9efd3d2 | ||
|
|
5511c01c2e | ||
|
|
02825fb5bb | ||
|
|
876e773589 | ||
|
|
9e1ae86191 | ||
|
|
df47b7b79d | ||
|
|
7d1c105b55 | ||
|
|
db6a9e8895 | ||
|
|
d76ac44dc3 | ||
|
|
c483c80a3c | ||
|
|
570ab904f6 | ||
|
|
00a74731ae | ||
|
|
102095affb | ||
|
|
b6ce45b474 | ||
|
|
11c87caba4 | ||
|
|
b8a608c45e | ||
|
|
8a446787be | ||
|
|
353124e171 | ||
|
|
e9298c89bd | ||
|
|
29b77be807 | ||
|
|
7094835ef0 | ||
|
|
7ad0eec325 | ||
|
|
31d5081163 | ||
|
|
250736cb7a | ||
|
|
a9bd3a70c9 | ||
|
|
d7436a4af4 | ||
|
|
f327e76be7 | ||
|
|
52e39e5d12 | ||
|
|
6c5ef256fd | ||
|
|
373ade8554 | ||
|
|
9d0a19cf8f | ||
|
|
d60dd38d78 | ||
|
|
d5ee799670 | ||
|
|
b685fd43dd | ||
|
|
0e04f6fdbe | ||
|
|
9c40929197 | ||
|
|
af309e8586 | ||
|
|
cc5d5c2335 | ||
|
|
60e668f4a7 | ||
|
|
743f6256a6 | ||
|
|
a87b4efd41 | ||
|
|
730d9970f5 | ||
|
|
2440593431 | ||
|
|
8c94ddbf1a | ||
|
|
af60183f22 | ||
|
|
40fb69325f | ||
|
|
f9891a2c0d | ||
|
|
44f1bb022b | ||
|
|
95c2a0285c | ||
|
|
8ea4813936 | ||
|
|
650bf8c0c0 | ||
|
|
dd4bb5d362 | ||
|
|
b3137e7ae8 | ||
|
|
0d740925c5 | ||
|
|
030ff59c40 | ||
|
|
3bc56740b9 | ||
|
|
87a5762bf2 | ||
|
|
5b149927fb | ||
|
|
275bb689e4 | ||
|
|
d595945309 | ||
|
|
e8ba741f28 | ||
|
|
57929b73ee | ||
|
|
9fd4e42438 | ||
|
|
0176cd9669 | ||
|
|
2e9b29970e | ||
|
|
a07fc1510b | ||
|
|
78f067acef | ||
|
|
0d5f97c8c7 | ||
|
|
a987387353 | ||
|
|
941b6dd340 | ||
|
|
23251a2487 | ||
|
|
6af8301996 | ||
|
|
b57ec9eda1 | ||
|
|
c8594a2eaa | ||
|
|
102715a3c9 | ||
|
|
d5e66b4f3a | ||
|
|
f5315887ec | ||
|
|
6e4ac8e2ce |
137
.github/ISSUE_TEMPLATE/bug_template.yml
vendored
137
.github/ISSUE_TEMPLATE/bug_template.yml
vendored
@@ -5,52 +5,113 @@ labels: ['bug']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: Thank you for taking the time to fill out this bug report. Please provide as much information as possible
|
||||
to help us understand and address the issue effectively.
|
||||
value: |
|
||||
## Thank you for reporting a bug! 🐛
|
||||
|
||||
**Please fill out all required fields.** Issues missing critical information (version, installation method, reproduction steps, etc.) will be delayed or closed until complete details are provided.
|
||||
|
||||
Clear, detailed reports help us resolve issues faster.
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for the same bug? (If one exists, thumbs up or comment on the issue instead).
|
||||
description: Please check if an issue already exists for the bug you encountered.
|
||||
label: Is there an existing issue for the same bug?
|
||||
description: Please search existing issues before creating a new one. If found, react or comment to the duplicate issue instead of making a new one.
|
||||
options:
|
||||
- label: I have checked the existing issues.
|
||||
- label: I have searched existing issues and this is not a duplicate.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: bug-description
|
||||
attributes:
|
||||
label: Describe the bug and reproduction steps
|
||||
description: Provide a description of the issue along with any reproduction steps.
|
||||
label: Bug Description
|
||||
description: Clearly describe what went wrong. Be specific and concise.
|
||||
placeholder: Example - "When I run a Python task, OpenHands crashes after 30 seconds with a connection timeout error."
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
attributes:
|
||||
label: Expected Behavior
|
||||
description: What did you expect to happen?
|
||||
placeholder: Example - "OpenHands should execute the Python script and return results."
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: actual-behavior
|
||||
attributes:
|
||||
label: Actual Behavior
|
||||
description: What actually happened?
|
||||
placeholder: Example - "Connection timed out after 30 seconds, task failed with error code 500."
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: textarea
|
||||
id: reproduction-steps
|
||||
attributes:
|
||||
label: Steps to Reproduce
|
||||
description: Provide clear, step-by-step instructions to reproduce the bug.
|
||||
placeholder: |
|
||||
1. Install OpenHands using Docker
|
||||
2. Configure with Claude 3.5 Sonnet
|
||||
3. Run command: `openhands run "write a python script"`
|
||||
4. Wait 30 seconds
|
||||
5. Error appears
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: dropdown
|
||||
id: installation
|
||||
attributes:
|
||||
label: OpenHands Installation
|
||||
label: OpenHands Installation Method
|
||||
description: How are you running OpenHands?
|
||||
options:
|
||||
- Docker command in README
|
||||
- GitHub resolver
|
||||
- CLI (uv tool install)
|
||||
- CLI (executable binary)
|
||||
- CLI (Docker)
|
||||
- Local GUI (Docker web interface)
|
||||
- OpenHands Cloud (app.all-hands.dev)
|
||||
- SDK (Python library)
|
||||
- Development workflow
|
||||
- CLI
|
||||
- app.all-hands.dev
|
||||
- Other
|
||||
default: 0
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: installation-other
|
||||
attributes:
|
||||
label: If you selected "Other", please specify
|
||||
description: Describe your installation method
|
||||
placeholder: ex. Custom Kubernetes deployment, pip install from source, etc.
|
||||
|
||||
- type: input
|
||||
id: openhands-version
|
||||
attributes:
|
||||
label: OpenHands Version
|
||||
description: What version of OpenHands are you using?
|
||||
placeholder: ex. 0.9.8, main, etc.
|
||||
description: What version are you using? Find this in settings or by running `openhands --version`
|
||||
placeholder: ex. 0.9.8, main, commit hash, etc.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: checkboxes
|
||||
id: version-confirmation
|
||||
attributes:
|
||||
label: Version Confirmation
|
||||
description: Bugs on older versions may already be fixed. Please upgrade before submitting.
|
||||
options:
|
||||
- label: "I have confirmed this bug exists on the LATEST version of OpenHands"
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: model-name
|
||||
attributes:
|
||||
label: Model Name
|
||||
description: What model are you using?
|
||||
placeholder: ex. gpt-4o, claude-3-5-sonnet, openrouter/deepseek-r1, etc.
|
||||
description: Which LLM model are you using?
|
||||
placeholder: ex. gpt-4o, claude-3-5-sonnet-20241022, openrouter/deepseek-r1, etc.
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: dropdown
|
||||
id: os
|
||||
@@ -60,12 +121,46 @@ body:
|
||||
- MacOS
|
||||
- Linux
|
||||
- WSL on Windows
|
||||
- Windows (Docker Desktop)
|
||||
- Other
|
||||
validations:
|
||||
required: false
|
||||
|
||||
- type: input
|
||||
id: browser
|
||||
attributes:
|
||||
label: Browser (if using web UI)
|
||||
description: |
|
||||
If applicable, which browser and version?
|
||||
|
||||
placeholder: ex. Chrome 131, Firefox 133, Safari 17.2
|
||||
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs and Error Messages
|
||||
description: |
|
||||
**Paste relevant logs, error messages, or stack traces.** Use code blocks (```) for formatting.
|
||||
|
||||
LLM logs are in `logs/llm/default/`. Include timestamps if errors occurred at a specific time.
|
||||
placeholder: |
|
||||
```
|
||||
Paste error logs here
|
||||
```
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Logs, Errors, Screenshots, and Additional Context
|
||||
description: Please provide any additional information you think might help. If you want to share the chat history
|
||||
you can click the thumbs-down (👎) button above the input field and you will get a shareable link
|
||||
(you can also click thumbs up when things are going well of course!). LLM logs will be stored in the
|
||||
`logs/llm/default` folder. Please add any additional context about the problem here.
|
||||
label: Screenshots and Additional Context
|
||||
description: |
|
||||
Add screenshots, videos, runtime environment, or other context that helps explain the issue.
|
||||
|
||||
💡 **Share conversation history:** In the OpenHands chat UI, click the 👎 or 👍 button (above the message input) to generate a shareable link to your conversation.
|
||||
|
||||
placeholder: Drag and drop screenshots here, paste links, or add additional context.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
---
|
||||
**Note:** Issues with incomplete information may be closed or deprioritized. Maintainers and community members have limited bandwidth and prioritize well-documented bugs that are easier to reproduce and fix. Thank you for your understanding!
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
2
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# disable blank issue creation
|
||||
blank_issues_enabled: false
|
||||
17
.github/ISSUE_TEMPLATE/feature_request.md
vendored
17
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -1,17 +0,0 @@
|
||||
---
|
||||
name: Feature Request or Enhancement
|
||||
about: Suggest an idea for an OpenHands feature or enhancement
|
||||
title: ''
|
||||
labels: 'enhancement'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**What problem or use case are you trying to solve?**
|
||||
|
||||
**Describe the UX or technical implementation you have in mind**
|
||||
|
||||
**Additional context**
|
||||
|
||||
|
||||
### If you find this feature request or enhancement useful, make sure to add a 👍 to the issue
|
||||
105
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
105
.github/ISSUE_TEMPLATE/feature_request.yml
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
name: Feature Request or Enhancement
|
||||
description: Suggest a new feature or improvement for OpenHands
|
||||
title: '[Feature]: '
|
||||
labels: ['enhancement']
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
## Thank you for suggesting a feature! 💡
|
||||
|
||||
**Please provide detailed information.** Vague or low-effort requests may be closed. Well-documented feature requests with strong community support are more likely to be added to the roadmap.
|
||||
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing feature request for this?
|
||||
description: Please search existing issues and feature requests before creating a new one. If found, react or comment to the duplicate issue instead of making a new one.
|
||||
options:
|
||||
- label: I have searched existing issues and feature requests, and this is not a duplicate.
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: problem-statement
|
||||
attributes:
|
||||
label: Problem or Use Case
|
||||
description: What problem are you trying to solve? What use case would this feature enable?
|
||||
placeholder: |
|
||||
Example - "As a developer working on large codebases, I need to search across multiple files simultaneously. Currently, I have to search file-by-file which is time-consuming and inefficient."
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: proposed-solution
|
||||
attributes:
|
||||
label: Proposed Solution
|
||||
description: Describe your ideal solution. What should this feature do? How should it work?
|
||||
placeholder: |
|
||||
Example - "Add a global search feature that allows searching across all files in the workspace. Results should show file name, line number, and context around matches. Include regex support and filtering options."
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives Considered
|
||||
description: Have you considered any alternative solutions or workarounds? What are their limitations?
|
||||
placeholder: Example - "I tried using grep in the terminal, but it's not integrated with the UI and doesn't provide click-to-navigate functionality."
|
||||
|
||||
- type: dropdown
|
||||
id: priority
|
||||
attributes:
|
||||
label: Priority / Severity
|
||||
description: How important is this feature to your workflow?
|
||||
options:
|
||||
- "Critical - Blocking my work, no workaround available"
|
||||
- "High - Significant impact on productivity"
|
||||
- "Medium - Would improve experience"
|
||||
- "Low - Nice to have"
|
||||
default: 2
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: dropdown
|
||||
id: scope
|
||||
attributes:
|
||||
label: Estimated Scope
|
||||
description: To the best of your knowledge, how complex do you think this feature would be to implement?
|
||||
options:
|
||||
- "Small - UI tweak, config option, or minor change"
|
||||
- "Medium - New feature with moderate complexity"
|
||||
- "Large - Significant feature requiring architecture changes"
|
||||
- "Unknown - Not sure about the technical complexity"
|
||||
default: 3
|
||||
|
||||
- type: dropdown
|
||||
id: feature-area
|
||||
attributes:
|
||||
label: Feature Area
|
||||
description: Which part of OpenHands does this feature relate to? If you select "Other", please specify the area in the Additional Context section below.
|
||||
options:
|
||||
- "Agent / AI behavior"
|
||||
- "User Interface / UX"
|
||||
- "CLI / Command-line interface"
|
||||
- "File system / Workspace management"
|
||||
- "Configuration / Settings"
|
||||
- "Integrations (GitHub, GitLab, etc.)"
|
||||
- "Performance / Optimization"
|
||||
- "Documentation"
|
||||
- "Other"
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: technical-details
|
||||
attributes:
|
||||
label: Technical Implementation Ideas (Optional)
|
||||
description: If you have technical expertise, share implementation ideas, API suggestions, or relevant technical details.
|
||||
placeholder: |
|
||||
Example - "Could use ripgrep library for fast search. Expose results via /api/search endpoint. Frontend can use virtualized list for rendering large result sets."
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional Context
|
||||
description: Add any other context, screenshots, mockups, or examples that help illustrate this feature request.
|
||||
placeholder: Drag and drop screenshots, mockups, or links here.
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -1,4 +1,4 @@
|
||||
<!-- Ideally you should open a PR when it is ready for review. Draft PRs will not be reviewed -->
|
||||
<!-- If you are still working on the PR, please mark it as draft. Maintainers will review PRs marked ready for review, which leads to lost time if your PR is actually not ready yet. Keep the PR marked as draft until it is finally ready for review -->
|
||||
|
||||
## Summary of PR
|
||||
|
||||
|
||||
29
.github/workflows/enterprise-preview.yml
vendored
29
.github/workflows/enterprise-preview.yml
vendored
@@ -1,29 +0,0 @@
|
||||
# Feature branch preview for enterprise code
|
||||
name: Enterprise Preview
|
||||
|
||||
# Run on PRs labeled
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
# Match ghcr-build.yml, but don't interrupt it.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
# This must happen for the PR Docker workflow when the label is present,
|
||||
# and also if it's added after the fact. Thus, it exists in both places.
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event.label.name == 'deploy'
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
# This should match the version in ghcr-build.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
22
.github/workflows/ghcr-build.yml
vendored
22
.github/workflows/ghcr-build.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "saas-rel-*"
|
||||
tags:
|
||||
- "*"
|
||||
pull_request:
|
||||
@@ -39,8 +40,7 @@ jobs:
|
||||
run: |
|
||||
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu" }
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" }
|
||||
]')
|
||||
else
|
||||
json=$(jq -n -c '[
|
||||
@@ -149,6 +149,9 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
platforms: ${{ env.DOCKER_PLATFORM }}
|
||||
# Caching directives to boost performance
|
||||
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
|
||||
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
|
||||
build-args: ${{ env.DOCKER_BUILD_ARGS }}
|
||||
context: containers/runtime
|
||||
provenance: false
|
||||
@@ -237,21 +240,6 @@ jobs:
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
needs: [ghcr_build_enterprise]
|
||||
steps:
|
||||
# This should match the version in enterprise-preview.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# "All Runtime Tests Passed" is a required job for PRs to merge
|
||||
# We can remove this once the config changes
|
||||
runtime_tests_check_success:
|
||||
|
||||
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
name: PR Review by OpenHands
|
||||
|
||||
on:
|
||||
# TEMPORARY MITIGATION (Clinejection hardening)
|
||||
#
|
||||
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
|
||||
# workflow is fully hardened for untrusted execution.
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
pr-review:
|
||||
# Note: fork PRs will not have access to repository secrets under `pull_request`.
|
||||
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
|
||||
if: |
|
||||
github.event.pull_request.head.repo.full_name == github.repository &&
|
||||
(
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
|
||||
github.event.action == 'ready_for_review' ||
|
||||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
|
||||
(
|
||||
github.event.action == 'review_requested' &&
|
||||
(
|
||||
github.event.requested_reviewer.login == 'openhands-agent' ||
|
||||
github.event.requested_reviewer.login == 'all-hands-bot'
|
||||
)
|
||||
)
|
||||
)
|
||||
concurrency:
|
||||
group: pr-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Run PR Review
|
||||
uses: OpenHands/extensions/plugins/pr-review@main
|
||||
with:
|
||||
llm-model: litellm_proxy/claude-sonnet-4-5-20250929
|
||||
llm-base-url: https://llm-proxy.app.all-hands.dev
|
||||
review-style: roasted
|
||||
llm-api-key: ${{ secrets.LLM_API_KEY }}
|
||||
github-token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
lmnr-api-key: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
name: PR Review Evaluation
|
||||
|
||||
# This workflow evaluates how well PR review comments were addressed.
|
||||
# It runs when a PR is closed to assess review effectiveness.
|
||||
#
|
||||
# Security note: pull_request_target is safe here because:
|
||||
# 1. Only triggers on PR close (not on code changes)
|
||||
# 2. Does not checkout PR code - only downloads artifacts from trusted workflow runs
|
||||
# 3. Runs evaluation scripts from the extensions repo, not from the PR
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
evaluate:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
PR_MERGED: ${{ github.event.pull_request.merged }}
|
||||
|
||||
steps:
|
||||
- name: Download review trace artifact
|
||||
id: download-trace
|
||||
uses: dawidd6/action-download-artifact@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
workflow: pr-review-by-openhands.yml
|
||||
name: pr-review-trace-${{ github.event.pull_request.number }}
|
||||
path: trace-info
|
||||
search_artifacts: true
|
||||
if_no_artifact_found: warn
|
||||
|
||||
- name: Check if trace file exists
|
||||
id: check-trace
|
||||
run: |
|
||||
if [ -f "trace-info/laminar_trace_info.json" ]; then
|
||||
echo "trace_exists=true" >> $GITHUB_OUTPUT
|
||||
echo "Found trace file for PR #$PR_NUMBER"
|
||||
else
|
||||
echo "trace_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No trace file found for PR #$PR_NUMBER - skipping evaluation"
|
||||
fi
|
||||
|
||||
# Always checkout main branch for security - cannot test script changes in PRs
|
||||
- name: Checkout extensions repository
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
repository: OpenHands/extensions
|
||||
path: extensions
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
run: pip install lmnr
|
||||
|
||||
- name: Run evaluation
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
env:
|
||||
# Script expects LMNR_PROJECT_API_KEY; org secret is named LMNR_SKILLS_API_KEY
|
||||
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
python extensions/plugins/pr-review/scripts/evaluate_review.py \
|
||||
--trace-file trace-info/laminar_trace_info.json
|
||||
|
||||
- name: Upload evaluation logs
|
||||
uses: actions/upload-artifact@v5
|
||||
if: always() && steps.check-trace.outputs.trace_exists == 'true'
|
||||
with:
|
||||
name: pr-review-evaluation-${{ github.event.pull_request.number }}
|
||||
path: '*.log'
|
||||
retention-days: 30
|
||||
@@ -6,11 +6,12 @@ Thanks for your interest in contributing to OpenHands! We welcome and appreciate
|
||||
|
||||
To understand the codebase, please refer to the README in each module:
|
||||
- [frontend](./frontend/README.md)
|
||||
- [evaluation](./evaluation/README.md)
|
||||
- [openhands](./openhands/README.md)
|
||||
- [agenthub](./openhands/agenthub/README.md)
|
||||
- [server](./openhands/server/README.md)
|
||||
|
||||
For benchmarks and evaluation, see the [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks) repository.
|
||||
|
||||
## Setting up Your Development Environment
|
||||
|
||||
We have a separate doc [Development.md](https://github.com/OpenHands/OpenHands/blob/main/Development.md) that tells
|
||||
|
||||
@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.2-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
@@ -200,7 +200,7 @@ Here's a guide to the important documentation files in the repository:
|
||||
- [/frontend/README.md](./frontend/README.md): Frontend React application setup and development guide
|
||||
- [/containers/README.md](./containers/README.md): Information about Docker containers and deployment
|
||||
- [/tests/unit/README.md](./tests/unit/README.md): Guide to writing and running unit tests
|
||||
- [/evaluation/README.md](./evaluation/README.md): Documentation for the evaluation framework and benchmarks
|
||||
- [OpenHands/benchmarks](https://github.com/OpenHands/benchmarks): Documentation for the evaluation framework and benchmarks
|
||||
- [/skills/README.md](./skills/README.md): Information about the skills architecture and implementation
|
||||
- [/openhands/server/README.md](./openhands/server/README.md): Server implementation details and API documentation
|
||||
- [/openhands/runtime/README.md](./openhands/runtime/README.md): Documentation for the runtime environment and execution model
|
||||
|
||||
@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
|
||||
@@ -440,12 +440,6 @@ type = "noop"
|
||||
#temperature = 0.1
|
||||
#max_input_tokens = 1024
|
||||
|
||||
#################################### Eval ####################################
|
||||
# Configuration for the evaluation, please refer to the specific evaluation
|
||||
# plugin for the available options
|
||||
##############################################################################
|
||||
|
||||
|
||||
########################### Kubernetes #######################################
|
||||
# Kubernetes configuration when using the Kubernetes runtime
|
||||
##############################################################################
|
||||
|
||||
@@ -12,7 +12,8 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/runtime}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -7,7 +7,8 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -23,12 +23,23 @@ RUN apt-get update && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python packages with security fixes
|
||||
RUN /app/.venv/bin/pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace "posthog>=6.0.0" "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy google-cloud-recaptcha-enterprise && \
|
||||
# Update packages with known CVE fixes
|
||||
/app/.venv/bin/pip install --upgrade \
|
||||
"mcp>=1.10.0" \
|
||||
"pillow>=11.3.0"
|
||||
# Install poetry and export before importing current code.
|
||||
RUN /app/.venv/bin/pip install poetry poetry-plugin-export
|
||||
|
||||
# Install Python dependencies from poetry.lock for reproducible builds
|
||||
# Copy lock files first for better Docker layer caching
|
||||
COPY --chown=openhands:openhands enterprise/pyproject.toml enterprise/poetry.lock /tmp/enterprise/
|
||||
RUN cd /tmp/enterprise && \
|
||||
# Export only main dependencies with hashes for supply chain security
|
||||
/app/.venv/bin/poetry export --only main -o requirements.txt && \
|
||||
# Remove the local path dependency (openhands-ai is already in base image)
|
||||
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
|
||||
# Install pinned dependencies from lock file
|
||||
/app/.venv/bin/pip install -r requirements.txt && \
|
||||
# Cleanup - return to /app before removing /tmp/enterprise
|
||||
cd /app && \
|
||||
rm -rf /tmp/enterprise && \
|
||||
/app/.venv/bin/pip uninstall -y poetry poetry-plugin-export
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
|
||||
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
131
enterprise/doc/design-doc/plugin-launch-flow.md
Normal file
131
enterprise/doc/design-doc/plugin-launch-flow.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Plugin Launch Flow
|
||||
|
||||
This document describes how plugins are launched in OpenHands Saas / Enterprise, from the plugin directory through to agent execution.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Plugin Directory ──▶ Frontend /launch ──▶ App Server ──▶ Agent Server ──▶ SDK
|
||||
(external) (modal) (API) (in sandbox) (plugin loading)
|
||||
```
|
||||
|
||||
| Component | Responsibility |
|
||||
|-----------|---------------|
|
||||
| **Plugin Directory** | Index plugins, present to user, construct launch URLs |
|
||||
| **Frontend** | Display confirmation modal, collect parameters, call API |
|
||||
| **App Server** | Validate request, pass plugin specs to agent server |
|
||||
| **Agent Server** | Run inside sandbox, delegate plugin loading to SDK |
|
||||
| **SDK** | Fetch plugins, load contents, merge skills/hooks/MCP into agent |
|
||||
|
||||
## User Experience
|
||||
|
||||
### Plugin Directory
|
||||
|
||||
The plugin directory presents users with a catalog of available plugins. For each plugin, users see:
|
||||
- Plugin name and description (from `plugin.json`)
|
||||
- Author and version information
|
||||
- A "Launch" button
|
||||
|
||||
When a user clicks "Launch", the plugin directory:
|
||||
1. Reads the plugin's `entry_command` to know which slash command to invoke
|
||||
2. Determines what parameters the plugin accepts (if any)
|
||||
3. Redirects to OpenHands with this information encoded in the URL
|
||||
|
||||
### Parameter Collection
|
||||
|
||||
If a plugin requires user input (API keys, configuration values, etc.), the frontend displays a form modal before starting the conversation. Parameters are passed in the launch URL and rendered as form fields based on their type:
|
||||
|
||||
- **String values** → Text input
|
||||
- **Number values** → Number input
|
||||
- **Boolean values** → Checkbox
|
||||
|
||||
Only primitive types are supported. Complex types (arrays, objects) are not currently supported for parameter input.
|
||||
|
||||
The user fills in required values, then clicks "Start Conversation" to proceed.
|
||||
|
||||
## Launch Flow
|
||||
|
||||
1. **Plugin Directory** (external) constructs a launch URL to the OpenHands app server when user clicks "Launch":
|
||||
```
|
||||
/launch?plugins=BASE64_JSON&message=/city-weather:now%20Tokyo
|
||||
```
|
||||
|
||||
The `plugins` parameter includes any parameter definitions with default values:
|
||||
```json
|
||||
[{
|
||||
"source": "github:owner/repo",
|
||||
"repo_path": "plugins/my-plugin",
|
||||
"parameters": {"api_key": "", "timeout": 30, "debug": false}
|
||||
}]
|
||||
```
|
||||
|
||||
2. **OpenHands Frontend** (`/launch` route, [PR #12699](https://github.com/OpenHands/OpenHands/pull/12699)) displays modal with parameter form, collects user input
|
||||
|
||||
3. **OpenHands App Server** ([PR #12338](https://github.com/OpenHands/OpenHands/pull/12338)) receives the API call:
|
||||
```
|
||||
POST /api/v1/app-conversations
|
||||
{
|
||||
"plugins": [{"source": "github:owner/repo", "repo_path": "plugins/city-weather"}],
|
||||
"initial_message": {"content": [{"type": "text", "text": "/city-weather:now Tokyo"}]}
|
||||
}
|
||||
```
|
||||
|
||||
Call stack:
|
||||
- `AppConversationRouter` receives request with `PluginSpec` list
|
||||
- `LiveStatusAppConversationService._finalize_conversation_request()` converts `PluginSpec` → `PluginSource`
|
||||
- Creates `StartConversationRequest(plugins=sdk_plugins, ...)` and sends to agent server
|
||||
|
||||
4. **Agent Server** (inside sandbox, [SDK PR #1651](https://github.com/OpenHands/software-agent-sdk/pull/1651)) stores specs, defers loading:
|
||||
|
||||
Call stack:
|
||||
- `ConversationService.start_conversation()` receives `StartConversationRequest`
|
||||
- Creates `StoredConversation` with plugin specs
|
||||
- Creates `LocalConversation(plugins=request.plugins, ...)`
|
||||
- Plugin loading deferred until first `run()` or `send_message()`
|
||||
|
||||
5. **SDK** fetches and loads plugins on first use:
|
||||
|
||||
Call stack:
|
||||
- `LocalConversation._ensure_plugins_loaded()` triggered by first message
|
||||
- For each plugin spec:
|
||||
- `Plugin.fetch(source, ref, repo_path)` → clones/caches git repo
|
||||
- `Plugin.load(path)` → parses `plugin.json`, loads commands/skills/hooks
|
||||
- `plugin.add_skills_to(context)` → merges skills into agent
|
||||
- `plugin.add_mcp_config_to(config)` → merges MCP servers
|
||||
|
||||
6. **Agent** receives message, `/city-weather:now` triggers the skill
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### Plugin Loading in Sandbox
|
||||
|
||||
Plugins load **inside the sandbox** because:
|
||||
- Plugin hooks and scripts need isolated execution
|
||||
- MCP servers run inside the sandbox
|
||||
- Skills may reference sandbox filesystem
|
||||
|
||||
### Entry Command Handling
|
||||
|
||||
The `entry_command` field in `plugin.json` allows plugin authors to declare a default command:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "city-weather",
|
||||
"entry_command": "now"
|
||||
}
|
||||
```
|
||||
|
||||
This flows through the system:
|
||||
1. Plugin author declares `entry_command` in plugin.json
|
||||
2. Plugin directory reads it when indexing
|
||||
3. Plugin directory includes `/city-weather:now` in the launch URL's `message` parameter
|
||||
4. Message passes through to agent as `initial_message`
|
||||
|
||||
The SDK exposes this field but does not auto-invoke it—callers control the initial message.
|
||||
|
||||
## Related
|
||||
|
||||
- [OpenHands PR #12338](https://github.com/OpenHands/OpenHands/pull/12338) - App server plugin support
|
||||
- [OpenHands PR #12699](https://github.com/OpenHands/OpenHands/pull/12699) - Frontend `/launch` route
|
||||
- [SDK PR #1651](https://github.com/OpenHands/software-agent-sdk/pull/1651) - Agent server plugin loading
|
||||
- [SDK PR #1647](https://github.com/OpenHands/software-agent-sdk/pull/1647) - Plugin.fetch() for remote plugin fetching
|
||||
207
enterprise/downgrade_migrated_users.py
Normal file
207
enterprise/downgrade_migrated_users.py
Normal file
@@ -0,0 +1,207 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This script can be removed once orgs is established - probably after Feb 15 2026
|
||||
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
and reverts them back to the pre-migration state.
|
||||
|
||||
Usage:
|
||||
# Dry run - just list the users that would be downgraded
|
||||
python downgrade_migrated_users.py --dry-run
|
||||
|
||||
# Downgrade a specific user by their keycloak_user_id
|
||||
python downgrade_migrated_users.py --user-id <user_id>
|
||||
|
||||
# Downgrade all migrated users (with confirmation)
|
||||
python downgrade_migrated_users.py --all
|
||||
|
||||
# Downgrade all migrated users without confirmation (dangerous!)
|
||||
python downgrade_migrated_users.py --all --no-confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# Add the enterprise directory to the path
|
||||
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
|
||||
def get_migrated_users() -> list[str]:
|
||||
"""Get list of keycloak_user_ids for users who have been migrated.
|
||||
|
||||
This includes:
|
||||
1. Users with already_migrated=True in user_settings (migrated users)
|
||||
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Get users from user_settings with already_migrated=True
|
||||
migrated_result = session.execute(
|
||||
select(UserSettings.keycloak_user_id).where(
|
||||
UserSettings.already_migrated.is_(True)
|
||||
)
|
||||
)
|
||||
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
|
||||
|
||||
# Get users from the 'user' table (new sign-ups won't have user_settings)
|
||||
# These are users who signed up after the migration was deployed
|
||||
new_signup_result = session.execute(
|
||||
text("""
|
||||
SELECT CAST(u.id AS VARCHAR)
|
||||
FROM "user" u
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_settings us
|
||||
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
|
||||
)
|
||||
""")
|
||||
)
|
||||
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
|
||||
|
||||
# Combine both sets
|
||||
all_users = migrated_users | new_signups
|
||||
return list(all_users)
|
||||
|
||||
|
||||
async def downgrade_user(user_id: str) -> bool:
|
||||
"""Downgrade a single user.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak_user_id to downgrade
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await UserStore.downgrade_user(user_id)
|
||||
if result:
|
||||
print(f'✓ Successfully downgraded user: {user_id}')
|
||||
return True
|
||||
else:
|
||||
print(f'✗ Failed to downgrade user: {user_id}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'✗ Error downgrading user {user_id}: {e}')
|
||||
logger.exception(
|
||||
'downgrade_script:error',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Downgrade migrated users back to pre-migration state'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Just list users that would be downgraded, without making changes',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
type=str,
|
||||
help='Downgrade a specific user by keycloak_user_id',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Downgrade all migrated users',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-confirm',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt (use with caution!)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get list of migrated users
|
||||
migrated_users = get_migrated_users()
|
||||
print(f'\nFound {len(migrated_users)} migrated user(s).')
|
||||
|
||||
if args.dry_run:
|
||||
print('\n--- DRY RUN MODE ---')
|
||||
print('The following users would be downgraded:')
|
||||
for user_id in migrated_users:
|
||||
print(f' - {user_id}')
|
||||
print('\nNo changes were made.')
|
||||
return
|
||||
|
||||
if args.user_id:
|
||||
# Downgrade a specific user
|
||||
if args.user_id not in migrated_users:
|
||||
print(f'\nUser {args.user_id} is not in the migrated users list.')
|
||||
print('Either the user was not migrated, or the user_id is incorrect.')
|
||||
return
|
||||
|
||||
print(f'\nDowngrading user: {args.user_id}')
|
||||
if not args.no_confirm:
|
||||
confirm = input('Are you sure? (yes/no): ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
success = await downgrade_user(args.user_id)
|
||||
if success:
|
||||
print('\nDowngrade completed successfully.')
|
||||
else:
|
||||
print('\nDowngrade failed. Check logs for details.')
|
||||
sys.exit(1)
|
||||
|
||||
elif args.all:
|
||||
# Downgrade all migrated users
|
||||
if not migrated_users:
|
||||
print('\nNo migrated users to downgrade.')
|
||||
return
|
||||
|
||||
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
|
||||
if not args.no_confirm:
|
||||
print('\nThis will:')
|
||||
print(' - Revert LiteLLM team/user budget settings')
|
||||
print(' - Delete organization entries')
|
||||
print(' - Delete user entries in the new schema')
|
||||
print(' - Reset the already_migrated flag')
|
||||
print('\nUsers to downgrade:')
|
||||
for user_id in migrated_users[:10]: # Show first 10
|
||||
print(f' - {user_id}')
|
||||
if len(migrated_users) > 10:
|
||||
print(f' ... and {len(migrated_users) - 10} more')
|
||||
|
||||
confirm = input('\nType "yes" to proceed: ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
print('\nStarting downgrade...\n')
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for user_id in migrated_users:
|
||||
success = await downgrade_user(user_id)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print('\n--- Summary ---')
|
||||
print(f'Successful: {success_count}')
|
||||
print(f'Failed: {fail_count}')
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
print('\nPlease specify --dry-run, --user-id, or --all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -28,9 +28,11 @@ class SaaSExperimentManager(ExperimentManager):
|
||||
return agent
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
# Skip experiment for planning agents which require their specialized prompt
|
||||
if agent.system_prompt_filename != 'system_prompt_planning.j2':
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@@ -26,12 +26,14 @@ from integrations.utils import (
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.integrations.service_types import AuthenticationError
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -143,11 +145,7 @@ class GithubManager(Manager):
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
# Check event types before making expensive API calls (e.g., _user_has_write_access_to_repo)
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
@@ -157,8 +155,17 @@ class GithubManager(Manager):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
user_has_write_access = self._user_has_write_access_to_repo(
|
||||
installation_id, repo_name, username
|
||||
)
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
if (
|
||||
GithubFactory.is_eligible_for_conversation_starter(message)
|
||||
and user_has_write_access
|
||||
):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
return user_has_write_access
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
@@ -186,14 +193,20 @@ class GithubManager(Manager):
|
||||
github_view.installation_id
|
||||
)
|
||||
# Store the installation token
|
||||
self.token_manager.store_org_token(
|
||||
await self.token_manager.store_org_token(
|
||||
github_view.installation_id, installation_token
|
||||
)
|
||||
# Add eyes reaction to acknowledge we've read the request
|
||||
self._add_reaction(github_view, 'eyes', installation_token)
|
||||
await self.start_job(github_view)
|
||||
|
||||
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||
async def send_message(self, message: str, github_view: ResolverViewInterface):
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
github_view: The GitHub view object containing issue/PR/comment info
|
||||
"""
|
||||
installation_token = self.token_manager.load_org_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
@@ -201,14 +214,12 @@ class GithubManager(Manager):
|
||||
logger.warning('Missing installation token')
|
||||
return
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=github_view.comment_id, body=outgoing_message
|
||||
comment_id=github_view.comment_id, body=message
|
||||
)
|
||||
|
||||
elif (
|
||||
@@ -219,7 +230,7 @@ class GithubManager(Manager):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(number=github_view.issue_number)
|
||||
issue.create_comment(outgoing_message)
|
||||
issue.create_comment(message)
|
||||
|
||||
else:
|
||||
logger.warning('Unsupported location')
|
||||
@@ -238,7 +249,7 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
msg_info: str = ''
|
||||
|
||||
try:
|
||||
user_info = github_view.user_info
|
||||
@@ -347,22 +358,20 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except SessionExpiredError as e:
|
||||
except (AuthenticationError, ExpiredError, SessionExpiredError) as e:
|
||||
logger.warning(
|
||||
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, github_view)
|
||||
await self.send_message(msg_info, github_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Github]: Error starting job')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', github_view
|
||||
)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
try:
|
||||
await self.data_collector.save_data(github_view)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
@@ -90,7 +89,6 @@ async def summarize_issue_solvability(
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
|
||||
@@ -24,10 +24,9 @@ from jinja2 import Environment
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
@@ -78,19 +77,17 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -151,12 +148,11 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
@@ -189,6 +185,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -327,7 +324,6 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -397,7 +393,6 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
|
||||
@@ -121,12 +121,11 @@ class GitlabManager(Manager):
|
||||
# Check if the user has write access to the repository
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
async def send_message(self, message: str, gitlab_view: ResolverViewInterface):
|
||||
"""Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
message: The message content to send (plain text string)
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||
@@ -138,8 +137,6 @@ class GitlabManager(Manager):
|
||||
external_auth_id=keycloak_user_id
|
||||
)
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||
gitlab_view, GitlabMRComment
|
||||
):
|
||||
@@ -147,7 +144,7 @@ class GitlabManager(Manager):
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message.message,
|
||||
message,
|
||||
)
|
||||
|
||||
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||
@@ -155,14 +152,14 @@ class GitlabManager(Manager):
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
outgoing_message,
|
||||
message,
|
||||
)
|
||||
elif isinstance(gitlab_view, GitlabIssue):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
None, # no discussion id, issue is tagged
|
||||
outgoing_message,
|
||||
message,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -262,12 +259,10 @@ class GitlabManager(Manager):
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
# Send the acknowledgment message
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
await self.send_message(msg_info, gitlab_view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', gitlab_view
|
||||
)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.store_repo_utils import store_repositories_in_db
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
|
||||
@@ -6,7 +6,6 @@ from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -78,9 +77,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
@@ -449,3 +446,5 @@ class GitlabFactory:
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
|
||||
raise ValueError(f'Unhandled GitLab webhook event: {message}')
|
||||
|
||||
@@ -167,17 +167,15 @@ async def install_webhook_on_resource(
|
||||
scopes=SCOPES,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Creating new webhook',
|
||||
extra={
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
},
|
||||
)
|
||||
log_extra = {
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
}
|
||||
|
||||
if status == WebhookStatus.RATE_LIMITED:
|
||||
logger.warning('Rate limited while creating webhook', extra=log_extra)
|
||||
raise BreakLoopException()
|
||||
|
||||
if webhook_id:
|
||||
@@ -191,9 +189,8 @@ async def install_webhook_on_resource(
|
||||
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
|
||||
)
|
||||
logger.info('Created new webhook', extra=log_extra)
|
||||
else:
|
||||
logger.error('Failed to create webhook', extra=log_extra)
|
||||
|
||||
return webhook_id, status
|
||||
|
||||
@@ -1,22 +1,37 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
"""Jira integration manager.
|
||||
|
||||
This module orchestrates the processing of Jira webhook events:
|
||||
1. Parse webhook payload (via JiraPayloadParser)
|
||||
2. Validate workspace
|
||||
3. Authenticate user
|
||||
4. Create view with repository selection (via JiraFactory)
|
||||
5. Start conversation job
|
||||
|
||||
The manager delegates payload parsing to JiraPayloadParser and view creation
|
||||
to JiraFactory, keeping the orchestration logic clean and traceable.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
from integrations.jira.jira_payload import (
|
||||
JiraPayloadError,
|
||||
JiraPayloadParser,
|
||||
JiraPayloadSkipped,
|
||||
JiraPayloadSuccess,
|
||||
JiraWebhookPayload,
|
||||
)
|
||||
from integrations.jira.jira_types import (
|
||||
JiraViewInterface,
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.models import Message
|
||||
from integrations.utils import (
|
||||
HOST,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
get_oh_labels,
|
||||
get_session_expired_message,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
@@ -28,9 +43,6 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -41,303 +53,211 @@ from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
# Get OH labels for this environment
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
class JiraManager(Manager):
|
||||
"""Manager for processing Jira webhook events.
|
||||
|
||||
This class orchestrates the flow from webhook receipt to conversation creation,
|
||||
delegating parsing to JiraPayloadParser and view creation to JiraFactory.
|
||||
"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
|
||||
)
|
||||
self.payload_parser = JiraPayloadParser(
|
||||
oh_label=OH_LABEL,
|
||||
inline_oh_label=INLINE_OH_LABEL,
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, jira_user_id: str, workspace_id: int
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message.
|
||||
|
||||
Flow:
|
||||
1. Parse webhook payload
|
||||
2. Validate workspace exists and is active
|
||||
3. Authenticate user
|
||||
4. Create view (includes fetching issue details and selecting repository)
|
||||
5. Start job
|
||||
|
||||
Each step has clear logging for traceability.
|
||||
"""
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
# Step 1: Parse webhook payload
|
||||
logger.info(
|
||||
'[Jira] Received webhook',
|
||||
extra={'raw_payload': raw_payload},
|
||||
)
|
||||
|
||||
parse_result = self.payload_parser.parse(raw_payload)
|
||||
|
||||
if isinstance(parse_result, JiraPayloadSkipped):
|
||||
logger.info(
|
||||
'[Jira] Webhook skipped', extra={'reason': parse_result.skip_reason}
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(parse_result, JiraPayloadError):
|
||||
logger.warning(
|
||||
'[Jira] Webhook parse failed', extra={'error': parse_result.error}
|
||||
)
|
||||
return
|
||||
|
||||
payload = parse_result.payload
|
||||
logger.info(
|
||||
'[Jira] Processing webhook',
|
||||
extra={
|
||||
'event_type': payload.event_type.value,
|
||||
'issue_key': payload.issue_key,
|
||||
'user_email': payload.user_email,
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Validate workspace
|
||||
workspace = await self._get_active_workspace(payload)
|
||||
if not workspace:
|
||||
return
|
||||
|
||||
# Step 3: Authenticate user
|
||||
jira_user, saas_user_auth = await self._authenticate_user(payload, workspace)
|
||||
if not jira_user or not saas_user_auth:
|
||||
return
|
||||
|
||||
# Step 4: Create view (includes issue details fetch and repo selection)
|
||||
decrypted_api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
try:
|
||||
view = await JiraFactory.create_view(
|
||||
payload=payload,
|
||||
workspace=workspace,
|
||||
user=jira_user,
|
||||
user_auth=saas_user_auth,
|
||||
decrypted_api_key=decrypted_api_key,
|
||||
)
|
||||
except RepositoryNotFoundError as e:
|
||||
logger.warning(
|
||||
'[Jira] Repository not found',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
await self._send_error_from_payload(payload, workspace, str(e))
|
||||
return
|
||||
except StartingConvoException as e:
|
||||
logger.warning(
|
||||
'[Jira] View creation failed',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
await self._send_error_from_payload(payload, workspace, str(e))
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Unexpected error creating view',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
)
|
||||
return
|
||||
|
||||
# Step 5: Start job
|
||||
await self.start_job(view)
|
||||
|
||||
async def _get_active_workspace(
|
||||
self, payload: JiraWebhookPayload
|
||||
) -> JiraWorkspace | None:
|
||||
"""Validate and return the workspace for the webhook.
|
||||
|
||||
Returns None if:
|
||||
- Workspace not found
|
||||
- Workspace is inactive
|
||||
- Request is from service account (to prevent recursion)
|
||||
"""
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
payload.workspace_name
|
||||
)
|
||||
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
'[Jira] Workspace not found',
|
||||
extra={'workspace_name': payload.workspace_name},
|
||||
)
|
||||
# Can't send error without workspace credentials
|
||||
return None
|
||||
|
||||
# Prevent recursive triggers from service account
|
||||
if payload.user_email == workspace.svc_acc_email:
|
||||
logger.debug(
|
||||
'[Jira] Ignoring service account trigger',
|
||||
extra={'workspace_name': payload.workspace_name},
|
||||
)
|
||||
return None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(
|
||||
'[Jira] Workspace inactive',
|
||||
extra={'workspace_id': workspace.id, 'status': workspace.status},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload, workspace, 'Jira integration is not active for your workspace.'
|
||||
)
|
||||
return None
|
||||
|
||||
return workspace
|
||||
|
||||
async def _authenticate_user(
|
||||
self, payload: JiraWebhookPayload, workspace: JiraWorkspace
|
||||
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Jira user by Keycloak user ID and workspace ID
|
||||
"""Authenticate the Jira user and get OpenHands auth."""
|
||||
jira_user = await self.integration_store.get_active_user(
|
||||
jira_user_id, workspace_id
|
||||
payload.account_id, workspace.id
|
||||
)
|
||||
|
||||
if not jira_user:
|
||||
logger.warning(
|
||||
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
|
||||
'[Jira] User not found or inactive',
|
||||
extra={
|
||||
'account_id': payload.account_id,
|
||||
'user_email': payload.user_email,
|
||||
'workspace_id': workspace.id,
|
||||
},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
f'User {payload.user_email} is not authenticated or active in the Jira integration.',
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_user.keycloak_user_id
|
||||
)
|
||||
|
||||
if not saas_user_auth:
|
||||
logger.warning(
|
||||
'[Jira] Failed to get OpenHands auth',
|
||||
extra={
|
||||
'keycloak_user_id': jira_user.keycloak_user_id,
|
||||
'user_email': payload.user_email,
|
||||
},
|
||||
)
|
||||
await self._send_error_from_payload(
|
||||
payload,
|
||||
workspace,
|
||||
f'User {payload.user_email} is not authenticated with OpenHands.',
|
||||
)
|
||||
return None, None
|
||||
|
||||
return jira_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
account_id,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira view
|
||||
jira_view = await JiraFactory.create_jira_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_view, JiraExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_view: JiraViewInterface):
|
||||
async def start_job(self, view: JiraViewInterface):
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
@@ -345,114 +265,101 @@ class JiraManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraUser = jira_view.jira_user
|
||||
logger.info(
|
||||
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_view.job_context.issue_key}',
|
||||
'[Jira] Starting job',
|
||||
extra={
|
||||
'issue_key': view.payload.issue_key,
|
||||
'user_id': view.jira_user.keycloak_user_id,
|
||||
'selected_repo': view.selected_repo,
|
||||
},
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await jira_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
conversation_id = await view.create_or_update_conversation(self.jinja_env)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
|
||||
'[Jira] Conversation created',
|
||||
extra={
|
||||
'conversation_id': conversation_id,
|
||||
'issue_key': view.payload.issue_key,
|
||||
},
|
||||
)
|
||||
|
||||
# Register callback processor for updates
|
||||
if isinstance(jira_view, JiraNewConversationView):
|
||||
if isinstance(view, JiraNewConversationView):
|
||||
processor = JiraCallbackProcessor(
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
workspace_name=jira_view.jira_workspace.name,
|
||||
issue_key=view.payload.issue_key,
|
||||
workspace_name=view.jira_workspace.name,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created callback processor for conversation {conversation_id}'
|
||||
'[Jira] Callback processor registered',
|
||||
extra={'conversation_id': conversation_id},
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_view.get_response_msg()
|
||||
# Send success response
|
||||
msg_info = view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Jira] Missing settings error: {str(e)}')
|
||||
logger.warning(
|
||||
'[Jira] Missing settings error',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
|
||||
logger.warning(
|
||||
'[Jira] LLM authentication error',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except SessionExpiredError as e:
|
||||
logger.warning(f'[Jira] Session expired: {str(e)}')
|
||||
logger.warning(
|
||||
'[Jira] Session expired',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
msg_info = get_session_expired_message()
|
||||
|
||||
except StartingConvoException as e:
|
||||
logger.warning(
|
||||
'[Jira] Conversation start failed',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
msg_info = str(e)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
'[Jira] Unexpected error starting job',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
) -> Tuple[str, str]:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
await self._send_comment(view, msg_info)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Message,
|
||||
message: str,
|
||||
issue_key: str,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
"""Send a comment to a Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
jira_cloud_id: The Jira Cloud ID
|
||||
svc_acc_email: Service account email for authentication
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message.message}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
@@ -460,54 +367,53 @@ class JiraManager(Manager):
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraWorkspace | None,
|
||||
):
|
||||
"""Send error comment to Jira issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira] Cannot send error comment - no workspace available')
|
||||
return
|
||||
async def _send_comment(self, view: JiraViewInterface, msg: str):
|
||||
"""Send a comment using credentials from the view."""
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
msg,
|
||||
issue_key=view.payload.issue_key,
|
||||
jira_cloud_id=view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to send comment',
|
||||
extra={'issue_key': view.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
|
||||
async def _send_error_from_payload(
|
||||
self,
|
||||
payload: JiraWebhookPayload,
|
||||
workspace: JiraWorkspace,
|
||||
error_msg: str,
|
||||
):
|
||||
"""Send error comment before view is created (using payload directly)."""
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=job_context.issue_key,
|
||||
error_msg,
|
||||
issue_key=payload.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
'[Jira] Failed to send error comment',
|
||||
extra={'issue_key': payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
|
||||
def get_workspace_name_from_payload(self, payload: dict) -> str | None:
|
||||
"""Extract workspace name from Jira webhook payload.
|
||||
|
||||
This method is used by the route for signature verification.
|
||||
"""
|
||||
parse_result = self.payload_parser.parse(payload)
|
||||
if isinstance(parse_result, JiraPayloadSuccess):
|
||||
return parse_result.payload.workspace_name
|
||||
return None
|
||||
|
||||
265
enterprise/integrations/jira/jira_payload.py
Normal file
265
enterprise/integrations/jira/jira_payload.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""Centralized payload parsing for Jira webhooks.
|
||||
|
||||
This module provides a single source of truth for parsing and validating
|
||||
Jira webhook payloads, replacing scattered parsing logic throughout the codebase.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class JiraEventType(Enum):
|
||||
"""Types of Jira events we handle."""
|
||||
|
||||
LABELED_TICKET = 'labeled_ticket'
|
||||
COMMENT_MENTION = 'comment_mention'
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraWebhookPayload:
|
||||
"""Normalized, validated representation of a Jira webhook payload.
|
||||
|
||||
This immutable dataclass replaces JobContext and provides a single
|
||||
source of truth for all webhook data. All parsing happens in
|
||||
JiraPayloadParser, ensuring consistent validation.
|
||||
"""
|
||||
|
||||
event_type: JiraEventType
|
||||
raw_event: str # Original webhookEvent value
|
||||
|
||||
# Issue data
|
||||
issue_id: str
|
||||
issue_key: str
|
||||
|
||||
# User data
|
||||
user_email: str
|
||||
display_name: str
|
||||
account_id: str
|
||||
|
||||
# Workspace data (derived from issue self URL)
|
||||
workspace_name: str
|
||||
base_api_url: str
|
||||
|
||||
# Event-specific data
|
||||
comment_body: str = '' # For comment events
|
||||
|
||||
@property
|
||||
def user_msg(self) -> str:
|
||||
"""Alias for comment_body for backward compatibility."""
|
||||
return self.comment_body
|
||||
|
||||
|
||||
class JiraPayloadParseError(Exception):
|
||||
"""Raised when payload parsing fails."""
|
||||
|
||||
def __init__(self, reason: str, event_type: str | None = None):
|
||||
self.reason = reason
|
||||
self.event_type = event_type
|
||||
super().__init__(reason)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadSuccess:
|
||||
"""Result when parsing succeeds."""
|
||||
|
||||
payload: JiraWebhookPayload
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadSkipped:
|
||||
"""Result when event is intentionally skipped."""
|
||||
|
||||
skip_reason: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JiraPayloadError:
|
||||
"""Result when parsing fails due to invalid data."""
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
JiraPayloadParseResult = JiraPayloadSuccess | JiraPayloadSkipped | JiraPayloadError
|
||||
|
||||
|
||||
class JiraPayloadParser:
|
||||
"""Centralized parser for Jira webhook payloads.
|
||||
|
||||
This class provides a single entry point for parsing webhooks,
|
||||
determining event types, and extracting all necessary fields.
|
||||
Replaces scattered parsing in JiraFactory and JiraManager.
|
||||
"""
|
||||
|
||||
def __init__(self, oh_label: str, inline_oh_label: str):
|
||||
"""Initialize parser with OpenHands label configuration.
|
||||
|
||||
Args:
|
||||
oh_label: Label that triggers OpenHands (e.g., 'openhands')
|
||||
inline_oh_label: Mention that triggers OpenHands (e.g., '@openhands')
|
||||
"""
|
||||
self.oh_label = oh_label
|
||||
self.inline_oh_label = inline_oh_label
|
||||
|
||||
def parse(self, raw_payload: dict) -> JiraPayloadParseResult:
|
||||
"""Parse a raw webhook payload into a normalized JiraWebhookPayload.
|
||||
|
||||
Args:
|
||||
raw_payload: The raw webhook payload dict from Jira
|
||||
|
||||
Returns:
|
||||
One of:
|
||||
- JiraPayloadSuccess: Valid, actionable event with payload
|
||||
- JiraPayloadSkipped: Event we intentionally don't process
|
||||
- JiraPayloadError: Malformed payload we expected to process
|
||||
"""
|
||||
webhook_event = raw_payload.get('webhookEvent', '')
|
||||
|
||||
logger.debug(
|
||||
'[Jira] Parsing webhook payload', extra={'webhook_event': webhook_event}
|
||||
)
|
||||
|
||||
if webhook_event == 'jira:issue_updated':
|
||||
return self._parse_label_event(raw_payload, webhook_event)
|
||||
elif webhook_event == 'comment_created':
|
||||
return self._parse_comment_event(raw_payload, webhook_event)
|
||||
else:
|
||||
return JiraPayloadSkipped(f'Unhandled webhook event type: {webhook_event}')
|
||||
|
||||
def _parse_label_event(
|
||||
self, payload: dict, webhook_event: str
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Parse an issue_updated event for label changes."""
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
|
||||
# Extract labels that were added
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if self.oh_label not in labels:
|
||||
return JiraPayloadSkipped(
|
||||
f"Label event does not contain '{self.oh_label}' label"
|
||||
)
|
||||
|
||||
# For label events, user data comes from 'user' field
|
||||
user_data = payload.get('user', {})
|
||||
return self._extract_and_validate(
|
||||
payload=payload,
|
||||
user_data=user_data,
|
||||
event_type=JiraEventType.LABELED_TICKET,
|
||||
webhook_event=webhook_event,
|
||||
comment_body='',
|
||||
)
|
||||
|
||||
def _parse_comment_event(
|
||||
self, payload: dict, webhook_event: str
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Parse a comment_created event."""
|
||||
comment_data = payload.get('comment', {})
|
||||
comment_body = comment_data.get('body', '')
|
||||
|
||||
if not self._has_mention(comment_body):
|
||||
return JiraPayloadSkipped(
|
||||
f"Comment does not mention '{self.inline_oh_label}'"
|
||||
)
|
||||
|
||||
# For comment events, user data comes from 'comment.author'
|
||||
user_data = comment_data.get('author', {})
|
||||
return self._extract_and_validate(
|
||||
payload=payload,
|
||||
user_data=user_data,
|
||||
event_type=JiraEventType.COMMENT_MENTION,
|
||||
webhook_event=webhook_event,
|
||||
comment_body=comment_body,
|
||||
)
|
||||
|
||||
def _has_mention(self, text: str) -> bool:
|
||||
"""Check if text contains an exact mention of OpenHands."""
|
||||
from integrations.utils import has_exact_mention
|
||||
|
||||
return has_exact_mention(text, self.inline_oh_label)
|
||||
|
||||
def _extract_and_validate(
|
||||
self,
|
||||
payload: dict,
|
||||
user_data: dict,
|
||||
event_type: JiraEventType,
|
||||
webhook_event: str,
|
||||
comment_body: str,
|
||||
) -> JiraPayloadParseResult:
|
||||
"""Extract common fields and validate required data is present."""
|
||||
issue_data = payload.get('issue', {})
|
||||
|
||||
# Extract all fields with empty string defaults (makes them str type)
|
||||
issue_id = issue_data.get('id', '')
|
||||
issue_key = issue_data.get('key', '')
|
||||
user_email = user_data.get('emailAddress', '')
|
||||
display_name = user_data.get('displayName', '')
|
||||
account_id = user_data.get('accountId', '')
|
||||
base_api_url, workspace_name = self._extract_workspace_from_url(
|
||||
issue_data.get('self', '')
|
||||
)
|
||||
|
||||
# Validate required fields
|
||||
missing: list[str] = []
|
||||
if not issue_id:
|
||||
missing.append('issue.id')
|
||||
if not issue_key:
|
||||
missing.append('issue.key')
|
||||
if not display_name:
|
||||
missing.append('user.displayName')
|
||||
if not account_id:
|
||||
missing.append('user.accountId')
|
||||
if not workspace_name:
|
||||
missing.append('workspace_name (derived from issue.self)')
|
||||
if not base_api_url:
|
||||
missing.append('base_api_url (derived from issue.self)')
|
||||
|
||||
if missing:
|
||||
return JiraPayloadError(f"Missing required fields: {', '.join(missing)}")
|
||||
|
||||
return JiraPayloadSuccess(
|
||||
JiraWebhookPayload(
|
||||
event_type=event_type,
|
||||
raw_event=webhook_event,
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
account_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
comment_body=comment_body,
|
||||
)
|
||||
)
|
||||
|
||||
def _extract_workspace_from_url(self, self_url: str) -> tuple[str, str]:
|
||||
"""Extract base API URL and workspace name from issue self URL.
|
||||
|
||||
Args:
|
||||
self_url: The 'self' URL from the issue data
|
||||
|
||||
Returns:
|
||||
Tuple of (base_api_url, workspace_name)
|
||||
"""
|
||||
if not self_url:
|
||||
return '', ''
|
||||
|
||||
# Extract base URL (everything before /rest/)
|
||||
if '/rest/' in self_url:
|
||||
base_api_url = self_url.split('/rest/')[0]
|
||||
else:
|
||||
parsed = urlparse(self_url)
|
||||
base_api_url = f'{parsed.scheme}://{parsed.netloc}'
|
||||
|
||||
# Extract workspace name (hostname)
|
||||
parsed = urlparse(base_api_url)
|
||||
workspace_name = parsed.hostname or ''
|
||||
|
||||
return base_api_url, workspace_name
|
||||
@@ -1,26 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
"""Type definitions and interfaces for Jira integration."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
|
||||
|
||||
class JiraViewInterface(ABC):
|
||||
"""Interface for Jira views that handle different types of Jira interactions."""
|
||||
"""Interface for Jira views that handle different types of Jira interactions.
|
||||
|
||||
job_context: JobContext
|
||||
Views hold the webhook payload directly rather than duplicating fields,
|
||||
and fetch issue details lazily when needed.
|
||||
"""
|
||||
|
||||
# Core data - view holds these references
|
||||
payload: 'JiraWebhookPayload'
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
|
||||
# Mutable state set during processing
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch and cache issue title and description from Jira API.
|
||||
|
||||
Returns:
|
||||
Tuple of (issue_title, issue_description)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -35,6 +51,21 @@ class JiraViewInterface(ABC):
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
"""Exception raised when starting a conversation fails.
|
||||
|
||||
This provides user-friendly error messages that can be sent back to Jira.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RepositoryNotFoundError(Exception):
|
||||
"""Raised when a repository cannot be determined from the issue.
|
||||
|
||||
This is a separate error domain from StartingConvoException - it represents
|
||||
a precondition failure (no repo configured/found) rather than a conversation
|
||||
creation failure. The manager catches this and converts it to a user-friendly
|
||||
message.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,8 +1,21 @@
|
||||
from dataclasses import dataclass
|
||||
"""Jira view implementations and factory.
|
||||
|
||||
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
Views are responsible for:
|
||||
- Holding the webhook payload and auth context
|
||||
- Lazy-loading issue details from Jira API when needed
|
||||
- Creating conversations with the selected repository
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import httpx
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
from integrations.jira.jira_types import (
|
||||
JiraViewInterface,
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
|
||||
from jinja2 import Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
@@ -10,55 +23,147 @@ from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
integration_store = JiraIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraNewConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
"""View for creating a new Jira conversation.
|
||||
|
||||
This view holds the webhook payload directly and lazily fetches
|
||||
issue details when needed for rendering templates.
|
||||
"""
|
||||
|
||||
payload: JiraWebhookPayload
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
selected_repo: str | None = None
|
||||
conversation_id: str = ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
# Lazy-loaded issue details (cached after first fetch)
|
||||
_issue_title: str | None = field(default=None, repr=False)
|
||||
_issue_description: str | None = field(default=None, repr=False)
|
||||
|
||||
# Decrypted API key (set by factory)
|
||||
_decrypted_api_key: str = field(default='', repr=False)
|
||||
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch issue details from Jira API (cached after first call).
|
||||
|
||||
Returns:
|
||||
Tuple of (issue_title, issue_description)
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If issue details cannot be fetched
|
||||
"""
|
||||
if self._issue_title is not None and self._issue_description is not None:
|
||||
return self._issue_title, self._issue_description
|
||||
|
||||
try:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{self.jira_workspace.jira_cloud_id}/rest/api/2/issue/{self.payload.issue_key}'
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(
|
||||
url,
|
||||
auth=(
|
||||
self.jira_workspace.svc_acc_email,
|
||||
self._decrypted_api_key,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise StartingConvoException(
|
||||
f'Issue {self.payload.issue_key} not found.'
|
||||
)
|
||||
|
||||
self._issue_title = issue_payload.get('fields', {}).get('summary', '')
|
||||
self._issue_description = (
|
||||
issue_payload.get('fields', {}).get('description', '') or ''
|
||||
)
|
||||
|
||||
if not self._issue_title:
|
||||
raise StartingConvoException(
|
||||
f'Issue {self.payload.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Fetched issue details',
|
||||
extra={
|
||||
'issue_key': self.payload.issue_key,
|
||||
'has_description': bool(self._issue_description),
|
||||
},
|
||||
)
|
||||
|
||||
return self._issue_title, self._issue_description
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
'[Jira] Failed to fetch issue details',
|
||||
extra={
|
||||
'issue_key': self.payload.issue_key,
|
||||
'status': e.response.status_code,
|
||||
},
|
||||
)
|
||||
raise StartingConvoException(
|
||||
f'Failed to fetch issue details: HTTP {e.response.status_code}'
|
||||
)
|
||||
except Exception as e:
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
'[Jira] Failed to fetch issue details',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
)
|
||||
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get instructions for the conversation.
|
||||
|
||||
This fetches issue details if not already cached.
|
||||
|
||||
Returns:
|
||||
Tuple of (system_instructions, user_message)
|
||||
"""
|
||||
issue_title, issue_description = await self.get_issue_details()
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_key=self.payload.issue_key,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
user_message=self.payload.user_msg,
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira conversation"""
|
||||
"""Create a new Jira conversation.
|
||||
|
||||
Returns:
|
||||
The conversation ID
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If conversation creation fails
|
||||
"""
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -76,149 +181,259 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Jira] Created conversation {self.conversation_id}')
|
||||
logger.info(
|
||||
'[Jira] Created conversation',
|
||||
extra={
|
||||
'conversation_id': self.conversation_id,
|
||||
'issue_key': self.payload.issue_key,
|
||||
'selected_repo': self.selected_repo,
|
||||
},
|
||||
)
|
||||
|
||||
# Store Jira conversation mapping
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_id=self.payload.issue_id,
|
||||
issue_key=self.payload.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
'[Jira] Failed to create conversation',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
"""Get the response message to send back to Jira."""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraExistingConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
# Should we raise here if there are no providers?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
return f"I'm on it! {self.payload.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraFactory:
|
||||
"""Factory for creating Jira views based on message content"""
|
||||
"""Factory for creating Jira views.
|
||||
|
||||
The factory is responsible for:
|
||||
- Creating the appropriate view type
|
||||
- Inferring and selecting the repository
|
||||
- Validating all required data is available
|
||||
|
||||
Repository selection happens here so that view creation either
|
||||
succeeds with a valid repo or fails with a clear error.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def create_jira_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_user: JiraUser,
|
||||
jira_workspace: JiraWorkspace,
|
||||
async def _create_provider_handler(user_auth: UserAuth) -> ProviderHandler | None:
|
||||
"""Create a ProviderHandler for the user."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return None
|
||||
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
|
||||
return ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_potential_repos(
|
||||
issue_key: str,
|
||||
issue_title: str,
|
||||
issue_description: str,
|
||||
user_msg: str,
|
||||
) -> list[str]:
|
||||
"""Extract potential repository names from issue content.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If no potential repos found in text.
|
||||
"""
|
||||
search_text = f'{issue_title}\n{issue_description}\n{user_msg}'
|
||||
potential_repos = infer_repo_from_message(search_text)
|
||||
|
||||
if not potential_repos:
|
||||
raise RepositoryNotFoundError(
|
||||
'Could not determine which repository to use. '
|
||||
'Please mention the repository (e.g., owner/repo) in the issue description or comment.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Found potential repositories in issue content',
|
||||
extra={'issue_key': issue_key, 'potential_repos': potential_repos},
|
||||
)
|
||||
return potential_repos
|
||||
|
||||
@staticmethod
|
||||
async def _verify_repos(
|
||||
issue_key: str,
|
||||
potential_repos: list[str],
|
||||
provider_handler: ProviderHandler,
|
||||
) -> list[str]:
|
||||
"""Verify which repos the user has access to."""
|
||||
verified_repos: list[str] = []
|
||||
|
||||
for repo_name in potential_repos:
|
||||
try:
|
||||
repository = await provider_handler.verify_repo_provider(repo_name)
|
||||
verified_repos.append(repository.full_name)
|
||||
logger.debug(
|
||||
'[Jira] Repository verification succeeded',
|
||||
extra={'issue_key': issue_key, 'repository': repository.full_name},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
'[Jira] Repository verification failed',
|
||||
extra={
|
||||
'issue_key': issue_key,
|
||||
'repo_name': repo_name,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
return verified_repos
|
||||
|
||||
@staticmethod
|
||||
def _select_single_repo(
|
||||
issue_key: str,
|
||||
potential_repos: list[str],
|
||||
verified_repos: list[str],
|
||||
) -> str:
|
||||
"""Select exactly one repo from verified repos.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If zero or multiple repos verified.
|
||||
"""
|
||||
if len(verified_repos) == 0:
|
||||
raise RepositoryNotFoundError(
|
||||
f'Could not access any of the mentioned repositories: {", ".join(potential_repos)}. '
|
||||
'Please ensure you have access to the repository and it exists.'
|
||||
)
|
||||
|
||||
if len(verified_repos) > 1:
|
||||
raise RepositoryNotFoundError(
|
||||
f'Multiple repositories found: {", ".join(verified_repos)}. '
|
||||
'Please specify exactly one repository in the issue description or comment.'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'[Jira] Verified repository access',
|
||||
extra={'issue_key': issue_key, 'repository': verified_repos[0]},
|
||||
)
|
||||
return verified_repos[0]
|
||||
|
||||
@staticmethod
|
||||
async def _infer_repository(
|
||||
payload: JiraWebhookPayload,
|
||||
user_auth: UserAuth,
|
||||
issue_title: str,
|
||||
issue_description: str,
|
||||
) -> str:
|
||||
"""Infer and verify the repository from issue content.
|
||||
|
||||
Raises:
|
||||
RepositoryNotFoundError: If no valid repository can be determined.
|
||||
"""
|
||||
provider_handler = await JiraFactory._create_provider_handler(user_auth)
|
||||
if not provider_handler:
|
||||
raise RepositoryNotFoundError(
|
||||
'No Git provider connected. Please connect a Git provider in OpenHands settings.'
|
||||
)
|
||||
|
||||
potential_repos = JiraFactory._extract_potential_repos(
|
||||
payload.issue_key, issue_title, issue_description, payload.user_msg
|
||||
)
|
||||
|
||||
verified_repos = await JiraFactory._verify_repos(
|
||||
payload.issue_key, potential_repos, provider_handler
|
||||
)
|
||||
|
||||
return JiraFactory._select_single_repo(
|
||||
payload.issue_key, potential_repos, verified_repos
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def create_view(
|
||||
payload: JiraWebhookPayload,
|
||||
workspace: JiraWorkspace,
|
||||
user: JiraUser,
|
||||
user_auth: UserAuth,
|
||||
decrypted_api_key: str,
|
||||
) -> JiraViewInterface:
|
||||
"""Create appropriate Jira view based on the message and user state"""
|
||||
"""Create a Jira view with repository already selected.
|
||||
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
This factory method:
|
||||
1. Creates the view with payload and auth context
|
||||
2. Fetches issue details (needed for repo inference)
|
||||
3. Infers and selects the repository
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_user.id
|
||||
If any step fails, an appropriate exception is raised with
|
||||
a user-friendly message.
|
||||
|
||||
Args:
|
||||
payload: Parsed webhook payload
|
||||
workspace: The Jira workspace
|
||||
user: The Jira user
|
||||
user_auth: OpenHands user authentication
|
||||
decrypted_api_key: Decrypted service account API key
|
||||
|
||||
Returns:
|
||||
A JiraViewInterface with selected_repo populated
|
||||
|
||||
Raises:
|
||||
StartingConvoException: If view creation fails
|
||||
RepositoryNotFoundError: If repository cannot be determined
|
||||
"""
|
||||
logger.info(
|
||||
'[Jira] Creating view',
|
||||
extra={
|
||||
'issue_key': payload.issue_key,
|
||||
'event_type': payload.event_type.value,
|
||||
},
|
||||
)
|
||||
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return JiraExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
# Create the view
|
||||
view = JiraNewConversationView(
|
||||
payload=payload,
|
||||
saas_user_auth=user_auth,
|
||||
jira_user=user,
|
||||
jira_workspace=workspace,
|
||||
_decrypted_api_key=decrypted_api_key,
|
||||
)
|
||||
|
||||
# Fetch issue details (needed for repo inference)
|
||||
try:
|
||||
issue_title, issue_description = await view.get_issue_details()
|
||||
except StartingConvoException:
|
||||
raise # Re-raise with original message
|
||||
except Exception as e:
|
||||
raise StartingConvoException(f'Failed to fetch issue details: {str(e)}')
|
||||
|
||||
# Infer and select repository
|
||||
selected_repo = await JiraFactory._infer_repository(
|
||||
payload=payload,
|
||||
user_auth=user_auth,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
)
|
||||
|
||||
view.selected_repo = selected_repo
|
||||
|
||||
logger.info(
|
||||
'[Jira] View created successfully',
|
||||
extra={
|
||||
'issue_key': payload.issue_key,
|
||||
'selected_repo': selected_repo,
|
||||
},
|
||||
)
|
||||
|
||||
return view
|
||||
|
||||
@@ -418,7 +418,7 @@ class JiraDcManager(Manager):
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
msg_info,
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -456,12 +456,19 @@ class JiraDcManager(Manager):
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
self, message: str, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
):
|
||||
"""Send message/comment to Jira DC issue."""
|
||||
"""Send message/comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
base_api_url: The base API URL for the Jira DC instance
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message.message}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
@@ -481,7 +488,7 @@ class JiraDcManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
error_msg,
|
||||
issue_key=job_context.issue_key,
|
||||
base_api_url=job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -502,7 +509,7 @@ class JiraDcManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
comment_msg,
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -19,7 +19,7 @@ class JiraDcViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
@@ -61,7 +61,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -113,7 +113,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
@@ -167,7 +167,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -408,7 +408,7 @@ class LinearManager(Manager):
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
msg_info,
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
@@ -473,8 +473,14 @@ class LinearManager(Manager):
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue."""
|
||||
async def send_message(self, message: str, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_id: The Linear issue ID to comment on
|
||||
api_key: The Linear API key for authentication
|
||||
"""
|
||||
query = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
@@ -485,7 +491,7 @@ class LinearManager(Manager):
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||
variables = {'input': {'issueId': issue_id, 'body': message}}
|
||||
return await self._query_api(query, variables, api_key)
|
||||
|
||||
async def _send_error_comment(
|
||||
@@ -498,9 +504,7 @@ class LinearManager(Manager):
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||
)
|
||||
await self.send_message(error_msg, issue_id, api_key)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||
|
||||
@@ -517,7 +521,7 @@ class LinearManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
comment_msg,
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class LinearViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
@@ -58,7 +58,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -110,7 +110,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
@@ -164,7 +164,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
@@ -12,19 +13,15 @@ class Manager(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
raise NotImplementedError
|
||||
def send_message(self, message: str, *args: Any, **kwargs: Any):
|
||||
"""Send message to integration from OpenHands server.
|
||||
|
||||
@abstractmethod
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
"Confirm that a job is being requested"
|
||||
Args:
|
||||
message: The message content to send (plain text string).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
raise NotImplementedError
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -16,8 +17,16 @@ class SourceType(str, Enum):
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message model for incoming webhook payloads from integrations.
|
||||
|
||||
Note: This model is intended for INCOMING messages only.
|
||||
For outgoing messages (e.g., sending comments to GitHub/GitLab),
|
||||
pass strings directly to the send_message methods instead of
|
||||
wrapping them in a Message object.
|
||||
"""
|
||||
|
||||
source: SourceType
|
||||
message: str | dict
|
||||
message: dict[str, Any]
|
||||
ephemeral: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
@@ -14,6 +14,7 @@ class ResolverUserContext(UserContext):
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
self._provider_handler: ProviderHandler | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
@@ -29,17 +30,34 @@ class ResolverUserContext(UserContext):
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def get_authenticated_git_url(self, repository: str) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
async def _get_provider_handler(self) -> ProviderHandler:
|
||||
"""Get or create a ProviderHandler for git operations."""
|
||||
if self._provider_handler is None:
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('No provider tokens available')
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
self._provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_id=user_id
|
||||
)
|
||||
return self._provider_handler
|
||||
|
||||
async def get_authenticated_git_url(
|
||||
self, repository: str, is_optional: bool = False
|
||||
) -> str:
|
||||
provider_handler = await self._get_provider_handler()
|
||||
url = await provider_handler.get_authenticated_git_url(
|
||||
repository, is_optional=is_optional
|
||||
)
|
||||
return url
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token from git_provider_tokens
|
||||
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens:
|
||||
return provider_tokens.get(provider_type)
|
||||
provider_token = provider_tokens.get(provider_type)
|
||||
if provider_token and provider_token.token:
|
||||
return provider_token.token.get_secret_value()
|
||||
return None
|
||||
|
||||
async def get_provider_tokens(self) -> PROVIDER_TOKEN_TYPE | None:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from integrations.manager import Manager
|
||||
@@ -22,7 +23,8 @@ from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -63,12 +65,11 @@ class SlackManager(Manager):
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
|
||||
)
|
||||
slack_user = result.scalar_one_or_none()
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
@@ -202,9 +203,7 @@ class SlackManager(Manager):
|
||||
msg = self.login_link.format(link)
|
||||
|
||||
logger.info('slack_not_yet_authenticated')
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||
)
|
||||
await self.send_message(msg, slack_view, ephemeral=True)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, slack_view):
|
||||
@@ -212,27 +211,40 @@ class SlackManager(Manager):
|
||||
|
||||
await self.start_job(slack_view)
|
||||
|
||||
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||
async def send_message(
|
||||
self,
|
||||
message: str | dict[str, Any],
|
||||
slack_view: SlackViewInterface,
|
||||
ephemeral: bool = False,
|
||||
):
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content. Can be a string (for simple text) or
|
||||
a dict with 'text' and 'blocks' keys (for structured messages).
|
||||
slack_view: The Slack view object containing channel/thread info.
|
||||
ephemeral: If True, send as an ephemeral message visible only to the user.
|
||||
"""
|
||||
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||
if message.ephemeral and isinstance(message.message, str):
|
||||
if ephemeral and isinstance(message, str):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
markdown_text=message,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
)
|
||||
elif message.ephemeral and isinstance(message.message, dict):
|
||||
elif ephemeral and isinstance(message, dict):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
text=message.message['text'],
|
||||
blocks=message.message['blocks'],
|
||||
text=message['text'],
|
||||
blocks=message['blocks'],
|
||||
)
|
||||
else:
|
||||
await client.chat_postMessage(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
markdown_text=message,
|
||||
thread_ts=slack_view.message_ts,
|
||||
)
|
||||
|
||||
@@ -279,10 +291,7 @@ class SlackManager(Manager):
|
||||
repos, slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||
slack_view,
|
||||
)
|
||||
await self.send_message(repo_selection_msg, slack_view, ephemeral=True)
|
||||
|
||||
return False
|
||||
|
||||
@@ -368,9 +377,10 @@ class SlackManager(Manager):
|
||||
except StartingConvoException as e:
|
||||
msg_info = str(e)
|
||||
|
||||
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||
await self.send_message(msg_info, slack_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Slack]: Error starting job')
|
||||
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', slack_view
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
v1_enabled: bool
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class SlackUnkownUserView(SlackViewInterface):
|
||||
team_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
@@ -118,7 +118,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
return block['user_id']
|
||||
return ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
@@ -189,6 +189,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
'v1_enabled': v1_enabled,
|
||||
},
|
||||
@@ -197,6 +198,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
v1_enabled=v1_enabled,
|
||||
@@ -240,7 +242,9 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
self, jinja: Environment, provider_tokens, user_secrets
|
||||
) -> None:
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
@@ -271,7 +275,9 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
|
||||
async def _create_v1_conversation(self, jinja: Environment) -> None:
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
@@ -344,7 +350,7 @@ class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
slack_conversation: SlackConversation
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
@@ -399,10 +405,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
instructions, _ = await self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
)
|
||||
|
||||
async def send_message_to_v1_conversation(self, jinja: Environment):
|
||||
@@ -467,7 +473,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
agent_server_url = get_agent_server_url_from_sandbox(running_sandbox)
|
||||
|
||||
# 4. Prepare the message content
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg, _ = await self._get_instructions(jinja)
|
||||
|
||||
# 5. Create the message request
|
||||
send_message_request = SendMessageRequest(
|
||||
|
||||
53
enterprise/integrations/store_repo_utils.py
Normal file
53
enterprise/integrations/store_repo_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import Repository
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
await repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
await user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
@@ -1,66 +1,99 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
async with a_session_maker() as session:
|
||||
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
|
||||
result = await session.execute(stmt)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
return customer.id
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -71,3 +104,29 @@ async def has_payment_method(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(user_id: str, org: Org):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -38,7 +38,7 @@ class ResolverViewInterface(SummaryExtractionTracker):
|
||||
is_public_repo: bool
|
||||
raw_payload: dict
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -6,15 +6,9 @@ import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.config import get_config
|
||||
from server.constants import WEB_HOST
|
||||
from storage.database import session_maker
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
@@ -125,26 +119,17 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
Returns:
|
||||
True if V1 conversations are enabled for this user, False otherwise
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
|
||||
settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
|
||||
if not settings or settings.v1_enabled is None:
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
return settings.v1_enabled
|
||||
return org.v1_enabled
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
@@ -409,102 +394,46 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
return message + footer
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||
Args:
|
||||
user_msg: Input message that may contain repository references
|
||||
Returns:
|
||||
List of repository names in 'owner/repo' format, empty list if none found
|
||||
"""
|
||||
# Normalize the message by removing extra whitespace and newlines
|
||||
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
|
||||
|
||||
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
|
||||
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
|
||||
# Pattern to match direct owner/repo mentions (e.g., "OpenHands/OpenHands")
|
||||
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||
git_url_pattern = (
|
||||
r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/'
|
||||
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?'
|
||||
r'(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
)
|
||||
|
||||
matches = []
|
||||
# UPDATED: allow {{ owner/repo }} in addition to existing boundaries
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|{{|[\[\(\'":`])' # left boundary
|
||||
r'([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)'
|
||||
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
|
||||
)
|
||||
|
||||
# First, find all Git URLs (highest priority)
|
||||
git_matches = re.findall(git_url_pattern, normalized_msg)
|
||||
for owner, repo in git_matches:
|
||||
# Remove .git extension if present
|
||||
matches: list[str] = []
|
||||
|
||||
# Git URLs first (highest priority)
|
||||
for owner, repo in re.findall(git_url_pattern, normalized_msg):
|
||||
repo = re.sub(r'\.git$', '', repo)
|
||||
matches.append(f'{owner}/{repo}')
|
||||
|
||||
# Second, find all direct owner/repo mentions
|
||||
direct_matches = re.findall(direct_pattern, normalized_msg)
|
||||
for owner, repo in direct_matches:
|
||||
# Direct mentions
|
||||
for owner, repo in re.findall(direct_pattern, normalized_msg):
|
||||
full_match = f'{owner}/{repo}'
|
||||
|
||||
# Skip if it looks like a version number, date, or file path
|
||||
if (
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
|
||||
or repo.endswith('.txt')
|
||||
or repo.endswith('.md') # file extensions
|
||||
or repo.endswith('.py')
|
||||
or repo.endswith('.js')
|
||||
or '.' in repo
|
||||
and len(repo.split('.')) > 2
|
||||
): # complex file paths
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match)
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match)
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match)
|
||||
or repo.endswith(('.txt', '.md', '.py', '.js'))
|
||||
or ('.' in repo and len(repo.split('.')) > 2)
|
||||
):
|
||||
continue
|
||||
|
||||
# Avoid duplicates from Git URLs already found
|
||||
if full_match not in matches:
|
||||
matches.append(full_match)
|
||||
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
# Suppress alembic.runtime.plugins INFO logs during import to prevent non-JSON logs in production
|
||||
# These plugin setup messages would otherwise appear before logging is configured
|
||||
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine # noqa: E402
|
||||
from storage.base import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -28,8 +30,10 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
272
enterprise/migrations/versions/089_create_org_tables.py
Normal file
272
enterprise/migrations/versions/089_create_org_tables.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 089
|
||||
Revises: 088
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '089'
|
||||
down_revision: Union[str, None] = '088'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add already_migrated column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column(
|
||||
'already_migrated',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'confirmation_mode',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column(
|
||||
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
|
||||
),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('v1_enabled', sa.Boolean, nullable=True),
|
||||
sa.Column('condenser_max_size', sa.Integer, nullable=True),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications',
|
||||
sa.Boolean,
|
||||
nullable=True,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop already_migrated column from user_settings table
|
||||
op.drop_column('user_settings', 'already_migrated')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Add git_user_name and git_user_email columns to user table.
|
||||
|
||||
Revision ID: 090
|
||||
Revises: 089
|
||||
Create Date: 2025-01-22
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '090'
|
||||
down_revision = '089'
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_name', sa.String, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('git_user_email', sa.String, nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'git_user_email')
|
||||
op.drop_column('user', 'git_user_name')
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add byor_export_enabled flag to org table.
|
||||
|
||||
Revision ID: 091
|
||||
Revises: 090
|
||||
Create Date: 2025-01-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '091'
|
||||
down_revision: Union[str, None] = '090'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add byor_export_enabled column to org table with default false
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'byor_export_enabled',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Set byor_export_enabled to true for orgs that have completed billing sessions
|
||||
op.execute(
|
||||
sa.text("""
|
||||
UPDATE org SET byor_export_enabled = TRUE
|
||||
WHERE id IN (
|
||||
SELECT DISTINCT org_id FROM billing_sessions
|
||||
WHERE status = 'completed' AND org_id IS NOT NULL
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'byor_export_enabled')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Rename 'user' role to 'member' in role table.
|
||||
|
||||
Revision ID: 092
|
||||
Revises: 091
|
||||
Create Date: 2025-02-12 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '092'
|
||||
down_revision: Union[str, None] = '091'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename 'user' role to 'member' for clarity
|
||||
# This avoids confusion between the 'user' role and the 'user' entity/account
|
||||
op.execute(sa.text("UPDATE role SET name = 'member' WHERE name = 'user'"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert 'member' role back to 'user'
|
||||
op.execute(sa.text("UPDATE role SET name = 'user' WHERE name = 'member'"))
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add pending_free_credits flag to org table.
|
||||
|
||||
Revision ID: 093
|
||||
Revises: 092
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '093'
|
||||
down_revision: Union[str, None] = '092'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add pending_free_credits column to org table with default false.
|
||||
# New orgs will have this set to TRUE at creation time.
|
||||
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""create org_invitation table
|
||||
|
||||
Revision ID: 094
|
||||
Revises: 093
|
||||
Create Date: 2026-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '094'
|
||||
down_revision: Union[str, None] = '093'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create org_invitation table
|
||||
op.create_table(
|
||||
'org_invitation',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('token', sa.String(64), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Foreign key constraints
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'],
|
||||
['org.id'],
|
||||
name='org_invitation_org_fkey',
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['role_id'],
|
||||
['role.id'],
|
||||
name='org_invitation_role_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['inviter_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_inviter_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['accepted_by_user_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_accepter_fkey',
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index(
|
||||
'ix_org_invitation_token',
|
||||
'org_invitation',
|
||||
['token'],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_id',
|
||||
'org_invitation',
|
||||
['org_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_email',
|
||||
'org_invitation',
|
||||
['email'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_status',
|
||||
'org_invitation',
|
||||
['status'],
|
||||
)
|
||||
# Composite index for checking pending invitations
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_email_status',
|
||||
'org_invitation',
|
||||
['org_id', 'email', 'status'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('org_invitation')
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Drop pending_free_credits column from org table.
|
||||
|
||||
Revision ID: 095
|
||||
Revises: 094
|
||||
Create Date: 2025-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '095'
|
||||
down_revision: Union[str, None] = '094'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the pending_free_credits column from org table.
|
||||
# This column was used for tracking free credit eligibility but is no longer needed.
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add pending_free_credits column with default false.
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Create resend_synced_users table.
|
||||
|
||||
Revision ID: 096
|
||||
Revises: 095
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '096'
|
||||
down_revision: Union[str, None] = '095'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
|
||||
op.create_table(
|
||||
'resend_synced_users',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('audience_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'synced_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
|
||||
# Create index on email for fast lookups
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_email',
|
||||
'resend_synced_users',
|
||||
['email'],
|
||||
)
|
||||
|
||||
# Create index on audience_id for filtering by audience
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_audience_id',
|
||||
'resend_synced_users',
|
||||
['audience_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop resend_synced_users table."""
|
||||
op.drop_index(
|
||||
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
|
||||
)
|
||||
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
|
||||
op.drop_table('resend_synced_users')
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add session_api_key_hash to v1_remote_sandbox table
|
||||
|
||||
Revision ID: 097
|
||||
Revises: 096
|
||||
Create Date: 2025-02-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '097'
|
||||
down_revision: Union[str, None] = '096'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add session_api_key_hash column to v1_remote_sandbox table."""
|
||||
op.add_column(
|
||||
'v1_remote_sandbox',
|
||||
sa.Column('session_api_key_hash', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
'v1_remote_sandbox',
|
||||
['session_api_key_hash'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
|
||||
op.drop_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
table_name='v1_remote_sandbox',
|
||||
)
|
||||
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Create verified_models table.
|
||||
|
||||
Revision ID: 098
|
||||
Revises: 097
|
||||
Create Date: 2026-02-26 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '098'
|
||||
down_revision: Union[str, None] = '097'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create verified_models table and seed with current model list."""
|
||||
op.create_table(
|
||||
'verified_models',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('model_name', sa.String(255), nullable=False),
|
||||
sa.Column('provider', sa.String(100), nullable=False),
|
||||
sa.Column(
|
||||
'is_enabled',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
'model_name', 'provider', name='uq_verified_model_provider'
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_verified_models_provider',
|
||||
'verified_models',
|
||||
['provider'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_verified_models_is_enabled',
|
||||
'verified_models',
|
||||
['is_enabled'],
|
||||
)
|
||||
|
||||
# Seed with current openhands provider models
|
||||
models = [
|
||||
('claude-opus-4-5-20251101', 'openhands'),
|
||||
('claude-sonnet-4-5-20250929', 'openhands'),
|
||||
('gpt-5.2-codex', 'openhands'),
|
||||
('gpt-5.2', 'openhands'),
|
||||
('minimax-m2.5', 'openhands'),
|
||||
('gemini-3-pro-preview', 'openhands'),
|
||||
('gemini-3-flash-preview', 'openhands'),
|
||||
('deepseek-chat', 'openhands'),
|
||||
('devstral-medium-2512', 'openhands'),
|
||||
('kimi-k2-0711-preview', 'openhands'),
|
||||
('qwen3-coder-480b', 'openhands'),
|
||||
]
|
||||
|
||||
for model_name, provider in models:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO verified_models (model_name, provider)
|
||||
VALUES (:model_name, :provider)
|
||||
"""
|
||||
).bindparams(model_name=model_name, provider=provider)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop verified_models table."""
|
||||
op.drop_index('ix_verified_models_is_enabled', table_name='verified_models')
|
||||
op.drop_index('ix_verified_models_provider', table_name='verified_models')
|
||||
op.drop_table('verified_models')
|
||||
11683
enterprise/poetry.lock
generated
11683
enterprise/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -44,6 +44,12 @@ httpx = "*"
|
||||
scikit-learn = "^1.7.0"
|
||||
shap = "^0.48.0"
|
||||
google-cloud-recaptcha-enterprise = "^1.24.0"
|
||||
# Dependencies previously only in Dockerfile, now managed via poetry.lock
|
||||
prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -4,6 +4,10 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
@@ -34,14 +38,28 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
accept_router as invitation_accept_router,
|
||||
)
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
invitation_router,
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
override_llm_models_dependency,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -65,6 +83,7 @@ base_app.include_router(api_router) # Add additional route for github auth
|
||||
base_app.include_router(oauth_router) # Add additional route for oauth callback
|
||||
base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes
|
||||
base_app.include_router(saas_user_router) # Add additional route SAAS user calls
|
||||
base_app.include_router(user_app_settings_router) # Add routes for user app settings
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
@@ -73,8 +92,15 @@ base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
|
||||
from integrations.github.github_v1_callback_processor import ( # noqa: E402
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
from server.routes.integration.github import github_integration_router # noqa: E402
|
||||
|
||||
# Bludgeon mypy into not deleting my import
|
||||
logger.debug(f'Loaded {GithubV1CallbackProcessor.__name__}')
|
||||
|
||||
base_app.include_router(
|
||||
github_integration_router
|
||||
) # Add additional route for integration webhook events
|
||||
@@ -86,6 +112,17 @@ if GITLAB_APP_CLIENT_ID:
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
) # Add routes for verified models management
|
||||
|
||||
# Override the default LLM models implementation with SaaS version
|
||||
# This must happen after all routers are included
|
||||
override_llm_models_dependency(base_app)
|
||||
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
|
||||
@@ -38,3 +38,9 @@ class ExpiredError(AuthError):
|
||||
"""Error when a token has expired (Usually the refresh token)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TokenRefreshError(AuthError):
|
||||
"""Error when token refresh fails due to timeout or lock contention"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
|
||||
from server.auth.sheets_client import GoogleSheetsClient
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@@ -9,12 +7,9 @@ class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured."""
|
||||
@@ -36,23 +31,11 @@ class UserVerifier:
|
||||
except Exception:
|
||||
logger.exception(f'Error reading user list file {waitlist}')
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured."""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.debug('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.debug('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
return bool(self.file_users)
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration."""
|
||||
@@ -63,15 +46,6 @@ class UserVerifier:
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = [
|
||||
u.lower() for u in self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
]
|
||||
if username.lower() in sheet_users:
|
||||
logger.debug(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
306
enterprise/server/auth/authorization.py
Normal file
306
enterprise/server/auth/authorization.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Permission-based authorization dependencies for API endpoints.
|
||||
|
||||
This module provides FastAPI dependencies for checking user permissions
|
||||
within organizations. It uses a permission-based authorization model where
|
||||
roles (owner, admin, member) are mapped to specific permissions.
|
||||
|
||||
Permissions are defined in the Permission enum and mapped to roles via
|
||||
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
|
||||
the familiar role-based hierarchy.
|
||||
|
||||
Usage:
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with VIEW_LLM_SETTINGS permission can access
|
||||
...
|
||||
|
||||
@router.patch('/{org_id}/settings')
|
||||
async def update_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with EDIT_LLM_SETTINGS permission can access
|
||||
...
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Permissions that can be assigned to roles."""
|
||||
|
||||
# Secrets
|
||||
MANAGE_SECRETS = 'manage_secrets'
|
||||
|
||||
# MCP
|
||||
MANAGE_MCP = 'manage_mcp'
|
||||
|
||||
# Integrations
|
||||
MANAGE_INTEGRATIONS = 'manage_integrations'
|
||||
|
||||
# Application Settings
|
||||
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
|
||||
|
||||
# API Keys
|
||||
MANAGE_API_KEYS = 'manage_api_keys'
|
||||
|
||||
# LLM Settings
|
||||
VIEW_LLM_SETTINGS = 'view_llm_settings'
|
||||
EDIT_LLM_SETTINGS = 'edit_llm_settings'
|
||||
|
||||
# Billing
|
||||
VIEW_BILLING = 'view_billing'
|
||||
ADD_CREDITS = 'add_credits'
|
||||
|
||||
# Organization Members
|
||||
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
|
||||
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
|
||||
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
|
||||
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
|
||||
|
||||
# Organization Management
|
||||
VIEW_ORG_SETTINGS = 'view_org_settings'
|
||||
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
|
||||
DELETE_ORGANIZATION = 'delete_organization'
|
||||
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
MEMBER = 'member'
|
||||
|
||||
|
||||
# Permission mappings for each role
|
||||
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
RoleName.OWNER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
Permission.CHANGE_USER_ROLE_OWNER,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (synchronous version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = OrgMemberStore.get_org_member_for_current_org(parse_uuid(user_id))
|
||||
else:
|
||||
org_member = OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return RoleStore.get_role_by_id(org_member.role_id)
|
||||
|
||||
|
||||
async def get_user_org_role_async(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization (async version).
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = await OrgMemberStore.get_org_member_for_current_org_async(
|
||||
parse_uuid(user_id)
|
||||
)
|
||||
else:
|
||||
org_member = await OrgMemberStore.get_org_member_async(
|
||||
org_id, parse_uuid(user_id)
|
||||
)
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return await RoleStore.get_role_by_id_async(org_member.role_id)
|
||||
|
||||
|
||||
def get_role_permissions(role_name: str) -> frozenset[Permission]:
|
||||
"""
|
||||
Get the permissions for a role.
|
||||
|
||||
Args:
|
||||
role_name: Name of the role
|
||||
|
||||
Returns:
|
||||
Set of permissions for the role
|
||||
"""
|
||||
try:
|
||||
role_enum = RoleName(role_name)
|
||||
return ROLE_PERMISSIONS.get(role_enum, frozenset())
|
||||
except ValueError:
|
||||
return frozenset()
|
||||
|
||||
|
||||
def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
"""
|
||||
Check if a role has a specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User's Role object
|
||||
permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if the role has the permission
|
||||
"""
|
||||
permissions = get_role_permissions(user_role.name)
|
||||
return permission in permissions
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
...
|
||||
|
||||
Args:
|
||||
permission: The permission required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that validates permission and returns user_id
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role_async(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'User not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if not has_permission(user_role, permission):
|
||||
logger.warning(
|
||||
'Insufficient permissions',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
'required_permission': permission.value,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'Requires {permission.value} permission',
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
@@ -39,6 +39,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
'on',
|
||||
)
|
||||
|
||||
DUPLICATE_EMAIL_CHECK = os.getenv('DUPLICATE_EMAIL_CHECK', 'true') in ('1', 'true')
|
||||
|
||||
# reCAPTCHA Enterprise
|
||||
RECAPTCHA_PROJECT_ID = os.getenv('RECAPTCHA_PROJECT_ID', '').strip()
|
||||
RECAPTCHA_SITE_KEY = os.getenv('RECAPTCHA_SITE_KEY', '').strip()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from storage.blocked_email_domain_store import BlockedEmailDomainStore
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -23,7 +22,7 @@ class DomainBlocker:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
async def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked by querying the database directly via SQL.
|
||||
|
||||
Supports blocking:
|
||||
@@ -45,7 +44,7 @@ class DomainBlocker:
|
||||
|
||||
try:
|
||||
# Query database directly via SQL to check if domain is blocked
|
||||
is_blocked = self.store.is_domain_blocked(domain)
|
||||
is_blocked = await self.store.is_domain_blocked(domain)
|
||||
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
@@ -63,5 +62,5 @@ class DomainBlocker:
|
||||
|
||||
|
||||
# Initialize store and domain blocker
|
||||
_store = BlockedEmailDomainStore(session_maker=session_maker)
|
||||
_store = BlockedEmailDomainStore()
|
||||
domain_blocker = DomainBlocker(store=_store)
|
||||
|
||||
@@ -1,87 +1,11 @@
|
||||
import os
|
||||
|
||||
from integrations.github.github_service import SaaSGitHubService
|
||||
from pydantic import SecretStr
|
||||
from server.auth.sheets_client import GoogleSheetsClient
|
||||
from server.auth.auth_utils import user_verifier
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import GitHubUser
|
||||
|
||||
|
||||
class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured"""
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if not waitlist:
|
||||
logger.debug('GITHUB_USER_LIST_FILE not configured')
|
||||
return
|
||||
|
||||
if not os.path.exists(waitlist):
|
||||
logger.error(f'User list file not found: {waitlist}')
|
||||
raise FileNotFoundError(f'User list file not found: {waitlist}')
|
||||
|
||||
try:
|
||||
with open(waitlist, 'r') as f:
|
||||
self.file_users = [line.strip().lower() for line in f if line.strip()]
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f'Error reading user list file {waitlist}', exc_info=True)
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured"""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.debug('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.debug('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration"""
|
||||
logger.debug(f'Checking if GitHub user {username} is allowed')
|
||||
if self.file_users:
|
||||
if username.lower() in self.file_users:
|
||||
logger.debug(f'User {username} found in text file allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = [
|
||||
u.lower() for u in self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
]
|
||||
if username.lower() in sheet_users:
|
||||
logger.debug(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
user_verifier = UserVerifier()
|
||||
|
||||
|
||||
def is_user_allowed(user_login: str):
|
||||
if user_verifier.is_active() and not user_verifier.is_user_allowed(user_login):
|
||||
logger.warning(f'GitHub user {user_login} not in allow list')
|
||||
|
||||
@@ -1,11 +1,36 @@
|
||||
import asyncio
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import select
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
async def _user_has_gitlab_provider(user_id: str) -> bool:
|
||||
"""Check if the user has authenticated with GitLab.
|
||||
|
||||
Args:
|
||||
user_id: The Keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if the user has a GitLab provider token, False otherwise
|
||||
"""
|
||||
# Lazy import to avoid circular dependency issues at module load time
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import a_session_maker
|
||||
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == user_id,
|
||||
AuthTokens.identity_provider == ProviderType.GITLAB.value,
|
||||
)
|
||||
)
|
||||
return result.scalars().first() is not None
|
||||
|
||||
|
||||
def schedule_gitlab_repo_sync(
|
||||
user_id: str, keycloak_access_token: SecretStr | None = None
|
||||
) -> None:
|
||||
@@ -14,10 +39,20 @@ def schedule_gitlab_repo_sync(
|
||||
Because the outer call is already a background task, we instruct the service
|
||||
to store repository data synchronously (store_in_background=False) to avoid
|
||||
nested background tasks while still keeping the overall operation async.
|
||||
|
||||
The sync is only performed if the user has authenticated with GitLab.
|
||||
"""
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
# Check if the user has a GitLab provider token before syncing
|
||||
if not await _user_has_gitlab_provider(user_id):
|
||||
logger.debug(
|
||||
'gitlab_repo_sync_skipped: user has no GitLab provider',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
|
||||
# Lazy import to avoid circular dependency:
|
||||
# middleware -> gitlab_sync -> integrations.gitlab.gitlab_service
|
||||
# -> openhands.integrations.gitlab.gitlab_service -> get_impl
|
||||
|
||||
@@ -18,6 +18,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
class AssessmentResult:
|
||||
"""Result of a reCAPTCHA Enterprise assessment."""
|
||||
|
||||
name: str
|
||||
score: float
|
||||
valid: bool
|
||||
action_valid: bool
|
||||
@@ -63,6 +64,7 @@ class RecaptchaService:
|
||||
user_ip: str,
|
||||
user_agent: str,
|
||||
email: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> AssessmentResult:
|
||||
"""Create a reCAPTCHA Enterprise assessment.
|
||||
|
||||
@@ -72,6 +74,7 @@ class RecaptchaService:
|
||||
user_ip: The user's IP address.
|
||||
user_agent: The user's browser user agent.
|
||||
email: Optional email for Account Defender hashing.
|
||||
user_id: Optional Keycloak user ID for logging correlation.
|
||||
|
||||
Returns:
|
||||
AssessmentResult with score, validity, and allowed status.
|
||||
@@ -100,6 +103,10 @@ class RecaptchaService:
|
||||
|
||||
response = self.client.create_assessment(request)
|
||||
|
||||
# Capture assessment name for potential annotation later
|
||||
# Format: projects/{project_id}/assessments/{assessment_id}
|
||||
assessment_name = response.name
|
||||
|
||||
token_properties = response.token_properties
|
||||
risk_analysis = response.risk_analysis
|
||||
|
||||
@@ -129,6 +136,7 @@ class RecaptchaService:
|
||||
logger.info(
|
||||
'recaptcha_assessment',
|
||||
extra={
|
||||
'assessment_name': assessment_name,
|
||||
'score': score,
|
||||
'valid': valid,
|
||||
'action_valid': action_valid,
|
||||
@@ -137,10 +145,13 @@ class RecaptchaService:
|
||||
'has_suspicious_labels': has_suspicious_labels,
|
||||
'allowed': allowed,
|
||||
'user_ip': user_ip,
|
||||
'user_id': user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
|
||||
return AssessmentResult(
|
||||
name=assessment_name,
|
||||
score=score,
|
||||
valid=valid,
|
||||
action_valid=action_valid,
|
||||
|
||||
@@ -18,9 +18,10 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from server.rate_limit import RateLimiter, create_redis_rate_limiter
|
||||
from sqlalchemy import delete, select
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
@@ -77,6 +78,15 @@ class SaasUserAuth(UserAuth):
|
||||
self.access_token = SecretStr(tokens['access_token'])
|
||||
self.refresh_token = SecretStr(tokens['refresh_token'])
|
||||
self.refreshed = True
|
||||
if not self.email or not self.email_verified or not self.user_id:
|
||||
# We don't need to verify the signature here because we just refreshed
|
||||
# this token from the IDP via token_manager.refresh()
|
||||
access_token_payload = jwt.decode(
|
||||
tokens['access_token'], options={'verify_signature': False}
|
||||
)
|
||||
self.user_id = access_token_payload['sub']
|
||||
self.email = access_token_payload['email']
|
||||
self.email_verified = access_token_payload['email_verified']
|
||||
|
||||
def _is_token_expired(self, token: SecretStr):
|
||||
logger.debug('saas_user_auth_is_token_expired')
|
||||
@@ -103,7 +113,6 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
@@ -116,7 +125,7 @@ class SaasUserAuth(UserAuth):
|
||||
if secrets_store:
|
||||
return secrets_store
|
||||
user_id = await self.get_user_id()
|
||||
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
|
||||
secrets_store = SaasSecretsStore(user_id, get_config())
|
||||
self.secrets_store = secrets_store
|
||||
return secrets_store
|
||||
|
||||
@@ -153,12 +162,13 @@ class SaasUserAuth(UserAuth):
|
||||
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
for token in tokens:
|
||||
idp_type = ProviderType(token.identity_provider)
|
||||
@@ -184,11 +194,11 @@ class SaasUserAuth(UserAuth):
|
||||
'idp_type': token.identity_provider,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
session.query(AuthTokens).filter(
|
||||
AuthTokens.id == token.id
|
||||
).delete()
|
||||
session.commit()
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(
|
||||
delete(AuthTokens).where(AuthTokens.id == token.id)
|
||||
)
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
self.provider_tokens = MappingProxyType(provider_tokens)
|
||||
@@ -202,15 +212,15 @@ class SaasUserAuth(UserAuth):
|
||||
if settings_store:
|
||||
return settings_store
|
||||
user_id = await self.get_user_id()
|
||||
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
|
||||
settings_store = SaasSettingsStore(user_id, get_config())
|
||||
self.settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
async def get_mcp_api_key(self) -> str:
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
mcp_api_key = api_key_store.retrieve_mcp_api_key(self.user_id)
|
||||
mcp_api_key = await api_key_store.retrieve_mcp_api_key(self.user_id)
|
||||
if not mcp_api_key:
|
||||
mcp_api_key = api_key_store.create_api_key(
|
||||
mcp_api_key = await api_key_store.create_api_key(
|
||||
self.user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
return mcp_api_key
|
||||
@@ -270,15 +280,17 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = api_key_store.validate_api_key(api_key)
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
return SaasUserAuth(
|
||||
saas_user_auth = SaasUserAuth(
|
||||
user_id=user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
auth_type=AuthType.BEARER,
|
||||
)
|
||||
await saas_user_auth.refresh()
|
||||
return saas_user_auth
|
||||
except Exception as exc:
|
||||
raise BearerTokenError from exc
|
||||
|
||||
@@ -317,7 +329,7 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
|
||||
@@ -16,9 +16,11 @@ from keycloak.exceptions import (
|
||||
KeycloakError,
|
||||
KeycloakPostError,
|
||||
)
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_APP_CLIENT_SECRET,
|
||||
DUPLICATE_EMAIL_CHECK,
|
||||
GITHUB_APP_CLIENT_ID,
|
||||
GITHUB_APP_CLIENT_SECRET,
|
||||
GITLAB_APP_CLIENT_ID,
|
||||
@@ -36,9 +38,9 @@ from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from sqlalchemy import String as SQLString
|
||||
from sqlalchemy import type_coerce
|
||||
from sqlalchemy import select, type_coerce
|
||||
from storage.auth_token_store import AuthTokenStore
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
|
||||
@@ -47,6 +49,10 @@ from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.types import SessionExpiredError
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
# HTTP timeout for external IDP calls (in seconds)
|
||||
# This prevents indefinite blocking if an IDP is slow or unresponsive
|
||||
IDP_HTTP_TIMEOUT = 15.0
|
||||
|
||||
|
||||
def _before_sleep_callback(retry_state: RetryCallState) -> None:
|
||||
logger.info(f'Retry attempt {retry_state.attempt_number} for Keycloak operation')
|
||||
@@ -200,7 +206,9 @@ class TokenManager:
|
||||
access_token: str,
|
||||
idp: ProviderType,
|
||||
) -> dict[str, str | int]:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
base_url = KEYCLOAK_SERVER_URL_EXT if self.external else KEYCLOAK_SERVER_URL
|
||||
url = f'{base_url}/realms/{KEYCLOAK_REALM_NAME}/broker/{idp.value}/token'
|
||||
headers = {
|
||||
@@ -359,7 +367,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitHub token')
|
||||
@@ -385,7 +395,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitLab token')
|
||||
@@ -413,7 +425,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed Bitbucket token')
|
||||
@@ -425,6 +439,8 @@ class TokenManager:
|
||||
access_token = data.get('access_token')
|
||||
refresh_token = data.get('refresh_token')
|
||||
if not access_token or not refresh_token:
|
||||
if data.get('error') == 'bad_refresh_token':
|
||||
raise ExpiredError()
|
||||
raise ValueError(
|
||||
'Failed to refresh token: missing access_token or refresh_token in response.'
|
||||
)
|
||||
@@ -646,6 +662,10 @@ class TokenManager:
|
||||
if not email:
|
||||
return False
|
||||
|
||||
# We have the option to skip the duplicate email check in test environments
|
||||
if not DUPLICATE_EMAIL_CHECK:
|
||||
return False
|
||||
|
||||
base_email = extract_base_email(email)
|
||||
if not base_email:
|
||||
logger.warning(f'Could not extract base email from: {email}')
|
||||
@@ -763,25 +783,24 @@ class TokenManager:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
async def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
installation_id: GitHub installation ID (integer or string)
|
||||
installation_token: The token to store
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Ensure installation_id is a string
|
||||
str_installation_id = str(installation_id)
|
||||
# Use type_coerce to ensure SQLAlchemy treats the parameter as a string
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if installation:
|
||||
installation.encrypted_token = self.encrypt_text(installation_token)
|
||||
else:
|
||||
@@ -791,9 +810,9 @@ class TokenManager:
|
||||
encrypted_token=self.encrypt_text(installation_token),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def load_org_token(self, installation_id: int) -> str | None:
|
||||
async def load_org_token(self, installation_id: int) -> str | None:
|
||||
"""Load a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
@@ -802,17 +821,16 @@ class TokenManager:
|
||||
Returns:
|
||||
The decrypted token if found, None otherwise
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Ensure installation_id is a string and use type_coerce
|
||||
str_installation_id = str(installation_id)
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if not installation:
|
||||
return None
|
||||
token = self.decrypt_text(installation.encrypted_token)
|
||||
|
||||
@@ -8,7 +8,7 @@ import socketio
|
||||
from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -524,16 +524,18 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -15,23 +15,31 @@ IS_FEATURE_ENV = (
|
||||
) # Does not include the staging deployment
|
||||
IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
|
||||
# Role name constants
|
||||
ROLE_OWNER = 'owner'
|
||||
ROLE_ADMIN = 'admin'
|
||||
ROLE_MEMBER = 'member'
|
||||
|
||||
# Deprecated - billing margins are now handled internally in litellm
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
# Minimax is now the default as it gives results close to claude in terms of quality
|
||||
# but at a much lower price
|
||||
5: 'minimax-m2.5',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
@@ -53,9 +61,7 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -91,5 +97,5 @@ def get_default_litellm_model():
|
||||
"""Construct proxy for litellm model based on user settings if not set explicitly."""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.github.github_view import GithubViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -14,7 +13,6 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -36,16 +34,12 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_github(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitHub.
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitHub
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
|
||||
@@ -54,8 +48,8 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
github_manager = GithubManager(token_manager, GitHubDataCollector())
|
||||
|
||||
# Send the message
|
||||
await github_manager.send_message(message_obj, self.github_view)
|
||||
# Send the message directly as a string
|
||||
await github_manager.send_message(message, self.github_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Sent summary message to {self.github_view.full_repo_name}#{self.github_view.issue_number}'
|
||||
@@ -108,13 +102,10 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
f'[GitHub] Sent summary instruction to conversation {conversation_id} {summary_event}'
|
||||
)
|
||||
|
||||
# Update the processor state
|
||||
# Update the processor state - the outer session will commit this
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -130,14 +121,15 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
logger.info(f'[GitHub] Summary sent for conversation {conversation_id}')
|
||||
|
||||
# Mark callback as completed status
|
||||
# Mark callback as completed status - the outer session will commit this
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'[GitHub] Error processing conversation callback: {str(e)}'
|
||||
)
|
||||
# Mark callback as error to prevent infinite re-invocation
|
||||
# The outer session will commit this
|
||||
callback.status = CallbackStatus.ERROR
|
||||
callback.updated_at = datetime.now()
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
from integrations.gitlab.gitlab_view import GitlabViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -14,7 +13,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -28,8 +27,7 @@ gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
|
||||
class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to GitLab.
|
||||
"""Processor for sending conversation summaries to GitLab.
|
||||
|
||||
This processor is used to send summaries of conversations to GitLab
|
||||
when agent state changes occur.
|
||||
@@ -39,22 +37,18 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_gitlab(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitLab.
|
||||
"""Send a message to GitLab.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitLab
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
# Send the message
|
||||
await gitlab_manager.send_message(message_obj, self.gitlab_view)
|
||||
# Send the message directly as a string
|
||||
await gitlab_manager.send_message(message, self.gitlab_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Sent summary message to {self.gitlab_view.full_repo_name}#{self.gitlab_view.issue_number}'
|
||||
@@ -111,9 +105,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -132,9 +126,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
||||
@@ -37,8 +37,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_jira(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira issue.
|
||||
"""Send a comment to Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira
|
||||
@@ -59,8 +58,9 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Decrypt API key
|
||||
api_key = jira_manager.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_manager.send_message(
|
||||
jira_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
issue_key=self.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -37,8 +37,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
base_api_url: str
|
||||
|
||||
async def _send_comment_to_jira_dc(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira DC issue.
|
||||
"""Send a comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira DC
|
||||
@@ -61,8 +60,9 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_dc_manager.send_message(
|
||||
jira_dc_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
issue_key=self.issue_key,
|
||||
base_api_url=self.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -36,8 +36,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_linear(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Linear issue.
|
||||
"""Send a comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Linear
|
||||
@@ -60,9 +59,9 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment
|
||||
# Send comment directly as a string
|
||||
await linear_manager.send_message(
|
||||
linear_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
self.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -26,8 +26,7 @@ slack_manager = SlackManager(token_manager)
|
||||
|
||||
|
||||
class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Slack.
|
||||
"""Processor for sending conversation summaries to Slack.
|
||||
|
||||
This processor is used to send summaries of conversations to Slack channels
|
||||
when agent state changes occur.
|
||||
@@ -41,14 +40,13 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
last_user_msg_id: int | None = None
|
||||
|
||||
async def _send_message_to_slack(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to Slack using the conversation_manager's send_to_event_stream method.
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Slack
|
||||
"""
|
||||
try:
|
||||
# Create a message object for Slack
|
||||
# Create a message object for Slack view creation (incoming message format)
|
||||
message_obj = Message(
|
||||
source=SourceType.SLACK,
|
||||
message={
|
||||
@@ -67,9 +65,8 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
message_obj, slack_user, saas_user_auth
|
||||
)
|
||||
await slack_manager.send_message(
|
||||
slack_manager.create_outgoing_message(message), slack_view
|
||||
)
|
||||
# Send the message directly as a string
|
||||
await slack_manager.send_message(message, slack_view)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Sent summary message to channel {self.channel_id} '
|
||||
|
||||
68
enterprise/server/email_validation.py
Normal file
68
enterprise/server/email_validation.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Email domain validation utilities for enterprise endpoints.
|
||||
"""
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
|
||||
|
||||
async def get_admin_user_id(
|
||||
request: Request, user_id: str | None = Depends(get_user_id)
|
||||
) -> str:
|
||||
"""
|
||||
Dependency that validates user has @openhands.dev email domain.
|
||||
|
||||
This dependency can be used in place of get_user_id for endpoints that
|
||||
should only be accessible to admin users. Currently, this is implemented
|
||||
by checking for @openhands.dev email domain.
|
||||
|
||||
TODO: In the future, this should be replaced with an explicit is_admin flag
|
||||
in user/org settings instead of relying on email domain validation.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
user_id: User ID from get_user_id dependency
|
||||
|
||||
Returns:
|
||||
str: User ID if email domain is valid
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if email domain is not @openhands.dev
|
||||
HTTPException: 401 if user is not authenticated
|
||||
|
||||
Example:
|
||||
@router.post('/endpoint')
|
||||
async def create_resource(
|
||||
user_id: str = Depends(get_admin_user_id),
|
||||
):
|
||||
# Only admin users can access this endpoint
|
||||
pass
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_auth = await get_user_auth(request)
|
||||
user_email = await user_auth.get_user_email()
|
||||
|
||||
if not user_email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User email not available',
|
||||
)
|
||||
|
||||
if not user_email.endswith('@openhands.dev'):
|
||||
logger.warning(
|
||||
'Access denied - invalid email domain',
|
||||
extra={'user_id': user_id, 'email_domain': user_email.split('@')[-1]},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Access restricted to @openhands.dev users',
|
||||
)
|
||||
|
||||
return user_id
|
||||
@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
|
||||
obj['stack_info'] = format_stack(stack_info)
|
||||
|
||||
result = json.dumps(obj, **kwargs)
|
||||
|
||||
# Swap out newlines to make things easier to read. This will produce
|
||||
# invalid json but means we can have similar logs in local development
|
||||
# to production, making things easier to correlate. Obviously,
|
||||
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
|
||||
if LOG_JSON_FOR_CONSOLE:
|
||||
result = result.replace('\\n', '\n')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -44,11 +44,13 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -144,22 +146,26 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -167,6 +173,7 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -174,15 +181,18 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -206,6 +216,7 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
|
||||
@@ -24,7 +22,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
# NOTE: these details are specific to the MCP protocol
|
||||
class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
@staticmethod
|
||||
def create_default_mcp_server_config(
|
||||
async def create_default_mcp_server_config(
|
||||
host: str, config: 'OpenHandsConfig', user_id: str | None = None
|
||||
) -> tuple[MCPSHTTPServerConfig | None, list[MCPStdioServerConfig]]:
|
||||
"""
|
||||
@@ -36,13 +34,16 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
|
||||
Returns:
|
||||
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
|
||||
"""
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
if user_id:
|
||||
api_key = api_key_store.retrieve_mcp_api_key(user_id)
|
||||
api_key = await api_key_store.retrieve_mcp_api_key(user_id)
|
||||
|
||||
if not api_key:
|
||||
api_key = api_key_store.create_api_key(user_id, 'MCP_API_KEY', None)
|
||||
api_key = await api_key_store.create_api_key(
|
||||
user_id, 'MCP_API_KEY', None
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
logger.error(f'Could not provision MCP API Key for user: {user_id}')
|
||||
|
||||
@@ -103,11 +103,13 @@ class SetAuthCookieMiddleware:
|
||||
keycloak_auth_cookie = request.cookies.get('keycloak_auth')
|
||||
auth_header = request.headers.get('Authorization')
|
||||
mcp_auth_header = request.headers.get('X-Session-API-Key')
|
||||
accepted_tos = False
|
||||
api_auth_header = request.headers.get('X-Access-Token')
|
||||
accepted_tos: bool | None = False
|
||||
if (
|
||||
keycloak_auth_cookie is None
|
||||
and (auth_header is None or not auth_header.startswith('Bearer '))
|
||||
and mcp_auth_header is None
|
||||
and api_auth_header is None
|
||||
):
|
||||
raise NoCredentialsError
|
||||
|
||||
@@ -144,7 +146,7 @@ class SetAuthCookieMiddleware:
|
||||
# "if accepted_tos is not None" as there should not be any users with
|
||||
# accepted_tos equal to "None"
|
||||
if accepted_tos is False and request.url.path != '/api/accept_tos':
|
||||
logger.error('User has not accepted the terms of service')
|
||||
logger.warning('User has not accepted the terms of service')
|
||||
raise TosNotAcceptedError
|
||||
|
||||
def _should_attach(self, request: Request) -> bool:
|
||||
@@ -160,8 +162,10 @@ class SetAuthCookieMiddleware:
|
||||
'/api/billing/customer-setup-success',
|
||||
'/api/billing/stripe-webhook',
|
||||
'/api/email/resend',
|
||||
'/api/organizations/members/invite/accept',
|
||||
'/oauth/device/authorize',
|
||||
'/oauth/device/token',
|
||||
'/api/v1/web-client/config',
|
||||
)
|
||||
if path in ignore_paths:
|
||||
return False
|
||||
@@ -172,6 +176,10 @@ class SetAuthCookieMiddleware:
|
||||
):
|
||||
return False
|
||||
|
||||
# Webhooks access is controlled using separate API keys
|
||||
if path.startswith('/api/v1/webhooks/'):
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
@@ -1,113 +1,87 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
)
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user_db_settings = await call_sync_from_async(
|
||||
settings_store.get_user_settings_by_keycloak_id, user_id
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
if current_org_member.llm_api_key_for_byor:
|
||||
return current_org_member.llm_api_key_for_byor.get_secret_value()
|
||||
return None
|
||||
|
||||
|
||||
async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
def _update_user_settings():
|
||||
with session_maker() as session:
|
||||
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
current_org_id = user.current_org_id
|
||||
current_org_member: OrgMember = None
|
||||
for org_member in user.org_members:
|
||||
if org_member.org_id == current_org_id:
|
||||
current_org_member = org_member
|
||||
break
|
||||
if not current_org_member:
|
||||
return None
|
||||
current_org_member.llm_api_key_for_byor = key
|
||||
OrgMemberStore.update_org_member(current_org_member)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
return None
|
||||
current_org_id = str(user.current_org_id)
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id,
|
||||
current_org_id,
|
||||
f'BYOR Key - user {user_id}, org {current_org_id}',
|
||||
{'type': 'byor'},
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -116,96 +90,25 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
|
||||
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
|
||||
|
||||
Args:
|
||||
byor_key: The BYOR key to verify
|
||||
user_id: The user ID for logging purposes
|
||||
|
||||
Returns:
|
||||
True if the key is verified as valid, False if verification fails or key is invalid.
|
||||
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
|
||||
"""
|
||||
if not (LITE_LLM_API_URL and byor_key):
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
|
||||
) as client:
|
||||
# Make a lightweight request to verify the key
|
||||
# Using /v1/models endpoint as it's lightweight and requires authentication
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/v1/models',
|
||||
headers={
|
||||
'Authorization': f'Bearer {byor_key}',
|
||||
},
|
||||
)
|
||||
|
||||
# Only 200 status code indicates valid key
|
||||
if response.status_code == 200:
|
||||
logger.debug(
|
||||
'BYOR key verification successful',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
|
||||
# All other status codes (401, 403, 500, etc.) are treated as invalid
|
||||
# This includes authentication errors and server errors
|
||||
logger.warning(
|
||||
'BYOR key verification failed - treating as invalid',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': response.status_code,
|
||||
'key_prefix': byor_key[:10] + '...'
|
||||
if len(byor_key) > 10
|
||||
else byor_key,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
except (httpx.TimeoutException, Exception) as e:
|
||||
# Any exception (timeout, network error, etc.) means we can't verify
|
||||
# Return False to trigger regeneration rather than returning potentially invalid key
|
||||
logger.warning(
|
||||
'BYOR key verification error - treating as invalid to ensure key validity',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
'error_type': type(e).__name__,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
"""Delete the BYOR key from LiteLLM using the key directly.
|
||||
|
||||
Also attempts to delete by key alias if the key is not found,
|
||||
to clean up orphaned aliases that could block key regeneration.
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
# Get user to construct the key alias
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
key_alias = None
|
||||
if user and user.current_org_id:
|
||||
key_alias = f'BYOR Key - user {user_id}, org {user.current_org_id}'
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
await LiteLlmManager.delete_key(byor_key, key_alias=key_alias)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -233,9 +136,9 @@ class ApiKeyCreate(BaseModel):
|
||||
class ApiKeyResponse(BaseModel):
|
||||
id: int
|
||||
name: str | None = None
|
||||
created_at: str
|
||||
last_used_at: str | None = None
|
||||
expires_at: str | None = None
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
expires_at: datetime | None = None
|
||||
|
||||
|
||||
class ApiKeyCreateResponse(ApiKeyResponse):
|
||||
@@ -246,58 +149,78 @@ class LlmApiKeyResponse(BaseModel):
|
||||
key: str | None
|
||||
|
||||
|
||||
@api_router.post('', response_model=ApiKeyCreateResponse)
|
||||
async def create_api_key(key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)):
|
||||
class ByorPermittedResponse(BaseModel):
|
||||
permitted: bool
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor/permitted', tags=['Keys'])
|
||||
async def check_byor_permitted(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> ByorPermittedResponse:
|
||||
"""Check if BYOR key export is permitted for the user's current org."""
|
||||
try:
|
||||
permitted = await OrgService.check_byor_export_enabled(user_id)
|
||||
return ByorPermittedResponse(permitted=permitted)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error checking BYOR export permission', extra={'error': str(e)}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to check BYOR export permission',
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('', tags=['Keys'])
|
||||
async def create_api_key(
|
||||
key_data: ApiKeyCreate, user_id: str = Depends(get_user_id)
|
||||
) -> ApiKeyCreateResponse:
|
||||
"""Create a new API key for the authenticated user."""
|
||||
try:
|
||||
api_key = api_key_store.create_api_key(
|
||||
api_key = await api_key_store.create_api_key(
|
||||
user_id, key_data.name, key_data.expires_at
|
||||
)
|
||||
# Get the created key details
|
||||
keys = api_key_store.list_api_keys(user_id)
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
for key in keys:
|
||||
if key['name'] == key_data.name:
|
||||
return {
|
||||
**key,
|
||||
'key': api_key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
if key.name == key_data.name:
|
||||
return ApiKeyCreateResponse(
|
||||
id=key.id,
|
||||
name=key.name,
|
||||
key=api_key,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
expires_at=key.expires_at,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error creating API key')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to create API key',
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('', response_model=list[ApiKeyResponse])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('', tags=['Keys'])
|
||||
async def list_api_keys(user_id: str = Depends(get_user_id)) -> list[ApiKeyResponse]:
|
||||
"""List all API keys for the authenticated user."""
|
||||
try:
|
||||
keys = api_key_store.list_api_keys(user_id)
|
||||
return [
|
||||
{
|
||||
**key,
|
||||
'created_at': (
|
||||
key['created_at'].isoformat() if key['created_at'] else None
|
||||
),
|
||||
'last_used_at': (
|
||||
key['last_used_at'].isoformat() if key['last_used_at'] else None
|
||||
),
|
||||
'expires_at': (
|
||||
key['expires_at'].isoformat() if key['expires_at'] else None
|
||||
),
|
||||
}
|
||||
for key in keys
|
||||
]
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
return [api_key_to_response(key) for key in keys]
|
||||
except Exception:
|
||||
logger.exception('Error listing API keys')
|
||||
raise HTTPException(
|
||||
@@ -306,16 +229,18 @@ async def list_api_keys(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.delete('/{key_id}')
|
||||
async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
@api_router.delete('/{key_id}', tags=['Keys'])
|
||||
async def delete_api_key(
|
||||
key_id: int, user_id: str = Depends(get_user_id)
|
||||
) -> MessageResponse:
|
||||
"""Delete an API key."""
|
||||
try:
|
||||
# First, verify the key belongs to the user
|
||||
keys = api_key_store.list_api_keys(user_id)
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
key_to_delete = None
|
||||
|
||||
for key in keys:
|
||||
if key['id'] == key_id:
|
||||
if key.id == key_id:
|
||||
key_to_delete = key
|
||||
break
|
||||
|
||||
@@ -326,14 +251,14 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
# Delete the key
|
||||
success = api_key_store.delete_api_key_by_id(key_id)
|
||||
success = await api_key_store.delete_api_key_by_id(key_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to delete API key',
|
||||
)
|
||||
return {'message': 'API key deleted successfully'}
|
||||
return MessageResponse(message='API key deleted successfully')
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
@@ -344,22 +269,33 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
|
||||
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
This endpoint validates that the key exists in LiteLLM before returning it.
|
||||
If validation fails, it automatically generates a new key to ensure users
|
||||
always receive a working key.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
try:
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Check if the BYOR key exists in the database
|
||||
byor_key = await get_byor_key_from_db(user_id)
|
||||
if byor_key:
|
||||
# Validate that the key is actually registered in LiteLLM
|
||||
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
|
||||
is_valid = await LiteLlmManager.verify_key(byor_key, user_id)
|
||||
if is_valid:
|
||||
return {'key': byor_key}
|
||||
return LlmApiKeyResponse(key=byor_key)
|
||||
else:
|
||||
# Key exists in DB but is invalid in LiteLLM - regenerate it
|
||||
logger.warning(
|
||||
@@ -384,7 +320,7 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'Successfully generated and stored new BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate new BYOR LLM API key',
|
||||
@@ -406,19 +342,22 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/llm/byor/refresh', response_model=LlmApiKeyResponse)
|
||||
async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
|
||||
@api_router.post('/llm/byor/refresh', tags=['Keys'])
|
||||
async def refresh_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> LlmApiKeyResponse:
|
||||
"""Refresh the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
|
||||
|
||||
Returns 402 Payment Required if BYOR export is not enabled for the user's org.
|
||||
"""
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
# Check if BYOR export is enabled for the user's org
|
||||
if not await OrgService.check_byor_export_enabled(user_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail='BYOR key export is not enabled. Purchase credits to enable this feature.',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
@@ -459,7 +398,7 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
'BYOR LLM API key refresh completed successfully',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return {'key': key}
|
||||
return LlmApiKeyResponse(key=key)
|
||||
except HTTPException as he:
|
||||
logger.error(
|
||||
'HTTP exception during BYOR LLM API key refresh',
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
import posthog
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, Response, status
|
||||
@@ -22,12 +24,20 @@ from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.recaptcha_service import recaptcha_service
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config, sign_token
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
InvitationExpiredError,
|
||||
InvitationInvalidError,
|
||||
OrgInvitationService,
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -87,7 +97,8 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -102,22 +113,40 @@ def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
|
||||
)
|
||||
|
||||
|
||||
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
|
||||
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token, invitation_token).
|
||||
Tokens may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None, None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA and/or invitation)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return (
|
||||
state_data.get('redirect_url', ''),
|
||||
state_data.get('recaptcha_token'),
|
||||
state_data.get('invitation_token'),
|
||||
)
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None, None
|
||||
|
||||
|
||||
# Keep alias for backward compatibility
|
||||
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
|
||||
"""Extract redirect URL and reCAPTCHA token from OAuth state.
|
||||
|
||||
Deprecated: Use _extract_oauth_state instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token). Token may be None.
|
||||
"""
|
||||
if not state:
|
||||
return '', None
|
||||
|
||||
try:
|
||||
# Try to decode as JSON (new format with reCAPTCHA)
|
||||
state_data = json.loads(base64.urlsafe_b64decode(state.encode()).decode())
|
||||
return state_data.get('redirect_url', ''), state_data.get('recaptcha_token')
|
||||
except Exception:
|
||||
# Old format - state is just the redirect URL
|
||||
return state, None
|
||||
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
|
||||
return redirect_url, recaptcha_token
|
||||
|
||||
|
||||
@oauth_router.get('/keycloak/callback')
|
||||
@@ -128,8 +157,8 @@ async def keycloak_callback(
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
):
|
||||
# Extract redirect URL and reCAPTCHA token from state
|
||||
redirect_url, recaptcha_token = _extract_recaptcha_state(state)
|
||||
# Extract redirect URL, reCAPTCHA token, and invitation token from state
|
||||
redirect_url, recaptcha_token, invitation_token = _extract_oauth_state(state)
|
||||
if not redirect_url:
|
||||
redirect_url = str(request.base_url)
|
||||
|
||||
@@ -174,6 +203,24 @@ async def keycloak_callback(
|
||||
|
||||
email = user_info.get('email')
|
||||
user_id = user_info['sub']
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
if not user:
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
else:
|
||||
# Existing user — gradually backfill contact_name if it still has a username-style value
|
||||
await UserStore.backfill_contact_name(user_id, user_info)
|
||||
await UserStore.backfill_user_email(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# reCAPTCHA verification with Account Defender
|
||||
if RECAPTCHA_SITE_KEY:
|
||||
@@ -203,6 +250,7 @@ async def keycloak_callback(
|
||||
user_ip=user_ip,
|
||||
user_agent=user_agent,
|
||||
email=email,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if not result.allowed:
|
||||
@@ -223,7 +271,7 @@ async def keycloak_callback(
|
||||
# Fail open - continue with login if reCAPTCHA service unavailable
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
if email and await domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
@@ -282,8 +330,13 @@ async def keycloak_callback(
|
||||
from server.routes.email import verify_email
|
||||
|
||||
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
|
||||
redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
# Preserve invitation token so it can be included in OAuth state after verification
|
||||
if invitation_token:
|
||||
verification_redirect_url = (
|
||||
f'{verification_redirect_url}&invitation_token={invitation_token}'
|
||||
)
|
||||
response = RedirectResponse(verification_redirect_url, status_code=302)
|
||||
return response
|
||||
|
||||
# default to github IDP for now.
|
||||
@@ -360,14 +413,78 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
|
||||
# Process invitation token if present (after email verification but before TOS)
|
||||
if invitation_token:
|
||||
try:
|
||||
logger.info(
|
||||
'Processing invitation token during auth callback',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'invitation_token_prefix': invitation_token[:10] + '...',
|
||||
},
|
||||
)
|
||||
|
||||
await OrgInvitationService.accept_invitation(
|
||||
invitation_token, parse_uuid(user_id)
|
||||
)
|
||||
logger.info(
|
||||
'Invitation accepted during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation expired during auth callback',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
# Add query param to redirect URL
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_expired=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_expired=true'
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
'Invalid invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_invalid=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_invalid=true'
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'User already member during invitation acceptance',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&already_member=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?already_member=true'
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
'Email mismatch during auth callback invitation acceptance',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&email_mismatch=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?email_mismatch=true'
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error processing invitation during auth callback',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
# Don't fail the login if invitation processing fails
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_error=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_error=true'
|
||||
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
@@ -375,8 +492,12 @@ async def keycloak_callback(
|
||||
tos_redirect_url = (
|
||||
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
|
||||
)
|
||||
if invitation_token:
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
else:
|
||||
if invitation_token:
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
set_response_cookie(
|
||||
@@ -430,7 +551,10 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
|
||||
user_id=user_info['sub'], offline_token=keycloak_refresh_token
|
||||
)
|
||||
|
||||
return RedirectResponse(state if state else request.base_url, status_code=302)
|
||||
redirect_url, _, _ = _extract_oauth_state(state)
|
||||
return RedirectResponse(
|
||||
redirect_url if redirect_url else request.base_url, status_code=302
|
||||
)
|
||||
|
||||
|
||||
@oauth_router.get('/github/callback')
|
||||
@@ -486,28 +610,23 @@ async def accept_tos(request: Request):
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
accepted_tos: datetime = datetime.now(timezone.utc)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).where(User.id == uuid.UUID(user_id))
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
await session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
session.add(user_settings)
|
||||
user.accepted_tos = accepted_tos
|
||||
await session.commit()
|
||||
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
|
||||
@@ -4,63 +4,42 @@ from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.config import get_config
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.database import a_session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
billing_router = APIRouter(prefix='/api/billing', tags=['Billing'])
|
||||
|
||||
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
async def validate_billing_enabled() -> None:
|
||||
"""
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
Validate that the billing feature flag is enabled
|
||||
"""
|
||||
if not is_all_hands_saas_environment(request):
|
||||
config = get_global_config()
|
||||
web_client_config = await config.web_client.get_web_client_config()
|
||||
if not web_client_config.feature_flags.enable_billing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
detail=(
|
||||
'Billing is disabled in this environment. '
|
||||
'Please set OH_WEB_CLIENT_FEATURE_FLAGS_ENABLE_BILLING to enable billing.'
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -106,31 +85,20 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
return max(max_budget - spend, 0.0)
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current credit balance
|
||||
# Endpoint to retrieve the current organization's credit balance
|
||||
@billing_router.get('/credits')
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=15.0
|
||||
) as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f'litellm_get_user_failed: {type(e).__name__}: {e}',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'status_code': e.response.status_code,
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve credit balance from billing service',
|
||||
)
|
||||
user = await UserStore.get_user_by_id_async(user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
max_budget, spend = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, user_id, str(user.current_org_id)
|
||||
)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@@ -139,16 +107,17 @@ async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(SubscriptionAccess).where(
|
||||
SubscriptionAccess.status == 'ACTIVE',
|
||||
SubscriptionAccess.user_id == user_id,
|
||||
SubscriptionAccess.start_at <= now,
|
||||
SubscriptionAccess.end_at >= now,
|
||||
)
|
||||
)
|
||||
subscription_access = result.scalar_one_or_none()
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
@@ -165,79 +134,7 @@ async def get_subscription_access(
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -245,17 +142,17 @@ async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONRespon
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
await validate_billing_enabled()
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
base_url = _get_base_url(request)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
success_url=f'{base_url}?setup=success',
|
||||
cancel_url=f'{base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -265,11 +162,11 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
await validate_billing_enabled()
|
||||
base_url = _get_base_url(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -282,141 +179,52 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
}
|
||||
},
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@billing_router.get('/success')
|
||||
async def success_callback(session_id: str, request: Request):
|
||||
# We can't use the auth cookie because of SameSite=strict
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
|
||||
if billing_session is None:
|
||||
# Hopefully this never happens - we get a redirect from stripe where the session does not exist
|
||||
@@ -425,15 +233,6 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -447,47 +246,61 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
user = await UserStore.get_user_by_id_async(billing_session.user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget, _ = LiteLlmManager.get_budget_from_team_info(
|
||||
user_team_info, billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
result = await session.execute(select(Org).where(Org.id == user.current_org_id))
|
||||
org = result.scalar_one_or_none()
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Enable BYOR export for the org now that they've purchased credits
|
||||
if org:
|
||||
org.byor_export_enabled = True
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
|
||||
# Callback endpoint for cancelled Stripe payments - updates billing session status
|
||||
@billing_router.get('/cancel')
|
||||
async def cancel_callback(session_id: str, request: Request):
|
||||
with session_maker() as session:
|
||||
billing_session = (
|
||||
session.query(BillingSession)
|
||||
.filter(BillingSession.id == session_id)
|
||||
.filter(BillingSession.status == 'in_progress')
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(BillingSession).where(
|
||||
BillingSession.id == session_id,
|
||||
BillingSession.status == 'in_progress',
|
||||
)
|
||||
)
|
||||
billing_session = result.scalar_one_or_none()
|
||||
if billing_session:
|
||||
logger.info(
|
||||
'stripe_checkout_cancel',
|
||||
@@ -499,208 +312,16 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
billing_session.status = 'cancelled'
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
config = get_config()
|
||||
settings_store = SaasSettingsStore(
|
||||
user_id=user_id, session_maker=session_maker, config=config
|
||||
)
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = settings_store.get_user_settings_by_keycloak_id(
|
||||
user_id, session
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
def _get_base_url(request: Request) -> URL:
|
||||
# Never send any part of the credit card process over a non secure connection
|
||||
base_url = request.base_url
|
||||
if base_url.hostname != 'localhost':
|
||||
base_url = base_url.replace(scheme='https')
|
||||
return base_url
|
||||
|
||||
@@ -5,8 +5,8 @@ from threading import Thread
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.database import a_session_maker, get_engine, session_maker
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -47,6 +47,7 @@ def add_debugging_routes(api: FastAPI):
|
||||
- checked_out: Number of connections currently in use
|
||||
- overflow: Number of overflow connections created beyond pool_size
|
||||
"""
|
||||
engine = get_engine()
|
||||
return {
|
||||
'checked_in': engine.pool.checkedin(),
|
||||
'checked_out': engine.pool.checkedout(),
|
||||
@@ -127,8 +128,9 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(UserSettings).count()
|
||||
num_users = session.query(User).count()
|
||||
time.sleep(delay)
|
||||
engine = get_engine()
|
||||
logger.info(
|
||||
'check',
|
||||
extra={
|
||||
@@ -155,7 +157,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
stmt = select(func.count(User.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -8,6 +8,7 @@ from server.auth.keycloak_manager import get_keycloak_admin
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -62,6 +63,10 @@ async def update_email(
|
||||
},
|
||||
)
|
||||
|
||||
await UserStore.update_user_email(
|
||||
user_id=user_id, email=email, email_verified=False
|
||||
)
|
||||
|
||||
user_auth: SaasUserAuth = await get_user_auth(request)
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email = email
|
||||
@@ -144,6 +149,7 @@ async def verified_email(request: Request):
|
||||
user_auth: SaasUserAuth = await get_user_auth(request)
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email_verified = True
|
||||
await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True)
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/settings/user'
|
||||
response = RedirectResponse(redirect_uri, status_code=302)
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -3,16 +3,22 @@ from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
router = APIRouter(prefix='/feedback', tags=['feedback'])
|
||||
# We use the get_dependencies method here to signal to the OpenAPI docs that this endpoint
|
||||
# is protected. The actual protection is provided by SetAuthCookieMiddleware
|
||||
# TODO: It may be an error by you can actually post feedback to a conversation you don't
|
||||
# own right now - maybe this is useful in the context of public shared conversations?
|
||||
router = APIRouter(
|
||||
prefix='/feedback', tags=['feedback'], dependencies=get_dependencies()
|
||||
)
|
||||
|
||||
|
||||
async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
@@ -30,23 +36,19 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
"""
|
||||
|
||||
# Verify the conversation belongs to the user
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
)
|
||||
metadata = result.scalars().first()
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
if not metadata:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'Conversation {conversation_id} not found',
|
||||
)
|
||||
|
||||
await call_sync_from_async(_verify_conversation)
|
||||
|
||||
# Create an event store to access the events directly
|
||||
# This works even when the conversation is not running
|
||||
@@ -96,12 +98,9 @@ async def submit_conversation_feedback(feedback: FeedbackRequest):
|
||||
)
|
||||
|
||||
# Add to database
|
||||
def _save_feedback():
|
||||
with session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_feedback)
|
||||
async with a_session_maker() as session:
|
||||
session.add(new_feedback)
|
||||
await session.commit()
|
||||
|
||||
return {'status': 'success', 'message': 'Feedback submitted successfully'}
|
||||
|
||||
@@ -120,30 +119,27 @@ async def get_batch_feedback(conversation_id: str, user_id: str = Depends(get_us
|
||||
return {}
|
||||
|
||||
# Query for existing feedback for all events
|
||||
def _check_feedback():
|
||||
with session_maker() as session:
|
||||
result = session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ConversationFeedback).where(
|
||||
ConversationFeedback.conversation_id == conversation_id,
|
||||
ConversationFeedback.event_id.in_(event_ids),
|
||||
)
|
||||
)
|
||||
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
# Create a mapping of event_id to feedback
|
||||
feedback_map = {
|
||||
feedback.event_id: {
|
||||
'exists': True,
|
||||
'rating': feedback.rating,
|
||||
'reason': feedback.reason,
|
||||
}
|
||||
for feedback in result.scalars()
|
||||
}
|
||||
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
# Build response including all events
|
||||
response = {}
|
||||
for event_id in event_ids:
|
||||
response[str(event_id)] = feedback_map.get(event_id, {'exists': False})
|
||||
|
||||
return response
|
||||
|
||||
return await call_sync_from_async(_check_feedback)
|
||||
return response
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import zlib
|
||||
from base64 import b64decode, b64encode
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
@@ -51,7 +52,11 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
state_payload = json.dumps(
|
||||
[query_params['state'][0], query_params['redirect_uri'][0]]
|
||||
)
|
||||
state = b64encode(_fernet().encrypt(state_payload.encode())).decode()
|
||||
# Compress before encrypting to reduce URL length
|
||||
# This is critical for feature deployments where reCAPTCHA tokens in state
|
||||
# can cause "URL too long" errors from GitHub
|
||||
compressed_payload = zlib.compress(state_payload.encode())
|
||||
state = b64encode(_fernet().encrypt(compressed_payload)).decode()
|
||||
query_params['state'] = [state]
|
||||
query_params['redirect_uri'] = [
|
||||
f'https://{request.url.netloc}/github-proxy/callback'
|
||||
@@ -67,7 +72,9 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
parsed_url = urlparse(str(request.url))
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
state = query_params['state'][0]
|
||||
decrypted_state = _fernet().decrypt(b64decode(state.encode())).decode()
|
||||
# Decrypt and decompress (reverse of github_proxy_start)
|
||||
decrypted_payload = _fernet().decrypt(b64decode(state.encode()))
|
||||
decrypted_state = zlib.decompress(decrypted_payload).decode()
|
||||
|
||||
# Build query Params
|
||||
state, redirect_uri = json.loads(decrypted_state)
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from integrations.jira.jira_manager import JiraManager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import HOST_URL
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from server.auth.constants import JIRA_CLIENT_ID, JIRA_CLIENT_SECRET
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import WEB_HOST
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.redis import create_redis_client
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -24,7 +27,7 @@ JIRA_WEBHOOKS_ENABLED = os.environ.get('JIRA_WEBHOOKS_ENABLED', '0') in (
|
||||
'1',
|
||||
'true',
|
||||
)
|
||||
JIRA_REDIRECT_URI = f'https://{WEB_HOST}/integration/jira/callback'
|
||||
JIRA_REDIRECT_URI = f'{HOST_URL}/integration/jira/callback'
|
||||
JIRA_SCOPES = 'read:me read:jira-user read:jira-work'
|
||||
JIRA_AUTH_URL = 'https://auth.atlassian.com/authorize'
|
||||
JIRA_TOKEN_URL = 'https://auth.atlassian.com/oauth/token'
|
||||
@@ -122,6 +125,63 @@ jira_manager = JiraManager(token_manager)
|
||||
redis_client = create_redis_client()
|
||||
|
||||
|
||||
async def verify_jira_signature(body: bytes, signature: str, payload: dict):
|
||||
"""
|
||||
Verify Jira webhook signature.
|
||||
|
||||
Args:
|
||||
body: Raw request body bytes
|
||||
signature: Signature from x-hub-signature header (format: "sha256=<hash>")
|
||||
payload: Parsed JSON payload from webhook
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if signature verification fails or workspace is invalid
|
||||
|
||||
Returns:
|
||||
None (raises exception on failure)
|
||||
"""
|
||||
|
||||
if not signature:
|
||||
raise HTTPException(
|
||||
status_code=403, detail='x-hub-signature header is missing!'
|
||||
)
|
||||
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
if workspace_name is None:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
raise HTTPException(
|
||||
status_code=403, detail='Workspace name not found in payload'
|
||||
)
|
||||
|
||||
workspace: (
|
||||
JiraWorkspace | None
|
||||
) = await jira_manager.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if workspace is None:
|
||||
logger.warning(f'[Jira] Could not identify workspace {workspace_name}')
|
||||
raise HTTPException(status_code=403, detail='Unidentified workspace')
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(
|
||||
'[Jira] Workspace is inactive',
|
||||
extra={
|
||||
'jira_workspace_id': workspace.id,
|
||||
'parsed_workspace_name': workspace.name,
|
||||
'status': workspace.status,
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=403, detail='Workspace is inactive')
|
||||
|
||||
webhook_secret = token_manager.decrypt_text(workspace.webhook_secret)
|
||||
expected_signature = hmac.new(
|
||||
webhook_secret.encode(), body, hashlib.sha256
|
||||
).hexdigest()
|
||||
|
||||
if not hmac.compare_digest(expected_signature, signature):
|
||||
raise HTTPException(status_code=403, detail="Request signatures didn't match!")
|
||||
|
||||
|
||||
async def _handle_workspace_link_creation(
|
||||
user_id: str, jira_user_id: str, target_workspace: str
|
||||
):
|
||||
@@ -216,6 +276,7 @@ async def _validate_workspace_update_permissions(user_id: str, target_workspace:
|
||||
async def jira_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_hub_signature: str = Header(None),
|
||||
):
|
||||
"""Handle Jira webhook events."""
|
||||
# Check if Jira webhooks are enabled
|
||||
@@ -227,13 +288,15 @@ async def jira_events(
|
||||
)
|
||||
|
||||
try:
|
||||
signature_valid, signature, payload = await jira_manager.validate_request(
|
||||
request
|
||||
)
|
||||
parts = x_hub_signature.split('=', 1)
|
||||
if not (len(parts) == 2 and parts[1]):
|
||||
raise HTTPException(status_code=403, detail='Malformed x-hub-signature!')
|
||||
|
||||
if not signature_valid:
|
||||
logger.warning('[Jira] Invalid webhook signature')
|
||||
raise HTTPException(status_code=403, detail='Invalid webhook signature!')
|
||||
signature = parts[1]
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
|
||||
await verify_jira_signature(body, signature, payload)
|
||||
|
||||
# Check for duplicate requests using Redis
|
||||
key = f'jira:{signature}'
|
||||
@@ -245,10 +308,11 @@ async def jira_events(
|
||||
logger.info(f'Processing new Jira webhook event: {signature}')
|
||||
redis_client.setex(key, 300, '1')
|
||||
|
||||
# Process the webhook
|
||||
# Process the webhook in background after returning response.
|
||||
# Note: For async functions, BackgroundTasks runs them in the same event loop
|
||||
# (not a thread pool), so asyncpg connections work correctly.
|
||||
message_payload = {'payload': payload}
|
||||
message = Message(source=SourceType.JIRA, message=message_payload)
|
||||
|
||||
background_tasks.add_task(jira_manager.receive_message, message)
|
||||
|
||||
return JSONResponse({'success': True})
|
||||
@@ -308,9 +372,7 @@ async def create_jira_workspace(request: Request, workspace_data: JiraWorkspaceC
|
||||
'prompt': 'consent',
|
||||
}
|
||||
|
||||
auth_url = (
|
||||
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
)
|
||||
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -369,9 +431,7 @@ async def create_workspace_link(request: Request, link_data: JiraLinkCreate):
|
||||
'response_type': 'code',
|
||||
'prompt': 'consent',
|
||||
}
|
||||
auth_url = (
|
||||
f"{JIRA_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
)
|
||||
auth_url = f'{JIRA_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import requests
|
||||
from fastapi import (
|
||||
@@ -316,7 +316,7 @@ async def create_jira_dc_workspace(
|
||||
'response_type': 'code',
|
||||
}
|
||||
|
||||
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -436,7 +436,7 @@ async def create_workspace_link(request: Request, link_data: JiraDcLinkCreate):
|
||||
'state': state,
|
||||
'response_type': 'code',
|
||||
}
|
||||
auth_url = f"{JIRA_DC_AUTH_URL}?{'&'.join([f'{k}={v}' for k, v in auth_params.items()])}"
|
||||
auth_url = f'{JIRA_DC_AUTH_URL}?{urlencode(auth_params)}'
|
||||
|
||||
return JSONResponse(
|
||||
content={
|
||||
|
||||
@@ -15,7 +15,6 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -35,6 +34,7 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
@@ -79,6 +79,14 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -94,16 +102,17 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -149,9 +158,16 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -180,6 +196,13 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = await UserStore.get_user_by_id_async(keycloak_user_id)
|
||||
if not user:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -211,6 +234,7 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -305,7 +329,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
payload = json.loads(form.get('payload'))
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -21,7 +20,7 @@ DEVICE_CODE_EXPIRES_IN = 600 # 10 minutes
|
||||
DEVICE_TOKEN_POLL_INTERVAL = 5 # seconds
|
||||
|
||||
API_KEY_NAME = 'Device Link Access Key'
|
||||
KEY_EXPIRATION_TIME = timedelta(days=1) # Key expires in 24 hours
|
||||
KEY_EXPIRATION_TIME = timedelta(days=7) # Key expires in a week
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Models
|
||||
@@ -54,7 +53,7 @@ class DeviceTokenErrorResponse(BaseModel):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
oauth_device_router = APIRouter(prefix='/oauth/device')
|
||||
device_code_store = DeviceCodeStore(session_maker)
|
||||
device_code_store = DeviceCodeStore()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -90,7 +89,7 @@ async def device_authorization(
|
||||
) -> DeviceAuthorizationResponse:
|
||||
"""Start device flow by generating device and user codes."""
|
||||
try:
|
||||
device_code_entry = device_code_store.create_device_code(
|
||||
device_code_entry = await device_code_store.create_device_code(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
@@ -125,7 +124,7 @@ async def device_authorization(
|
||||
async def device_token(device_code: str = Form(...)):
|
||||
"""Poll for a token until the user authorizes or the code expires."""
|
||||
try:
|
||||
device_code_entry = device_code_store.get_by_device_code(device_code)
|
||||
device_code_entry = await device_code_store.get_by_device_code(device_code)
|
||||
|
||||
if not device_code_entry:
|
||||
return _oauth_error(
|
||||
@@ -138,7 +137,9 @@ async def device_token(device_code: str = Form(...)):
|
||||
is_too_fast, current_interval = device_code_entry.check_rate_limit()
|
||||
if is_too_fast:
|
||||
# Update poll time and increase interval
|
||||
device_code_store.update_poll_time(device_code, increase_interval=True)
|
||||
await device_code_store.update_poll_time(
|
||||
device_code, increase_interval=True
|
||||
)
|
||||
logger.warning(
|
||||
'Client polling too fast, returning slow_down error',
|
||||
extra={
|
||||
@@ -154,7 +155,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
)
|
||||
|
||||
# Update poll time for successful rate limit check
|
||||
device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
await device_code_store.update_poll_time(device_code, increase_interval=False)
|
||||
|
||||
if device_code_entry.is_expired():
|
||||
return _oauth_error(
|
||||
@@ -181,7 +182,7 @@ async def device_token(device_code: str = Form(...)):
|
||||
# Retrieve the specific API key for this device using the user_code
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
device_key_name = f'{API_KEY_NAME} ({device_code_entry.user_code})'
|
||||
device_api_key = api_key_store.retrieve_api_key_by_name(
|
||||
device_api_key = await api_key_store.retrieve_api_key_by_name(
|
||||
device_code_entry.keycloak_user_id, device_key_name
|
||||
)
|
||||
|
||||
@@ -238,7 +239,7 @@ async def device_verification_authenticated(
|
||||
)
|
||||
|
||||
# Validate device code
|
||||
device_code_entry = device_code_store.get_by_user_code(user_code)
|
||||
device_code_entry = await device_code_store.get_by_user_code(user_code)
|
||||
if not device_code_entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -252,7 +253,7 @@ async def device_verification_authenticated(
|
||||
)
|
||||
|
||||
# First, authorize the device code
|
||||
success = device_code_store.authorize_device_code(
|
||||
success = await device_code_store.authorize_device_code(
|
||||
user_code=user_code,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -272,7 +273,7 @@ async def device_verification_authenticated(
|
||||
try:
|
||||
# Create a unique API key for this device using user_code in the name
|
||||
device_key_name = f'{API_KEY_NAME} ({user_code})'
|
||||
api_key_store.create_api_key(
|
||||
await api_key_store.create_api_key(
|
||||
user_id,
|
||||
name=device_key_name,
|
||||
expires_at=datetime.now(UTC) + KEY_EXPIRATION_TIME,
|
||||
@@ -289,7 +290,7 @@ async def device_verification_authenticated(
|
||||
# Clean up: revert the device authorization since API key creation failed
|
||||
# This prevents the device from being in an authorized state without an API key
|
||||
try:
|
||||
device_code_store.deny_device_code(user_code)
|
||||
await device_code_store.deny_device_code(user_code)
|
||||
logger.info(
|
||||
'Reverted device authorization due to API key creation failure',
|
||||
extra={'user_code': user_code, 'user_id': user_id},
|
||||
|
||||
122
enterprise/server/routes/org_invitation_models.py
Normal file
122
enterprise/server/routes/org_invitation_models.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Pydantic models and custom exceptions for organization invitations.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
|
||||
class InvitationError(Exception):
|
||||
"""Base exception for invitation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class InvitationAlreadyExistsError(InvitationError):
|
||||
"""Raised when a pending invitation already exists for the email."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = 'A pending invitation already exists for this email'
|
||||
):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class UserAlreadyMemberError(InvitationError):
|
||||
"""Raised when the user is already a member of the organization."""
|
||||
|
||||
def __init__(self, message: str = 'User is already a member of this organization'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationExpiredError(InvitationError):
|
||||
"""Raised when the invitation has expired."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation has expired'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationInvalidError(InvitationError):
|
||||
"""Raised when the invitation is invalid or revoked."""
|
||||
|
||||
def __init__(self, message: str = 'Invitation is no longer valid'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionError(InvitationError):
|
||||
"""Raised when the user lacks permission to perform the action."""
|
||||
|
||||
def __init__(self, message: str = 'Insufficient permission'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class EmailMismatchError(InvitationError):
|
||||
"""Raised when the accepting user's email doesn't match the invitation email."""
|
||||
|
||||
def __init__(self, message: str = 'Your email does not match the invitation'):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InvitationCreate(BaseModel):
|
||||
"""Request model for creating invitation(s)."""
|
||||
|
||||
emails: list[EmailStr]
|
||||
role: str = 'member' # Default to member role
|
||||
|
||||
|
||||
class InvitationResponse(BaseModel):
|
||||
"""Response model for invitation details."""
|
||||
|
||||
id: int
|
||||
email: str
|
||||
role: str
|
||||
status: str
|
||||
created_at: str
|
||||
expires_at: str
|
||||
inviter_email: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_invitation(
|
||||
cls,
|
||||
invitation: OrgInvitation,
|
||||
inviter_email: str | None = None,
|
||||
) -> 'InvitationResponse':
|
||||
"""Create an InvitationResponse from an OrgInvitation entity.
|
||||
|
||||
Args:
|
||||
invitation: The invitation entity to convert
|
||||
inviter_email: Optional email of the inviter
|
||||
|
||||
Returns:
|
||||
InvitationResponse: The response model instance
|
||||
"""
|
||||
role_name = ''
|
||||
if invitation.role:
|
||||
role_name = invitation.role.name
|
||||
elif invitation.role_id:
|
||||
role = RoleStore.get_role_by_id(invitation.role_id)
|
||||
role_name = role.name if role else ''
|
||||
|
||||
return cls(
|
||||
id=invitation.id,
|
||||
email=invitation.email,
|
||||
role=role_name,
|
||||
status=invitation.status,
|
||||
created_at=invitation.created_at.isoformat(),
|
||||
expires_at=invitation.expires_at.isoformat(),
|
||||
inviter_email=inviter_email,
|
||||
)
|
||||
|
||||
|
||||
class InvitationFailure(BaseModel):
|
||||
"""Response model for a failed invitation."""
|
||||
|
||||
email: str
|
||||
error: str
|
||||
|
||||
|
||||
class BatchInvitationResponse(BaseModel):
|
||||
"""Response model for batch invitation creation."""
|
||||
|
||||
successful: list[InvitationResponse]
|
||||
failed: list[InvitationFailure]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user