mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
192 Commits
openhands/
...
add-isolat
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7c39df1d4b | ||
|
|
f107e21d26 | ||
|
|
516591c012 | ||
|
|
12d6da8130 | ||
|
|
9efb67a3bd | ||
|
|
38f2728cfa | ||
|
|
c5ef7a5944 | ||
|
|
20366ba973 | ||
|
|
df03a56888 | ||
|
|
d202c90f5f | ||
|
|
fab48fe864 | ||
|
|
a196881ab0 | ||
|
|
ca2c9546ad | ||
|
|
704fc6dd69 | ||
|
|
6630d5dc4e | ||
|
|
0e7fefca7e | ||
|
|
4020448d64 | ||
|
|
2fdd4d084a | ||
|
|
aba5d54a86 | ||
|
|
6710a39621 | ||
|
|
7addb78158 | ||
|
|
8afa6cf51b | ||
|
|
1289688b64 | ||
|
|
e349d37b8c | ||
|
|
6fec7b729d | ||
|
|
fccc6f3196 | ||
|
|
cd05434d7f | ||
|
|
9e7b74ea32 | ||
|
|
7447cfdb3d | ||
|
|
297af05d53 | ||
|
|
b8f387df94 | ||
|
|
fc67f39b74 | ||
|
|
bc8922d3f9 | ||
|
|
37d58bba4d | ||
|
|
037a2dca8f | ||
|
|
b5920eece6 | ||
|
|
a81bef8cdf | ||
|
|
450aa3b527 | ||
|
|
4646439108 | ||
|
|
4decd8b3e9 | ||
|
|
818f743dc7 | ||
|
|
f402371b27 | ||
|
|
92b1fca719 | ||
|
|
8de13457c3 | ||
|
|
f89e41ac30 | ||
|
|
9b0029c5bb | ||
|
|
3f247952fa | ||
|
|
8f94b68ea1 | ||
|
|
eb616dfae4 | ||
|
|
dc360c8a5c | ||
|
|
26c636d63e | ||
|
|
3ec8d70d04 | ||
|
|
694ac74bb9 | ||
|
|
7ee20067a8 | ||
|
|
054c5b666f | ||
|
|
0ff7329424 | ||
|
|
86c590cdc3 | ||
|
|
319677e629 | ||
|
|
f8b566b858 | ||
|
|
f9694858fb | ||
|
|
7880c39ede | ||
|
|
b5e00f577c | ||
|
|
2631294e79 | ||
|
|
47776ae2ad | ||
|
|
0ad411e162 | ||
|
|
7bc56e0d74 | ||
|
|
e450a3a603 | ||
|
|
5f06aad131 | ||
|
|
26ca1cf2d7 | ||
|
|
75c9a09ad1 | ||
|
|
139a5f7caf | ||
|
|
4caa72d080 | ||
|
|
2f2a1c5c58 | ||
|
|
37e0f7fd6e | ||
|
|
b012176c9c | ||
|
|
a5e1a9fd99 | ||
|
|
0b0d77bcdf | ||
|
|
3791a76216 | ||
|
|
b921f06e2b | ||
|
|
07b8391605 | ||
|
|
2ec03b8c55 | ||
|
|
17e32af6fe | ||
|
|
4b303ec9b4 | ||
|
|
eb954164a5 | ||
|
|
8beb9b4638 | ||
|
|
0c1c2163b1 | ||
|
|
dd2a62c992 | ||
|
|
b40f55a328 | ||
|
|
4e0d553380 | ||
|
|
42c40d75b1 | ||
|
|
f3d9faef34 | ||
|
|
6e30c62078 | ||
|
|
f29161b7f3 | ||
|
|
7d084db6d7 | ||
|
|
134c122026 | ||
|
|
0ab08e93a6 | ||
|
|
523b40dbfc | ||
|
|
d3586bf820 | ||
|
|
e3dbb00d4e | ||
|
|
e11b2008f3 | ||
|
|
6a5b915088 | ||
|
|
a02b5a6c0e | ||
|
|
a5c5133961 | ||
|
|
3b3b05dc33 | ||
|
|
eea1e7f4e1 | ||
|
|
7d6392f793 | ||
|
|
ec3c33afac | ||
|
|
e2d990f3a0 | ||
|
|
f258eafa37 | ||
|
|
19634f364e | ||
|
|
aa6446038c | ||
|
|
dbddc1868e | ||
|
|
cd967ef4bc | ||
|
|
eb847de7ec | ||
|
|
c3e91baa53 | ||
|
|
d2003c83fb | ||
|
|
7c0a939d96 | ||
|
|
e34c13ea3c | ||
|
|
1f35a73cc4 | ||
|
|
267528fa82 | ||
|
|
49f360d021 | ||
|
|
f45b86a396 | ||
|
|
9520da668c | ||
|
|
9d19292619 | ||
|
|
fc9a87550d | ||
|
|
490d3dba10 | ||
|
|
d7bf698d1e | ||
|
|
d655049934 | ||
|
|
6357b46001 | ||
|
|
5ed1dde2e9 | ||
|
|
a68576b876 | ||
|
|
722124ae83 | ||
|
|
186f4423e0 | ||
|
|
44578664ed | ||
|
|
9efe6eb776 | ||
|
|
6d137e883f | ||
|
|
2889f736d9 | ||
|
|
531683abae | ||
|
|
fab64a51b7 | ||
|
|
cc18a18874 | ||
|
|
7525a95af0 | ||
|
|
640f50d525 | ||
|
|
6f2f85073d | ||
|
|
9f3b2425ec | ||
|
|
1ebc3ab04e | ||
|
|
9bd0566e4e | ||
|
|
d82972e126 | ||
|
|
e1b94732a8 | ||
|
|
baf323a26c | ||
|
|
cc7eef9fc0 | ||
|
|
c9a2a6c17f | ||
|
|
2a857a676f | ||
|
|
cf7096e80d | ||
|
|
cfd27b1dce | ||
|
|
c36b628879 | ||
|
|
a34cc6b7e7 | ||
|
|
d70006717e | ||
|
|
bf57a3ac6d | ||
|
|
ffc77fe229 | ||
|
|
82082fcee3 | ||
|
|
8d1f8c24f3 | ||
|
|
0369bc77dd | ||
|
|
1ef111d954 | ||
|
|
69db41aa1d | ||
|
|
a7118ddda6 | ||
|
|
86494cdd90 | ||
|
|
101aa68424 | ||
|
|
47b225d76d | ||
|
|
06758d352a | ||
|
|
6dc6f9514e | ||
|
|
08519c2e44 | ||
|
|
cc1e4b8c4a | ||
|
|
0d6ff3ac50 | ||
|
|
b15ffa29a5 | ||
|
|
5f2ce8e18a | ||
|
|
8f90374f49 | ||
|
|
4c38beb456 | ||
|
|
02f009e6b5 | ||
|
|
fed53185ac | ||
|
|
5cdebc3ed5 | ||
|
|
947fc2f616 | ||
|
|
939242fc22 | ||
|
|
f787f6a089 | ||
|
|
f687bcccf7 | ||
|
|
ba06aa3c0c | ||
|
|
36f516b337 | ||
|
|
3d4805f4b1 | ||
|
|
bf178fcc0e | ||
|
|
7c41d6f30f | ||
|
|
7906b38ded | ||
|
|
d74b0e3fc6 | ||
|
|
07b6ce5ed0 |
33
.github/pull_request_template.md
vendored
33
.github/pull_request_template.md
vendored
@@ -1,12 +1,31 @@
|
||||
- [ ] This change is worth documenting at https://docs.all-hands.dev/
|
||||
- [ ] Include this change in the Release Notes. If checked, you **must** provide an **end-user friendly** description for your change below
|
||||
## Summary of PR
|
||||
|
||||
**End-user friendly description of the problem this fixes or functionality this introduces.**
|
||||
<!-- Summarize what the PR does, explaining any non-trivial design decisions. -->
|
||||
|
||||
## Change Type
|
||||
|
||||
---
|
||||
**Summarize what the PR does, explaining any non-trivial design decisions.**
|
||||
<!-- Choose the types that apply to your PR and remove the rest. -->
|
||||
|
||||
- [ ] Bug fix
|
||||
- [ ] New feature
|
||||
- [ ] Breaking change
|
||||
- [ ] Refactor
|
||||
- [ ] Other (dependency update, docs, typo fixes, etc.)
|
||||
|
||||
---
|
||||
**Link of any specific issues this addresses:**
|
||||
## Checklist
|
||||
|
||||
- [ ] I have read and reviewed the code and I understand what the code is doing.
|
||||
- [ ] I have tested the code to the best of my ability and ensured it works as expected.
|
||||
|
||||
## Fixes
|
||||
|
||||
<!-- If this resolves an issue, link it here so it will close automatically upon merge. -->
|
||||
|
||||
Resolves #(issue)
|
||||
|
||||
## Release Notes
|
||||
|
||||
<!-- Check the box if this change is worth adding to the release notes. If checked, you must provide an
|
||||
end-user friendly description for your change below the checkbox. -->
|
||||
|
||||
- [ ] Include this change in the Release Notes.
|
||||
|
||||
6
.github/scripts/update_pr_description.sh
vendored
6
.github/scripts/update_pr_description.sh
vendored
@@ -13,12 +13,12 @@ DOCKER_RUN_COMMAND="docker run -it --rm \
|
||||
-p 3000:3000 \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:${SHORT_SHA}-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.openhands.dev/openhands/runtime:${SHORT_SHA}-nikolaik \
|
||||
--name openhands-app-${SHORT_SHA} \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:${SHORT_SHA}"
|
||||
docker.openhands.dev/openhands/openhands:${SHORT_SHA}"
|
||||
|
||||
# Define the uvx command
|
||||
UVX_RUN_COMMAND="uvx --python 3.12 --from git+https://github.com/All-Hands-AI/OpenHands@${BRANCH_NAME}#subdirectory=openhands-cli openhands"
|
||||
UVX_RUN_COMMAND="uvx --python 3.12 --from git+https://github.com/OpenHands/OpenHands@${BRANCH_NAME}#subdirectory=openhands-cli openhands"
|
||||
|
||||
# Get the current PR body
|
||||
PR_BODY=$(gh pr view "$PR_NUMBER" --json body --jq .body)
|
||||
|
||||
2
.github/workflows/dispatch-to-docs.yml
vendored
2
.github/workflows/dispatch-to-docs.yml
vendored
@@ -12,7 +12,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
repo: ["All-Hands-AI/docs"]
|
||||
repo: ["OpenHands/docs"]
|
||||
steps:
|
||||
- name: Push to docs repo
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
|
||||
2
.github/workflows/enterprise-preview.yml
vendored
2
.github/workflows/enterprise-preview.yml
vendored
@@ -26,4 +26,4 @@ jobs:
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/All-Hands-AI/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
6
.github/workflows/ghcr-build.yml
vendored
6
.github/workflows/ghcr-build.yml
vendored
@@ -37,7 +37,6 @@ jobs:
|
||||
shell: bash
|
||||
id: define-base-images
|
||||
run: |
|
||||
# Only build nikolaik on PRs, otherwise build both nikolaik and ubuntu.
|
||||
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
|
||||
@@ -46,7 +45,6 @@ jobs:
|
||||
else
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
|
||||
{ image: "ghcr.io/all-hands-ai/python-nodejs:python3.13-nodejs22-trixie", tag: "trixie" },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu" }
|
||||
]')
|
||||
fi
|
||||
@@ -200,7 +198,7 @@ jobs:
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/all-hands-ai/enterprise-server
|
||||
images: ghcr.io/openhands/enterprise-server
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
@@ -252,7 +250,7 @@ jobs:
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/All-Hands-AI/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# Run unit tests with the Docker runtime Docker images as root
|
||||
test_runtime_root:
|
||||
|
||||
4
.github/workflows/openhands-resolver.yml
vendored
4
.github/workflows/openhands-resolver.yml
vendored
@@ -201,7 +201,7 @@ jobs:
|
||||
issue_number: ${{ env.ISSUE_NUMBER }},
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
body: `[OpenHands](https://github.com/All-Hands-AI/OpenHands) started fixing the ${issueType}! You can monitor the progress [here](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}).`
|
||||
body: `[OpenHands](https://github.com/OpenHands/OpenHands) started fixing the ${issueType}! You can monitor the progress [here](https://github.com/${context.repo.owner}/${context.repo.repo}/actions/runs/${context.runId}).`
|
||||
});
|
||||
|
||||
- name: Install OpenHands
|
||||
@@ -233,7 +233,7 @@ jobs:
|
||||
if (isExperimentalLabel || isIssueCommentExperimental || isReviewCommentExperimental) {
|
||||
console.log("Installing experimental OpenHands...");
|
||||
|
||||
await exec.exec("pip install git+https://github.com/all-hands-ai/openhands.git");
|
||||
await exec.exec("pip install git+https://github.com/openhands/openhands.git");
|
||||
} else {
|
||||
console.log("Installing from requirements.txt...");
|
||||
|
||||
|
||||
2
.github/workflows/run-eval.yml
vendored
2
.github/workflows/run-eval.yml
vendored
@@ -101,7 +101,7 @@ jobs:
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"github-repo\": \"${{ steps.eval_params.outputs.repo_url }}\", \"github-branch\": \"${{ steps.eval_params.outputs.eval_branch }}\", \"pr-number\": \"${PR_NUMBER}\", \"eval-instances\": \"${{ steps.eval_params.outputs.eval_instances }}\"}}" \
|
||||
https://api.github.com/repos/All-Hands-AI/evaluation/actions/workflows/create-branch.yml/dispatches
|
||||
https://api.github.com/repos/OpenHands/evaluation/actions/workflows/create-branch.yml/dispatches
|
||||
|
||||
# Send Slack message
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
|
||||
@@ -83,6 +83,116 @@ VSCode Extension:
|
||||
- Use `vscode.window.createOutputChannel()` for debug logging instead of `showErrorMessage()` popups
|
||||
- Pre-commit process runs both frontend and backend checks when committing extension changes
|
||||
|
||||
## Enterprise Directory
|
||||
|
||||
The `enterprise/` directory contains additional functionality that extends the open-source OpenHands codebase. This includes:
|
||||
- Authentication and user management (Keycloak integration)
|
||||
- Database migrations (Alembic)
|
||||
- Integration services (GitHub, GitLab, Jira, Linear, Slack)
|
||||
- Billing and subscription management (Stripe)
|
||||
- Telemetry and analytics (PostHog, custom metrics framework)
|
||||
|
||||
### Enterprise Development Setup
|
||||
|
||||
**Prerequisites:**
|
||||
- Python 3.12
|
||||
- Poetry (for dependency management)
|
||||
- Node.js 22.x (for frontend)
|
||||
- Docker (optional)
|
||||
|
||||
**Setup Steps:**
|
||||
1. First, build the main OpenHands project: `make build`
|
||||
2. Then install enterprise dependencies: `cd enterprise && poetry install --with dev,test` (This can take a very long time. Be patient.)
|
||||
3. Set up enterprise pre-commit hooks: `poetry run pre-commit install --config ./dev_config/python/.pre-commit-config.yaml`
|
||||
|
||||
**Running Enterprise Tests:**
|
||||
```bash
|
||||
# Enterprise unit tests (full suite)
|
||||
PYTHONPATH=".:$PYTHONPATH" poetry run --project=enterprise pytest --forked -n auto -s -p no:ddtrace -p no:ddtrace.pytest_bdd -p no:ddtrace.pytest_benchmark ./enterprise/tests/unit --cov=enterprise --cov-branch
|
||||
|
||||
# Test specific modules (faster for development)
|
||||
cd enterprise
|
||||
PYTHONPATH=".:$PYTHONPATH" poetry run pytest tests/unit/telemetry/ --confcutdir=tests/unit/telemetry
|
||||
|
||||
# Enterprise linting (IMPORTANT: use --show-diff-on-failure to match GitHub CI)
|
||||
poetry run pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
```
|
||||
|
||||
**Running Enterprise Server:**
|
||||
```bash
|
||||
cd enterprise
|
||||
make start-backend # Development mode with hot reload
|
||||
# or
|
||||
make run # Full application (backend + frontend)
|
||||
```
|
||||
|
||||
**Key Configuration Files:**
|
||||
- `enterprise/pyproject.toml` - Enterprise-specific dependencies
|
||||
- `enterprise/Makefile` - Enterprise build and run commands
|
||||
- `enterprise/dev_config/python/` - Linting and type checking configuration
|
||||
- `enterprise/migrations/` - Database migration files
|
||||
|
||||
**Database Migrations:**
|
||||
Enterprise uses Alembic for database migrations. When making schema changes:
|
||||
1. Create migration files in `enterprise/migrations/versions/`
|
||||
2. Test migrations thoroughly
|
||||
3. The CI will check for migration conflicts on PRs
|
||||
|
||||
**Integration Development:**
|
||||
The enterprise codebase includes integrations for:
|
||||
- **GitHub** - PR management, webhooks, app installations
|
||||
- **GitLab** - Similar to GitHub but for GitLab instances
|
||||
- **Jira** - Issue tracking and project management
|
||||
- **Linear** - Modern issue tracking
|
||||
- **Slack** - Team communication and notifications
|
||||
|
||||
Each integration follows a consistent pattern with service classes, storage models, and API endpoints.
|
||||
|
||||
**Important Notes:**
|
||||
- Enterprise code is licensed under Polyform Free Trial License (30-day limit)
|
||||
- The enterprise server extends the OSS server through dynamic imports
|
||||
- Database changes require careful migration planning in `enterprise/migrations/`
|
||||
- Always test changes in both OSS and enterprise contexts
|
||||
- Use the enterprise-specific Makefile commands for development
|
||||
|
||||
**Enterprise Testing Best Practices:**
|
||||
|
||||
**Database Testing:**
|
||||
- Use SQLite in-memory databases (`sqlite:///:memory:`) for unit tests instead of real PostgreSQL
|
||||
- Create module-specific `conftest.py` files with database fixtures
|
||||
- Mock external database connections in unit tests to avoid dependency on running services
|
||||
- Use real database connections only for integration tests
|
||||
|
||||
**Import Patterns:**
|
||||
- Use relative imports without `enterprise.` prefix in enterprise code
|
||||
- Example: `from storage.database import session_maker` not `from enterprise.storage.database import session_maker`
|
||||
- This ensures code works in both OSS and enterprise contexts
|
||||
|
||||
**Test Structure:**
|
||||
- Place tests in `enterprise/tests/unit/` following the same structure as the source code
|
||||
- Use `--confcutdir=tests/unit/[module]` when testing specific modules
|
||||
- Create comprehensive fixtures for complex objects (databases, external services)
|
||||
- Write platform-agnostic tests (avoid hardcoded OS-specific assertions)
|
||||
|
||||
**Mocking Strategy:**
|
||||
- Use `AsyncMock` for async operations and `MagicMock` for complex objects
|
||||
- Mock all external dependencies (databases, APIs, file systems) in unit tests
|
||||
- Use `patch` with correct import paths (e.g., `telemetry.registry.logger` not `enterprise.telemetry.registry.logger`)
|
||||
- Test both success and failure scenarios with proper error handling
|
||||
|
||||
**Coverage Goals:**
|
||||
- Aim for 90%+ test coverage on new enterprise modules
|
||||
- Focus on critical business logic and error handling paths
|
||||
- Use `--cov-report=term-missing` to identify uncovered lines
|
||||
|
||||
**Troubleshooting:**
|
||||
- If tests fail, ensure all dependencies are installed: `poetry install --with dev,test`
|
||||
- For database issues, check migration status and run migrations if needed
|
||||
- For frontend issues, ensure the main OpenHands frontend is built: `make build`
|
||||
- Check logs in the `logs/` directory for runtime issues
|
||||
- If tests fail with import errors, verify `PYTHONPATH=".:$PYTHONPATH"` is set
|
||||
- **If GitHub CI fails but local linting passes**: Always use `--show-diff-on-failure` flag to match CI behavior exactly
|
||||
|
||||
## Template for Github Pull Request
|
||||
|
||||
If you are starting a pull request (PR), please follow the template in `.github/pull_request_template.md`.
|
||||
|
||||
16
.vscode/settings.json
vendored
16
.vscode/settings.json
vendored
@@ -3,4 +3,20 @@
|
||||
"files.eol": "\n",
|
||||
"files.trimTrailingWhitespace": true,
|
||||
"files.insertFinalNewline": true,
|
||||
|
||||
"python.defaultInterpreterPath": "./.venv/bin/python",
|
||||
"python.terminal.activateEnvironment": true,
|
||||
"python.analysis.autoImportCompletions": true,
|
||||
"python.analysis.autoSearchPaths": true,
|
||||
"python.analysis.extraPaths": [
|
||||
"./.venv/lib/python3.12/site-packages"
|
||||
],
|
||||
"python.analysis.packageIndexDepths": [
|
||||
{
|
||||
"name": "openhands",
|
||||
"depth": 10,
|
||||
"includeAllSymbols": true
|
||||
}
|
||||
],
|
||||
"python.analysis.stubPath": "./.venv/lib/python3.12/site-packages",
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ These Slack etiquette guidelines are designed to foster an inclusive, respectful
|
||||
- Post questions or discussions in the most relevant channel (e.g., for [slack - #general](https://openhands-ai.slack.com/archives/C06P5NCGSFP) for general topics, [slack - #questions](https://openhands-ai.slack.com/archives/C06U8UTKSAD) for queries/questions.
|
||||
- When asking for help or raising issues, include necessary details like links, screenshots, or clear explanations to provide context.
|
||||
- Keep discussions in public channels whenever possible to allow others to benefit from the conversation, unless the matter is sensitive or private.
|
||||
- Always adhere to [our standards](https://github.com/All-Hands-AI/OpenHands/blob/main/CODE_OF_CONDUCT.md#our-standards) to ensure a welcoming and collaborative environment.
|
||||
- Always adhere to [our standards](https://github.com/OpenHands/OpenHands/blob/main/CODE_OF_CONDUCT.md#our-standards) to ensure a welcoming and collaborative environment.
|
||||
- If you choose to mute a channel, consider setting up alerts for topics that still interest you to stay engaged. For Slack, Go to Settings → Notifications → My Keywords to add specific keywords that will notify you when mentioned. For example, if you're here for discussions about LLMs, mute the channel if it’s too busy, but set notifications to alert you only when “LLMs” appears in messages.
|
||||
|
||||
## Attribution
|
||||
|
||||
@@ -8,7 +8,7 @@ If this resonates with you, we'd love to have you join us in our quest!
|
||||
|
||||
## 🤝 How to Join
|
||||
|
||||
Check out our [How to Join the Community section.](https://github.com/All-Hands-AI/OpenHands?tab=readme-ov-file#-how-to-join-the-community)
|
||||
Check out our [How to Join the Community section.](https://github.com/OpenHands/OpenHands?tab=readme-ov-file#-how-to-join-the-community)
|
||||
|
||||
## 💪 Becoming a Contributor
|
||||
|
||||
|
||||
@@ -13,15 +13,15 @@ To understand the codebase, please refer to the README in each module:
|
||||
|
||||
## Setting up Your Development Environment
|
||||
|
||||
We have a separate doc [Development.md](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md) that tells you how to set up a development workflow.
|
||||
We have a separate doc [Development.md](https://github.com/OpenHands/OpenHands/blob/main/Development.md) that tells you how to set up a development workflow.
|
||||
|
||||
## How Can I Contribute?
|
||||
|
||||
There are many ways that you can contribute:
|
||||
|
||||
1. **Download and use** OpenHands, and send [issues](https://github.com/All-Hands-AI/OpenHands/issues) when you encounter something that isn't working or a feature that you'd like to see.
|
||||
1. **Download and use** OpenHands, and send [issues](https://github.com/OpenHands/OpenHands/issues) when you encounter something that isn't working or a feature that you'd like to see.
|
||||
2. **Send feedback** after each session by [clicking the thumbs-up thumbs-down buttons](https://docs.all-hands.dev/usage/feedback), so we can see where things are working and failing, and also build an open dataset for training code agents.
|
||||
3. **Improve the Codebase** by sending [PRs](#sending-pull-requests-to-openhands) (see details below). In particular, we have some [good first issues](https://github.com/All-Hands-AI/OpenHands/labels/good%20first%20issue) that may be ones to start on.
|
||||
3. **Improve the Codebase** by sending [PRs](#sending-pull-requests-to-openhands) (see details below). In particular, we have some [good first issues](https://github.com/OpenHands/OpenHands/labels/good%20first%20issue) that may be ones to start on.
|
||||
|
||||
## What Can I Build?
|
||||
Here are a few ways you can help improve the codebase.
|
||||
@@ -35,7 +35,7 @@ of the application, please open an issue first, or better, join the #eng-ui-ux c
|
||||
to gather consensus from our design team first.
|
||||
|
||||
#### Improving the agent
|
||||
Our main agent is the CodeAct agent. You can [see its prompts here](https://github.com/All-Hands-AI/OpenHands/tree/main/openhands/agenthub/codeact_agent).
|
||||
Our main agent is the CodeAct agent. You can [see its prompts here](https://github.com/OpenHands/OpenHands/tree/main/openhands/agenthub/codeact_agent).
|
||||
|
||||
Changes to these prompts, and to the underlying behavior in Python, can have a huge impact on user experience.
|
||||
You can try modifying the prompts to see how they change the behavior of the agent as you use the app
|
||||
@@ -54,7 +54,7 @@ The agent needs a place to run code and commands. When you run OpenHands on your
|
||||
to do this by default. But there are other ways of creating a sandbox for the agent.
|
||||
|
||||
If you work for a company that provides a cloud-based runtime, you could help us add support for that runtime
|
||||
by implementing the [interface specified here](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/base.py).
|
||||
by implementing the [interface specified here](https://github.com/OpenHands/OpenHands/blob/main/openhands/runtime/base.py).
|
||||
|
||||
#### Testing
|
||||
When you write code, it is also good to write tests. Please navigate to the [`./tests`](./tests) folder to see existing test suites.
|
||||
@@ -84,7 +84,7 @@ For example, a PR title could be:
|
||||
- `refactor: modify package path`
|
||||
- `feat(frontend): xxxx`, where `(frontend)` means that this PR mainly focuses on the frontend component.
|
||||
|
||||
You may also check out previous PRs in the [PR list](https://github.com/All-Hands-AI/OpenHands/pulls).
|
||||
You may also check out previous PRs in the [PR list](https://github.com/OpenHands/OpenHands/pulls).
|
||||
|
||||
### Pull Request description
|
||||
- If your PR is small (such as a typo fix), you can go brief.
|
||||
@@ -97,7 +97,7 @@ please include a short message that we can add to our changelog.
|
||||
|
||||
### Opening Issues
|
||||
|
||||
If you notice any bugs or have any feature requests please open them via the [issues page](https://github.com/All-Hands-AI/OpenHands/issues). We will triage based on how critical the bug is or how potentially useful the improvement is, discuss, and implement the ones that the community has interest/effort for.
|
||||
If you notice any bugs or have any feature requests please open them via the [issues page](https://github.com/OpenHands/OpenHands/issues). We will triage based on how critical the bug is or how potentially useful the improvement is, discuss, and implement the ones that the community has interest/effort for.
|
||||
|
||||
Further, if you see an issue you like, please leave a "thumbs-up" or a comment, which will help us prioritize.
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
## Contributors
|
||||
|
||||
We would like to thank all the [contributors](https://github.com/All-Hands-AI/OpenHands/graphs/contributors) who have helped make OpenHands possible. We greatly appreciate your dedication and hard work.
|
||||
We would like to thank all the [contributors](https://github.com/OpenHands/OpenHands/graphs/contributors) who have helped make OpenHands possible. We greatly appreciate your dedication and hard work.
|
||||
|
||||
## Open Source Projects
|
||||
|
||||
@@ -14,7 +14,7 @@ OpenHands includes and adapts the following open source projects. We are gratefu
|
||||
|
||||
#### [Aider](https://github.com/paul-gauthier/aider)
|
||||
- License: Apache License 2.0
|
||||
- Description: AI pair programming tool. OpenHands has adapted and integrated its linter module for code-related tasks in [`agentskills utilities`](https://github.com/All-Hands-AI/OpenHands/tree/main/openhands/runtime/plugins/agent_skills/utils/aider)
|
||||
- Description: AI pair programming tool. OpenHands has adapted and integrated its linter module for code-related tasks in [`agentskills utilities`](https://github.com/OpenHands/OpenHands/tree/main/openhands/runtime/plugins/agent_skills/utils/aider)
|
||||
|
||||
#### [BrowserGym](https://github.com/ServiceNow/BrowserGym)
|
||||
- License: Apache License 2.0
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
This guide is for people working on OpenHands and editing the source code.
|
||||
If you wish to contribute your changes, check out the
|
||||
[CONTRIBUTING.md](https://github.com/All-Hands-AI/OpenHands/blob/main/CONTRIBUTING.md)
|
||||
[CONTRIBUTING.md](https://github.com/OpenHands/OpenHands/blob/main/CONTRIBUTING.md)
|
||||
on how to clone and setup the project initially before moving on. Otherwise,
|
||||
you can clone the OpenHands project directly.
|
||||
|
||||
@@ -159,7 +159,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:0.59-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:0.60-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
@@ -193,7 +193,7 @@ Here's a guide to the important documentation files in the repository:
|
||||
- [/README.md](./README.md): Main project overview, features, and basic setup instructions
|
||||
- [/Development.md](./Development.md) (this file): Comprehensive guide for developers working on OpenHands
|
||||
- [/CONTRIBUTING.md](./CONTRIBUTING.md): Guidelines for contributing to the project, including code style and PR process
|
||||
- [/docs/DOC_STYLE_GUIDE.md](./docs/DOC_STYLE_GUIDE.md): Standards for writing and maintaining project documentation
|
||||
- [DOC_STYLE_GUIDE.md](https://github.com/All-Hands-AI/docs/blob/main/openhands/DOC_STYLE_GUIDE.md): Standards for writing and maintaining project documentation
|
||||
- [/openhands/README.md](./openhands/README.md): Details about the backend Python implementation
|
||||
- [/frontend/README.md](./frontend/README.md): Frontend React application setup and development guide
|
||||
- [/containers/README.md](./containers/README.md): Information about Docker containers and deployment
|
||||
|
||||
42
README.md
42
README.md
@@ -7,26 +7,26 @@
|
||||
|
||||
|
||||
<div align="center">
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/graphs/contributors"><img src="https://img.shields.io/github/contributors/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="Contributors"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/stargazers"><img src="https://img.shields.io/github/stars/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="Stargazers"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/graphs/contributors"><img src="https://img.shields.io/github/contributors/OpenHands/OpenHands?style=for-the-badge&color=blue" alt="Contributors"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/stargazers"><img src="https://img.shields.io/github/stars/OpenHands/OpenHands?style=for-the-badge&color=blue" alt="Stargazers"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/OpenHands/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://all-hands.dev/joinslack"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<a href="https://github.com/OpenHands/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<br/>
|
||||
<a href="https://docs.all-hands.dev/usage/getting-started"><img src="https://img.shields.io/badge/Documentation-000?logo=googledocs&logoColor=FFE165&style=for-the-badge" alt="Check out the documentation"></a>
|
||||
<a href="https://arxiv.org/abs/2407.16741"><img src="https://img.shields.io/badge/Paper%20on%20Arxiv-000?logoColor=FFE165&logo=arxiv&style=for-the-badge" alt="Paper on Arxiv"></a>
|
||||
<a href="https://docs.google.com/spreadsheets/d/1wOUdFCMyY6Nt0AIqF705KN4JKOWgeI4wUGUP60krXXs/edit?gid=0#gid=0"><img src="https://img.shields.io/badge/Benchmark%20score-000?logoColor=FFE165&logo=huggingface&style=for-the-badge" alt="Evaluation Benchmark Score"></a>
|
||||
|
||||
<!-- Keep these links. Translations will automatically update with the README. -->
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=de">Deutsch</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=es">Español</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=fr">français</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=ja">日本語</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=ko">한국어</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=pt">Português</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=ru">Русский</a> |
|
||||
<a href="https://www.readme-i18n.com/All-Hands-AI/OpenHands?lang=zh">中文</a>
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=de">Deutsch</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=es">Español</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=fr">français</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=ja">日本語</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=ko">한국어</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=pt">Português</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=ru">Русский</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=zh">中文</a>
|
||||
|
||||
<hr>
|
||||
</div>
|
||||
@@ -82,17 +82,17 @@ You'll find OpenHands running at [http://localhost:3000](http://localhost:3000)
|
||||
You can also run OpenHands directly with Docker:
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.59-nikolaik
|
||||
docker pull docker.openhands.dev/openhands/runtime:0.60-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.59-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.openhands.dev/openhands/runtime:0.60-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.59
|
||||
docker.openhands.dev/openhands/openhands:0.60
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -119,7 +119,7 @@ system requirements and more information.
|
||||
> It is not appropriate for multi-tenant deployments where multiple users share the same instance. There is no built-in authentication, isolation, or scalability.
|
||||
>
|
||||
> If you're interested in running OpenHands in a multi-tenant environment, check out the source-available, commercially-licensed
|
||||
> [OpenHands Cloud Helm Chart](https://github.com/all-Hands-AI/OpenHands-cloud)
|
||||
> [OpenHands Cloud Helm Chart](https://github.com/openHands/OpenHands-cloud)
|
||||
|
||||
You can [connect OpenHands to your local filesystem](https://docs.all-hands.dev/usage/runtimes/docker#connecting-to-your-filesystem),
|
||||
interact with it via a [friendly CLI](https://docs.all-hands.dev/usage/how-to/cli-mode),
|
||||
@@ -128,7 +128,7 @@ or run it on tagged issues with [a github action](https://docs.all-hands.dev/usa
|
||||
|
||||
Visit [Running OpenHands](https://docs.all-hands.dev/usage/installation) for more information and setup instructions.
|
||||
|
||||
If you want to modify the OpenHands source code, check out [Development.md](https://github.com/All-Hands-AI/OpenHands/blob/main/Development.md).
|
||||
If you want to modify the OpenHands source code, check out [Development.md](https://github.com/OpenHands/OpenHands/blob/main/Development.md).
|
||||
|
||||
Having issues? The [Troubleshooting Guide](https://docs.all-hands.dev/usage/troubleshooting) can help.
|
||||
|
||||
@@ -146,17 +146,17 @@ OpenHands is a community-driven project, and we welcome contributions from every
|
||||
through Slack, so this is the best place to start, but we also are happy to have you contact us on Github:
|
||||
|
||||
- [Join our Slack workspace](https://all-hands.dev/joinslack) - Here we talk about research, architecture, and future development.
|
||||
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||
- [Read or post Github Issues](https://github.com/OpenHands/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||
|
||||
See more about the community in [COMMUNITY.md](./COMMUNITY.md) or find details on contributing in [CONTRIBUTING.md](./CONTRIBUTING.md).
|
||||
|
||||
## 📈 Progress
|
||||
|
||||
See the monthly OpenHands roadmap [here](https://github.com/orgs/All-Hands-AI/projects/1) (updated at the maintainer's meeting at the end of each month).
|
||||
See the monthly OpenHands roadmap [here](https://github.com/orgs/OpenHands/projects/1) (updated at the maintainer's meeting at the end of each month).
|
||||
|
||||
<p align="center">
|
||||
<a href="https://star-history.com/#All-Hands-AI/OpenHands&Date">
|
||||
<img src="https://api.star-history.com/svg?repos=All-Hands-AI/OpenHands&type=Date" width="500" alt="Star History Chart">
|
||||
<a href="https://star-history.com/#OpenHands/OpenHands&Date">
|
||||
<img src="https://api.star-history.com/svg?repos=OpenHands/OpenHands&type=Date" width="500" alt="Star History Chart">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
|
||||
@@ -189,7 +189,7 @@ model = "gpt-4o"
|
||||
# Whether to use native tool calling if supported by the model. Can be true, false, or None by default, which chooses the model's default behavior based on the evaluation.
|
||||
# ATTENTION: Based on evaluation, enabling native function calling may lead to worse results
|
||||
# in some scenarios. Use with caution and consider testing with your specific use case.
|
||||
# https://github.com/All-Hands-AI/OpenHands/pull/4711
|
||||
# https://github.com/OpenHands/OpenHands/pull/4711
|
||||
#native_tool_calling = None
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
DOCKER_REGISTRY=ghcr.io
|
||||
DOCKER_ORG=all-hands-ai
|
||||
DOCKER_ORG=openhands
|
||||
DOCKER_IMAGE=openhands
|
||||
DOCKER_BASE_DIR="."
|
||||
|
||||
@@ -104,6 +104,9 @@ RUN apt-get update && apt-get install -y \
|
||||
&& apt-get clean \
|
||||
&& apt-get autoremove -y
|
||||
|
||||
# mark /app as safe git directory to avoid pre-commit errors
|
||||
RUN git config --system --add safe.directory /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# cache build dependencies
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/all-hands-ai/runtime:0.59-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:0.60-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
DOCKER_REGISTRY=ghcr.io
|
||||
DOCKER_ORG=all-hands-ai
|
||||
DOCKER_ORG=openhands
|
||||
DOCKER_BASE_DIR="./containers/runtime"
|
||||
DOCKER_IMAGE=runtime
|
||||
# These variables will be appended by the runtime_build.py script
|
||||
|
||||
@@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.all-hands.dev/all-hands-ai/runtime:0.59-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:0.60-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG OPENHANDS_VERSION=latest
|
||||
ARG BASE="ghcr.io/all-hands-ai/openhands"
|
||||
ARG BASE="ghcr.io/openhands/openhands"
|
||||
FROM ${BASE}:${OPENHANDS_VERSION}
|
||||
|
||||
# Datadog labels
|
||||
|
||||
@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS_PATH ?= ".."
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
This directory contains the enterprise server used by [OpenHands Cloud](https://github.com/All-Hands-AI/OpenHands-Cloud/). The official, public version of OpenHands Cloud is available at
|
||||
[app.all-hands.dev](https://app.all-hands.dev).
|
||||
|
||||
You may also want to check out the MIT-licensed [OpenHands](https://github.com/All-Hands-AI/OpenHands)
|
||||
You may also want to check out the MIT-licensed [OpenHands](https://github.com/OpenHands/OpenHands)
|
||||
|
||||
## Extension of OpenHands (OSS)
|
||||
|
||||
@@ -16,7 +16,7 @@ The code in `/enterprise` directory builds on top of open source (OSS) code, ext
|
||||
|
||||
- Enterprise stacks on top of OSS. For example, the middleware in enterprise is stacked right on top of the middlewares in OSS. In `SAAS`, the middleware from BOTH repos will be present and running (which can sometimes cause conflicts)
|
||||
|
||||
- Enterprise overrides the implementation in OSS (only one is present at a time). For example, the server config SaasServerConfig which overrides [`ServerConfig`](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L8) on OSS. This is done through dynamic imports ([see here](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L37-#L45))
|
||||
- Enterprise overrides the implementation in OSS (only one is present at a time). For example, the server config SaasServerConfig which overrides [`ServerConfig`](https://github.com/OpenHands/OpenHands/blob/main/openhands/server/config/server_config.py#L8) on OSS. This is done through dynamic imports ([see here](https://github.com/OpenHands/OpenHands/blob/main/openhands/server/config/server_config.py#L37-#L45))
|
||||
|
||||
Key areas that change on `SAAS` are
|
||||
|
||||
|
||||
@@ -0,0 +1,856 @@
|
||||
# OpenHands Enterprise Usage Telemetry Service
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Introduction](#1-introduction)
|
||||
- 1.1 [Problem Statement](#11-problem-statement)
|
||||
- 1.2 [Proposed Solution](#12-proposed-solution)
|
||||
2. [User Interface](#2-user-interface)
|
||||
- 2.1 [License Warning Banner](#21-license-warning-banner)
|
||||
- 2.2 [Administrator Experience](#22-administrator-experience)
|
||||
3. [Other Context](#3-other-context)
|
||||
- 3.1 [Replicated Platform Integration](#31-replicated-platform-integration)
|
||||
- 3.2 [Administrator Email Detection Strategy](#32-administrator-email-detection-strategy)
|
||||
- 3.3 [Metrics Collection Framework](#33-metrics-collection-framework)
|
||||
4. [Technical Design](#4-technical-design)
|
||||
- 4.1 [Database Schema](#41-database-schema)
|
||||
- 4.1.1 [Telemetry Metrics Table](#411-telemetry-metrics-table)
|
||||
- 4.1.2 [Telemetry Identity Table](#412-telemetry-identity-table)
|
||||
- 4.2 [Metrics Collection Framework](#42-metrics-collection-framework)
|
||||
- 4.2.1 [Base Collector Interface](#421-base-collector-interface)
|
||||
- 4.2.2 [Collector Registry](#422-collector-registry)
|
||||
- 4.2.3 [Example Collector Implementation](#423-example-collector-implementation)
|
||||
- 4.3 [Collection and Upload System](#43-collection-and-upload-system)
|
||||
- 4.3.1 [Metrics Collection Processor](#431-metrics-collection-processor)
|
||||
- 4.3.2 [Replicated Upload Processor](#432-replicated-upload-processor)
|
||||
- 4.4 [License Warning System](#44-license-warning-system)
|
||||
- 4.4.1 [License Status Endpoint](#441-license-status-endpoint)
|
||||
- 4.4.2 [UI Integration](#442-ui-integration)
|
||||
- 4.5 [Cronjob Configuration](#45-cronjob-configuration)
|
||||
- 4.5.1 [Collection Cronjob](#451-collection-cronjob)
|
||||
- 4.5.2 [Upload Cronjob](#452-upload-cronjob)
|
||||
5. [Implementation Plan](#5-implementation-plan)
|
||||
- 5.1 [Database Schema and Models (M1)](#51-database-schema-and-models-m1)
|
||||
- 5.1.1 [OpenHands - Database Migration](#511-openhands---database-migration)
|
||||
- 5.1.2 [OpenHands - Model Tests](#512-openhands---model-tests)
|
||||
- 5.2 [Metrics Collection Framework (M2)](#52-metrics-collection-framework-m2)
|
||||
- 5.2.1 [OpenHands - Core Collection Framework](#521-openhands---core-collection-framework)
|
||||
- 5.2.2 [OpenHands - Example Collectors](#522-openhands---example-collectors)
|
||||
- 5.2.3 [OpenHands - Framework Tests](#523-openhands---framework-tests)
|
||||
- 5.3 [Collection and Upload Processors (M3)](#53-collection-and-upload-processors-m3)
|
||||
- 5.3.1 [OpenHands - Collection Processor](#531-openhands---collection-processor)
|
||||
- 5.3.2 [OpenHands - Upload Processor](#532-openhands---upload-processor)
|
||||
- 5.3.3 [OpenHands - Integration Tests](#533-openhands---integration-tests)
|
||||
- 5.4 [License Warning API (M4)](#54-license-warning-api-m4)
|
||||
- 5.4.1 [OpenHands - License Status API](#541-openhands---license-status-api)
|
||||
- 5.4.2 [OpenHands - API Integration](#542-openhands---api-integration)
|
||||
- 5.5 [UI Warning Banner (M5)](#55-ui-warning-banner-m5)
|
||||
- 5.5.1 [OpenHands - UI Warning Banner](#551-openhands---ui-warning-banner)
|
||||
- 5.5.2 [OpenHands - UI Integration](#552-openhands---ui-integration)
|
||||
- 5.6 [Helm Chart Deployment Configuration (M6)](#56-helm-chart-deployment-configuration-m6)
|
||||
- 5.6.1 [OpenHands-Cloud - Cronjob Manifests](#561-openhands-cloud---cronjob-manifests)
|
||||
- 5.6.2 [OpenHands-Cloud - Configuration Management](#562-openhands-cloud---configuration-management)
|
||||
- 5.7 [Documentation and Enhanced Collectors (M7)](#57-documentation-and-enhanced-collectors-m7)
|
||||
- 5.7.1 [OpenHands - Advanced Collectors](#571-openhands---advanced-collectors)
|
||||
- 5.7.2 [OpenHands - Monitoring and Testing](#572-openhands---monitoring-and-testing)
|
||||
- 5.7.3 [OpenHands - Technical Documentation](#573-openhands---technical-documentation)
|
||||
|
||||
## 1. Introduction
|
||||
|
||||
### 1.1 Problem Statement
|
||||
|
||||
OpenHands Enterprise (OHE) helm charts are publicly available but not open source, creating a visibility gap for the sales team. Unknown users can install and use OHE without the vendor's knowledge, preventing proper customer engagement and sales pipeline management. Without usage telemetry, the vendor cannot identify potential customers, track installation health, or proactively support users who may need assistance.
|
||||
|
||||
### 1.2 Proposed Solution
|
||||
|
||||
We propose implementing a comprehensive telemetry service that leverages the Replicated metrics platform and Python SDK to track OHE installations and usage. The solution provides automatic customer discovery, instance monitoring, and usage metrics collection while maintaining a clear license compliance pathway.
|
||||
|
||||
The system consists of three main components: (1) a pluggable metrics collection framework that allows developers to easily define and register custom metrics collectors, (2) automated cronjobs that periodically collect metrics and upload them to Replicated's vendor portal, and (3) a license compliance warning system that displays UI notifications when telemetry uploads fail, indicating potential license expiration.
|
||||
|
||||
The design ensures that telemetry cannot be easily disabled without breaking core OHE functionality by tying the warning system to environment variables that are essential for OHE operation. This approach balances user transparency with business requirements for customer visibility.
|
||||
|
||||
## 2. User Interface
|
||||
|
||||
### 2.1 License Warning Banner
|
||||
|
||||
When telemetry uploads fail for more than 4 days, users will see a prominent warning banner in the OpenHands Enterprise UI:
|
||||
|
||||
```
|
||||
⚠️ Your OpenHands Enterprise license will expire in 30 days. Please contact support if this issue persists.
|
||||
```
|
||||
|
||||
The banner appears at the top of all pages and cannot be permanently dismissed while the condition persists. Users can temporarily dismiss it, but it will reappear on page refresh until telemetry uploads resume successfully.
|
||||
|
||||
### 2.2 Administrator Experience
|
||||
|
||||
System administrators will not need to configure the telemetry system manually. The service automatically:
|
||||
|
||||
1. **Detects OHE installations** using existing required environment variables (`GITHUB_APP_CLIENT_ID`, `KEYCLOAK_SERVER_URL`, etc.)
|
||||
|
||||
2. **Generates unique customer identifiers** using administrator contact information:
|
||||
- Customer email: Determined by the following priority order:
|
||||
1. `OPENHANDS_ADMIN_EMAIL` environment variable (if set in helm values)
|
||||
2. Email of the first user who accepted Terms of Service (earliest `accepted_tos` timestamp)
|
||||
- Instance ID: Automatically generated by Replicated SDK using machine fingerprinting (IOPlatformUUID on macOS, D-Bus machine ID on Linux, Machine GUID on Windows)
|
||||
- **No Fallback**: If neither email source is available, telemetry collection is skipped until at least one user exists
|
||||
|
||||
3. **Collects and uploads metrics transparently** in the background via weekly collection and daily upload cronjobs
|
||||
|
||||
4. **Displays warnings only when necessary** for license compliance - no notifications appear during normal operation
|
||||
|
||||
## 3. Other Context
|
||||
|
||||
### 3.1 Replicated Platform Integration
|
||||
|
||||
The Replicated platform provides vendor-hosted infrastructure for collecting customer and instance telemetry. The Python SDK handles authentication, state management, and reliable metric delivery. Key concepts:
|
||||
|
||||
- **Customer**: Represents a unique OHE installation, identified by email or installation fingerprint
|
||||
- **Instance**: Represents a specific deployment of OHE for a customer
|
||||
- **Metrics**: Custom key-value data points collected from the installation
|
||||
- **Status**: Instance health indicators (running, degraded, updating, etc.)
|
||||
|
||||
The SDK automatically handles machine fingerprinting, local state caching, and retry logic for failed uploads.
|
||||
|
||||
### 3.2 Administrator Email Detection Strategy
|
||||
|
||||
To identify the appropriate administrator contact for sales outreach, the system uses a three-tier approach that avoids performance penalties on user authentication:
|
||||
|
||||
**Tier 1: Explicit Configuration** - The `OPENHANDS_ADMIN_EMAIL` environment variable allows administrators to explicitly specify the contact email during deployment.
|
||||
|
||||
**Tier 2: First Active User Detection** - If no explicit email is configured, the system identifies the first user who accepted Terms of Service (earliest `accepted_tos` timestamp with a valid email). This represents the first person to actively engage with the system and is very likely the administrator or installer.
|
||||
|
||||
**No Fallback Needed** - If neither email source is available, telemetry collection is skipped entirely. This ensures we only report meaningful usage data when there are actual active users.
|
||||
|
||||
**Performance Optimization**: The admin email determination is performed only during telemetry upload attempts, ensuring zero performance impact on user login flows.
|
||||
|
||||
### 3.3 Metrics Collection Framework
|
||||
|
||||
The proposed collector framework allows developers to define metrics in a single file change:
|
||||
|
||||
```python
|
||||
@register_collector("user_activity")
|
||||
class UserActivityCollector(MetricsCollector):
|
||||
def collect(self) -> Dict[str, Any]:
|
||||
# Query database and return metrics
|
||||
return {"active_users_7d": count, "conversations_created": total}
|
||||
```
|
||||
|
||||
Collectors are automatically discovered and executed by the collection cronjob, making the system extensible without modifying core collection logic.
|
||||
|
||||
## 4. Technical Design
|
||||
|
||||
### 4.1 Database Schema
|
||||
|
||||
#### 4.1.1 Telemetry Metrics Table
|
||||
|
||||
Stores collected metrics with transmission status tracking:
|
||||
|
||||
```sql
|
||||
CREATE TABLE telemetry_metrics (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
collected_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
metrics_data JSONB NOT NULL,
|
||||
uploaded_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
upload_attempts INTEGER DEFAULT 0,
|
||||
last_upload_error TEXT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX idx_telemetry_metrics_collected_at ON telemetry_metrics(collected_at);
|
||||
CREATE INDEX idx_telemetry_metrics_uploaded_at ON telemetry_metrics(uploaded_at);
|
||||
```
|
||||
|
||||
#### 4.1.2 Telemetry Identity Table
|
||||
|
||||
Stores persistent identity information that must survive container restarts:
|
||||
|
||||
```sql
|
||||
CREATE TABLE telemetry_identity (
|
||||
id INTEGER PRIMARY KEY DEFAULT 1,
|
||||
customer_id VARCHAR(255) NULL,
|
||||
instance_id VARCHAR(255) NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
||||
CONSTRAINT single_identity_row CHECK (id = 1)
|
||||
);
|
||||
```
|
||||
|
||||
**Design Rationale:**
|
||||
- **Separation of Concerns**: Identity data (customer_id, instance_id) is separated from operational data
|
||||
- **Persistent vs Computed**: Only data that cannot be reliably recomputed is persisted
|
||||
- **Upload Tracking**: Upload timestamps are tied directly to the metrics they represent
|
||||
- **Simplified Queries**: System state can be derived from metrics table (e.g., `MAX(uploaded_at)` for last successful upload)
|
||||
|
||||
### 4.2 Metrics Collection Framework
|
||||
|
||||
#### 4.2.1 Base Collector Interface
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class MetricResult:
|
||||
key: str
|
||||
value: Any
|
||||
|
||||
class MetricsCollector(ABC):
|
||||
"""Base class for metrics collectors."""
|
||||
|
||||
@abstractmethod
|
||||
def collect(self) -> List[MetricResult]:
|
||||
"""Collect metrics and return results."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def collector_name(self) -> str:
|
||||
"""Unique name for this collector."""
|
||||
pass
|
||||
|
||||
def should_collect(self) -> bool:
|
||||
"""Override to add collection conditions."""
|
||||
return True
|
||||
```
|
||||
|
||||
#### 4.2.2 Collector Registry
|
||||
|
||||
```python
|
||||
from typing import Dict, Type, List
|
||||
import importlib
|
||||
import pkgutil
|
||||
|
||||
class CollectorRegistry:
|
||||
"""Registry for metrics collectors."""
|
||||
|
||||
def __init__(self):
|
||||
self._collectors: Dict[str, Type[MetricsCollector]] = {}
|
||||
|
||||
def register(self, collector_class: Type[MetricsCollector]) -> None:
|
||||
"""Register a collector class."""
|
||||
collector = collector_class()
|
||||
self._collectors[collector.collector_name] = collector_class
|
||||
|
||||
def get_all_collectors(self) -> List[MetricsCollector]:
|
||||
"""Get instances of all registered collectors."""
|
||||
return [cls() for cls in self._collectors.values()]
|
||||
|
||||
def discover_collectors(self, package_path: str) -> None:
|
||||
"""Auto-discover collectors in a package."""
|
||||
# Implementation to scan for @register_collector decorators
|
||||
pass
|
||||
|
||||
# Global registry instance
|
||||
collector_registry = CollectorRegistry()
|
||||
|
||||
def register_collector(name: str):
|
||||
"""Decorator to register a collector."""
|
||||
def decorator(cls: Type[MetricsCollector]) -> Type[MetricsCollector]:
|
||||
collector_registry.register(cls)
|
||||
return cls
|
||||
return decorator
|
||||
```
|
||||
|
||||
#### 4.2.3 Example Collector Implementation
|
||||
|
||||
```python
|
||||
@register_collector("system_metrics")
|
||||
class SystemMetricsCollector(MetricsCollector):
|
||||
"""Collects basic system and usage metrics."""
|
||||
|
||||
@property
|
||||
def collector_name(self) -> str:
|
||||
return "system_metrics"
|
||||
|
||||
def collect(self) -> List[MetricResult]:
|
||||
results = []
|
||||
|
||||
# Collect user count
|
||||
with session_maker() as session:
|
||||
user_count = session.query(UserSettings).count()
|
||||
results.append(MetricResult(
|
||||
key="total_users",
|
||||
value=user_count
|
||||
))
|
||||
|
||||
# Collect conversation count (last 30 days)
|
||||
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
conversation_count = session.query(StoredConversationMetadata)\
|
||||
.filter(StoredConversationMetadata.created_at >= thirty_days_ago)\
|
||||
.count()
|
||||
|
||||
results.append(MetricResult(
|
||||
key="conversations_30d",
|
||||
value=conversation_count
|
||||
))
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
### 4.3 Collection and Upload System
|
||||
|
||||
#### 4.3.1 Metrics Collection Processor
|
||||
|
||||
```python
|
||||
class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
|
||||
"""Maintenance task processor for collecting metrics."""
|
||||
|
||||
collection_interval_days: int = 7
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""Collect metrics from all registered collectors."""
|
||||
|
||||
# Check if collection is needed
|
||||
if not self._should_collect():
|
||||
return {"status": "skipped", "reason": "too_recent"}
|
||||
|
||||
# Collect metrics from all registered collectors
|
||||
all_metrics = {}
|
||||
collector_results = {}
|
||||
|
||||
for collector in collector_registry.get_all_collectors():
|
||||
try:
|
||||
if collector.should_collect():
|
||||
results = collector.collect()
|
||||
for result in results:
|
||||
all_metrics[result.key] = result.value
|
||||
collector_results[collector.collector_name] = len(results)
|
||||
except Exception as e:
|
||||
logger.error(f"Collector {collector.collector_name} failed: {e}")
|
||||
collector_results[collector.collector_name] = f"error: {e}"
|
||||
|
||||
# Store metrics in database
|
||||
with session_maker() as session:
|
||||
telemetry_record = TelemetryMetrics(
|
||||
metrics_data=all_metrics,
|
||||
collected_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(telemetry_record)
|
||||
session.commit()
|
||||
|
||||
# Note: No need to track last_collection_at separately
|
||||
# Can be derived from MAX(collected_at) in telemetry_metrics
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"metrics_collected": len(all_metrics),
|
||||
"collectors_run": collector_results
|
||||
}
|
||||
|
||||
def _should_collect(self) -> bool:
|
||||
"""Check if collection is needed based on interval."""
|
||||
with session_maker() as session:
|
||||
# Get last collection time from metrics table
|
||||
last_collected = session.query(func.max(TelemetryMetrics.collected_at)).scalar()
|
||||
if not last_collected:
|
||||
return True
|
||||
|
||||
time_since_last = datetime.now(timezone.utc) - last_collected
|
||||
return time_since_last.days >= self.collection_interval_days
|
||||
```
|
||||
|
||||
#### 4.3.2 Replicated Upload Processor
|
||||
|
||||
```python
|
||||
from replicated import AsyncReplicatedClient, InstanceStatus
|
||||
|
||||
class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
"""Maintenance task processor for uploading metrics to Replicated."""
|
||||
|
||||
replicated_publishable_key: str
|
||||
replicated_app_slug: str
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""Upload pending metrics to Replicated."""
|
||||
|
||||
# Get pending metrics
|
||||
with session_maker() as session:
|
||||
pending_metrics = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.uploaded_at.is_(None))\
|
||||
.order_by(TelemetryMetrics.collected_at)\
|
||||
.all()
|
||||
|
||||
if not pending_metrics:
|
||||
return {"status": "no_pending_metrics"}
|
||||
|
||||
# Get admin email - skip if not available
|
||||
admin_email = self._get_admin_email()
|
||||
if not admin_email:
|
||||
logger.info("Skipping telemetry upload - no admin email available")
|
||||
return {
|
||||
"status": "skipped",
|
||||
"reason": "no_admin_email",
|
||||
"total_processed": 0
|
||||
}
|
||||
|
||||
uploaded_count = 0
|
||||
failed_count = 0
|
||||
|
||||
async with AsyncReplicatedClient(
|
||||
publishable_key=self.replicated_publishable_key,
|
||||
app_slug=self.replicated_app_slug
|
||||
) as client:
|
||||
|
||||
# Get or create customer and instance
|
||||
customer = await client.customer.get_or_create(
|
||||
email_address=admin_email
|
||||
)
|
||||
instance = await customer.get_or_create_instance()
|
||||
|
||||
# Store customer/instance IDs for future use
|
||||
await self._update_telemetry_identity(customer.customer_id, instance.instance_id)
|
||||
|
||||
# Upload each metric batch
|
||||
for metric_record in pending_metrics:
|
||||
try:
|
||||
# Send individual metrics
|
||||
for key, value in metric_record.metrics_data.items():
|
||||
await instance.send_metric(key, value)
|
||||
|
||||
# Update instance status
|
||||
await instance.set_status(InstanceStatus.RUNNING)
|
||||
|
||||
# Mark as uploaded
|
||||
with session_maker() as session:
|
||||
record = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.id == metric_record.id)\
|
||||
.first()
|
||||
if record:
|
||||
record.uploaded_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
|
||||
uploaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload metrics {metric_record.id}: {e}")
|
||||
|
||||
# Update error info
|
||||
with session_maker() as session:
|
||||
record = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.id == metric_record.id)\
|
||||
.first()
|
||||
if record:
|
||||
record.upload_attempts += 1
|
||||
record.last_upload_error = str(e)
|
||||
session.commit()
|
||||
|
||||
failed_count += 1
|
||||
|
||||
# Note: No need to track last_successful_upload_at separately
|
||||
# Can be derived from MAX(uploaded_at) in telemetry_metrics
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"uploaded": uploaded_count,
|
||||
"failed": failed_count,
|
||||
"total_processed": len(pending_metrics)
|
||||
}
|
||||
|
||||
def _get_admin_email(self) -> str | None:
|
||||
"""Get administrator email for customer identification."""
|
||||
# 1. Check environment variable first
|
||||
env_admin_email = os.getenv('OPENHANDS_ADMIN_EMAIL')
|
||||
if env_admin_email:
|
||||
logger.info("Using admin email from environment variable")
|
||||
return env_admin_email
|
||||
|
||||
# 2. Use first active user's email (earliest accepted_tos)
|
||||
with session_maker() as session:
|
||||
first_user = session.query(UserSettings)\
|
||||
.filter(UserSettings.email.isnot(None))\
|
||||
.filter(UserSettings.accepted_tos.isnot(None))\
|
||||
.order_by(UserSettings.accepted_tos.asc())\
|
||||
.first()
|
||||
|
||||
if first_user and first_user.email:
|
||||
logger.info(f"Using first active user email: {first_user.email}")
|
||||
return first_user.email
|
||||
|
||||
# No admin email available - skip telemetry
|
||||
logger.info("No admin email available - skipping telemetry collection")
|
||||
return None
|
||||
|
||||
async def _update_telemetry_identity(self, customer_id: str, instance_id: str) -> None:
|
||||
"""Update or create telemetry identity record."""
|
||||
with session_maker() as session:
|
||||
identity = session.query(TelemetryIdentity).first()
|
||||
if not identity:
|
||||
identity = TelemetryIdentity()
|
||||
session.add(identity)
|
||||
|
||||
identity.customer_id = customer_id
|
||||
identity.instance_id = instance_id
|
||||
session.commit()
|
||||
```
|
||||
|
||||
### 4.4 License Warning System
|
||||
|
||||
#### 4.4.1 License Status Endpoint
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
license_router = APIRouter()
|
||||
|
||||
@license_router.get("/license-status")
|
||||
async def get_license_status():
|
||||
"""Get license warning status for UI display."""
|
||||
|
||||
# Only show warnings for OHE installations
|
||||
if not _is_openhands_enterprise():
|
||||
return {"warn": False, "message": ""}
|
||||
|
||||
with session_maker() as session:
|
||||
# Get last successful upload time from metrics table
|
||||
last_upload = session.query(func.max(TelemetryMetrics.uploaded_at))\
|
||||
.filter(TelemetryMetrics.uploaded_at.isnot(None))\
|
||||
.scalar()
|
||||
|
||||
if not last_upload:
|
||||
# No successful uploads yet - show warning after 4 days
|
||||
return {
|
||||
"warn": True,
|
||||
"message": "OpenHands Enterprise license verification pending. Please ensure network connectivity."
|
||||
}
|
||||
|
||||
# Check if last successful upload was more than 4 days ago
|
||||
days_since_upload = (datetime.now(timezone.utc) - last_upload).days
|
||||
|
||||
if days_since_upload > 4:
|
||||
# Find oldest unsent batch
|
||||
oldest_unsent = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.uploaded_at.is_(None))\
|
||||
.order_by(TelemetryMetrics.collected_at)\
|
||||
.first()
|
||||
|
||||
if oldest_unsent:
|
||||
# Calculate expiration date (oldest unsent + 34 days)
|
||||
expiration_date = oldest_unsent.collected_at + timedelta(days=34)
|
||||
days_until_expiration = (expiration_date - datetime.now(timezone.utc)).days
|
||||
|
||||
if days_until_expiration <= 0:
|
||||
message = "Your OpenHands Enterprise license has expired. Please contact support immediately."
|
||||
else:
|
||||
message = f"Your OpenHands Enterprise license will expire in {days_until_expiration} days. Please contact support if this issue persists."
|
||||
|
||||
return {"warn": True, "message": message}
|
||||
|
||||
return {"warn": False, "message": ""}
|
||||
|
||||
def _is_openhands_enterprise() -> bool:
|
||||
"""Detect if this is an OHE installation."""
|
||||
# Check for required OHE environment variables
|
||||
required_vars = [
|
||||
'GITHUB_APP_CLIENT_ID',
|
||||
'KEYCLOAK_SERVER_URL',
|
||||
'KEYCLOAK_REALM_NAME'
|
||||
]
|
||||
|
||||
return all(os.getenv(var) for var in required_vars)
|
||||
```
|
||||
|
||||
#### 4.4.2 UI Integration
|
||||
|
||||
The frontend will poll the license status endpoint and display warnings using the existing banner component pattern:
|
||||
|
||||
```typescript
|
||||
// New component: LicenseWarningBanner.tsx
|
||||
interface LicenseStatus {
|
||||
warn: boolean;
|
||||
message: string;
|
||||
}
|
||||
|
||||
export function LicenseWarningBanner() {
|
||||
const [licenseStatus, setLicenseStatus] = useState<LicenseStatus>({ warn: false, message: "" });
|
||||
|
||||
useEffect(() => {
|
||||
const checkLicenseStatus = async () => {
|
||||
try {
|
||||
const response = await fetch('/api/license-status');
|
||||
const status = await response.json();
|
||||
setLicenseStatus(status);
|
||||
} catch (error) {
|
||||
console.error('Failed to check license status:', error);
|
||||
}
|
||||
};
|
||||
|
||||
// Check immediately and then every hour
|
||||
checkLicenseStatus();
|
||||
const interval = setInterval(checkLicenseStatus, 60 * 60 * 1000);
|
||||
|
||||
return () => clearInterval(interval);
|
||||
}, []);
|
||||
|
||||
if (!licenseStatus.warn) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="bg-red-600 text-white p-4 rounded flex items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<FaExclamationTriangle className="mr-3" />
|
||||
<span>{licenseStatus.message}</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### 4.5 Cronjob Configuration
|
||||
|
||||
The cronjob configurations will be deployed via the OpenHands-Cloud helm charts.
|
||||
|
||||
#### 4.5.1 Collection Cronjob
|
||||
|
||||
The collection cronjob runs weekly to gather metrics:
|
||||
|
||||
```yaml
|
||||
# charts/openhands/templates/telemetry-collection-cronjob.yaml
|
||||
apiVersion: batch/v1
|
||||
kind: CronJob
|
||||
metadata:
|
||||
name: {{ include "openhands.fullname" . }}-telemetry-collection
|
||||
labels:
|
||||
{{- include "openhands.labels" . | nindent 4 }}
|
||||
spec:
|
||||
schedule: "0 2 * * 0" # Weekly on Sunday at 2 AM
|
||||
jobTemplate:
|
||||
spec:
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: telemetry-collector
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
|
||||
env:
|
||||
{{- include "openhands.env" . | nindent 12 }}
|
||||
command:
|
||||
- python
|
||||
- -c
|
||||
- |
|
||||
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from enterprise.storage.database import session_maker
|
||||
from enterprise.server.telemetry.collection_processor import TelemetryCollectionProcessor
|
||||
|
||||
# Create collection task
|
||||
processor = TelemetryCollectionProcessor()
|
||||
task = MaintenanceTask()
|
||||
task.set_processor(processor)
|
||||
task.status = MaintenanceTaskStatus.PENDING
|
||||
|
||||
with session_maker() as session:
|
||||
session.add(task)
|
||||
session.commit()
|
||||
restartPolicy: OnFailure
|
||||
```
|
||||
|
||||
#### 4.5.2 Upload Cronjob
|
||||
|
||||
The upload cronjob runs daily to send metrics to Replicated:
|
||||
|
||||
```yaml
|
||||
# charts/openhands/templates/telemetry-upload-cronjob.yaml
|
||||
apiVersion: batch/v1
|
||||
kind: CronJob
|
||||
metadata:
|
||||
name: {{ include "openhands.fullname" . }}-telemetry-upload
|
||||
labels:
|
||||
{{- include "openhands.labels" . | nindent 4 }}
|
||||
spec:
|
||||
schedule: "0 3 * * *" # Daily at 3 AM
|
||||
jobTemplate:
|
||||
spec:
|
||||
template:
|
||||
spec:
|
||||
containers:
|
||||
- name: telemetry-uploader
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}"
|
||||
env:
|
||||
{{- include "openhands.env" . | nindent 12 }}
|
||||
- name: REPLICATED_PUBLISHABLE_KEY
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: {{ include "openhands.fullname" . }}-replicated-config
|
||||
key: publishable-key
|
||||
- name: REPLICATED_APP_SLUG
|
||||
value: {{ .Values.telemetry.replicatedAppSlug | default "openhands-enterprise" | quote }}
|
||||
command:
|
||||
- python
|
||||
- -c
|
||||
- |
|
||||
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from enterprise.storage.database import session_maker
|
||||
from enterprise.server.telemetry.upload_processor import TelemetryUploadProcessor
|
||||
import os
|
||||
|
||||
# Create upload task
|
||||
processor = TelemetryUploadProcessor(
|
||||
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
|
||||
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
|
||||
)
|
||||
task = MaintenanceTask()
|
||||
task.set_processor(processor)
|
||||
task.status = MaintenanceTaskStatus.PENDING
|
||||
|
||||
with session_maker() as session:
|
||||
session.add(task)
|
||||
session.commit()
|
||||
restartPolicy: OnFailure
|
||||
```
|
||||
|
||||
## 5. Implementation Plan
|
||||
|
||||
All implementation must pass existing lints and tests. New functionality requires comprehensive unit tests with >90% coverage. Integration tests should verify end-to-end telemetry flow including collection, storage, upload, and warning display.
|
||||
|
||||
### 5.1 Database Schema and Models (M1)
|
||||
|
||||
**Repository**: OpenHands
|
||||
Establish the foundational database schema and SQLAlchemy models for telemetry data storage.
|
||||
|
||||
#### 5.1.1 OpenHands - Database Migration
|
||||
|
||||
- [ ] `enterprise/migrations/versions/077_create_telemetry_tables.py`
|
||||
- [ ] `enterprise/storage/telemetry_metrics.py`
|
||||
- [ ] `enterprise/storage/telemetry_config.py`
|
||||
|
||||
#### 5.1.2 OpenHands - Model Tests
|
||||
|
||||
- [ ] `enterprise/tests/unit/storage/test_telemetry_metrics.py`
|
||||
- [ ] `enterprise/tests/unit/storage/test_telemetry_config.py`
|
||||
|
||||
**Demo**: Database tables created and models can store/retrieve telemetry data.
|
||||
|
||||
### 5.2 Metrics Collection Framework (M2)
|
||||
|
||||
**Repository**: OpenHands
|
||||
Implement the pluggable metrics collection system with registry and base classes.
|
||||
|
||||
#### 5.2.1 OpenHands - Core Collection Framework
|
||||
|
||||
- [ ] `enterprise/server/telemetry/__init__.py`
|
||||
- [ ] `enterprise/server/telemetry/collector_base.py`
|
||||
- [ ] `enterprise/server/telemetry/collector_registry.py`
|
||||
- [ ] `enterprise/server/telemetry/decorators.py`
|
||||
|
||||
#### 5.2.2 OpenHands - Example Collectors
|
||||
|
||||
- [ ] `enterprise/server/telemetry/collectors/__init__.py`
|
||||
- [ ] `enterprise/server/telemetry/collectors/system_metrics.py`
|
||||
- [ ] `enterprise/server/telemetry/collectors/user_activity.py`
|
||||
|
||||
#### 5.2.3 OpenHands - Framework Tests
|
||||
|
||||
- [ ] `enterprise/tests/unit/telemetry/test_collector_base.py`
|
||||
- [ ] `enterprise/tests/unit/telemetry/test_collector_registry.py`
|
||||
- [ ] `enterprise/tests/unit/telemetry/test_system_metrics.py`
|
||||
|
||||
**Demo**: Developers can create new collectors with a single file change using the @register_collector decorator.
|
||||
|
||||
### 5.3 Collection and Upload Processors (M3)
|
||||
|
||||
**Repository**: OpenHands
|
||||
Implement maintenance task processors for collecting metrics and uploading to Replicated.
|
||||
|
||||
#### 5.3.1 OpenHands - Collection Processor
|
||||
|
||||
- [ ] `enterprise/server/telemetry/collection_processor.py`
|
||||
- [ ] `enterprise/tests/unit/telemetry/test_collection_processor.py`
|
||||
|
||||
#### 5.3.2 OpenHands - Upload Processor
|
||||
|
||||
- [ ] `enterprise/server/telemetry/upload_processor.py`
|
||||
- [ ] `enterprise/tests/unit/telemetry/test_upload_processor.py`
|
||||
|
||||
#### 5.3.3 OpenHands - Integration Tests
|
||||
|
||||
- [ ] `enterprise/tests/integration/test_telemetry_flow.py`
|
||||
|
||||
**Demo**: Metrics are automatically collected weekly and uploaded daily to Replicated vendor portal.
|
||||
|
||||
### 5.4 License Warning API (M4)
|
||||
|
||||
**Repository**: OpenHands
|
||||
Implement the license status endpoint for the warning system.
|
||||
|
||||
#### 5.4.1 OpenHands - License Status API
|
||||
|
||||
- [ ] `enterprise/server/routes/license.py`
|
||||
- [ ] `enterprise/tests/unit/routes/test_license.py`
|
||||
|
||||
#### 5.4.2 OpenHands - API Integration
|
||||
|
||||
- [ ] Update `enterprise/saas_server.py` to include license router
|
||||
|
||||
**Demo**: License status API returns warning status based on telemetry upload success.
|
||||
|
||||
### 5.5 UI Warning Banner (M5)
|
||||
|
||||
**Repository**: OpenHands
|
||||
Implement the frontend warning banner component and integration.
|
||||
|
||||
#### 5.5.1 OpenHands - UI Warning Banner
|
||||
|
||||
- [ ] `frontend/src/components/features/license/license-warning-banner.tsx`
|
||||
- [ ] `frontend/src/components/features/license/license-warning-banner.test.tsx`
|
||||
|
||||
#### 5.5.2 OpenHands - UI Integration
|
||||
|
||||
- [ ] Update main UI layout to include license warning banner
|
||||
- [ ] Add license status polling service
|
||||
|
||||
**Demo**: License warnings appear in UI when telemetry uploads fail for >4 days, with accurate expiration countdown.
|
||||
|
||||
### 5.6 Helm Chart Deployment Configuration (M6)
|
||||
|
||||
**Repository**: OpenHands-Cloud
|
||||
Create Kubernetes cronjob configurations and deployment scripts.
|
||||
|
||||
#### 5.6.1 OpenHands-Cloud - Cronjob Manifests
|
||||
|
||||
- [ ] `charts/openhands/templates/telemetry-collection-cronjob.yaml`
|
||||
- [ ] `charts/openhands/templates/telemetry-upload-cronjob.yaml`
|
||||
|
||||
#### 5.6.2 OpenHands-Cloud - Configuration Management
|
||||
|
||||
- [ ] `charts/openhands/templates/replicated-secret.yaml`
|
||||
- [ ] Update `charts/openhands/values.yaml` with telemetry configuration options:
|
||||
```yaml
|
||||
# Add to values.yaml
|
||||
telemetry:
|
||||
enabled: true
|
||||
replicatedAppSlug: "openhands-enterprise"
|
||||
adminEmail: "" # Optional: admin email for customer identification
|
||||
|
||||
# Add to deployment environment variables
|
||||
env:
|
||||
OPENHANDS_ADMIN_EMAIL: "{{ .Values.telemetry.adminEmail }}"
|
||||
```
|
||||
|
||||
**Demo**: Complete telemetry system deployed via helm chart with configurable collection intervals and Replicated integration.
|
||||
|
||||
### 5.7 Documentation and Enhanced Collectors (M7)
|
||||
|
||||
**Repository**: OpenHands
|
||||
Add comprehensive metrics collectors, monitoring capabilities, and documentation.
|
||||
|
||||
#### 5.7.1 OpenHands - Advanced Collectors
|
||||
|
||||
- [ ] `enterprise/server/telemetry/collectors/conversation_metrics.py`
|
||||
- [ ] `enterprise/server/telemetry/collectors/integration_usage.py`
|
||||
- [ ] `enterprise/server/telemetry/collectors/performance_metrics.py`
|
||||
|
||||
#### 5.7.2 OpenHands - Monitoring and Testing
|
||||
|
||||
- [ ] `enterprise/server/telemetry/monitoring.py`
|
||||
- [ ] `enterprise/tests/e2e/test_telemetry_system.py`
|
||||
- [ ] Performance tests for large-scale metric collection
|
||||
|
||||
#### 5.7.3 OpenHands - Technical Documentation
|
||||
|
||||
- [ ] `enterprise/server/telemetry/README.md`
|
||||
- [ ] Update deployment documentation with telemetry configuration instructions
|
||||
- [ ] Add troubleshooting guide for telemetry issues
|
||||
|
||||
**Demo**: Rich telemetry data flowing to vendor portal with comprehensive monitoring, alerting for system health, and complete documentation.
|
||||
274
enterprise/enterprise_local/README.md
Normal file
274
enterprise/enterprise_local/README.md
Normal file
@@ -0,0 +1,274 @@
|
||||
# Instructions for developing SAAS locally
|
||||
|
||||
You have a few options here, which are expanded on below:
|
||||
|
||||
- A simple local development setup, with live reloading for both OSS and this repo
|
||||
- A more complex setup that includes Redis
|
||||
- An even more complex setup that includes GitHub events
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before starting, make sure you have the following tools installed:
|
||||
|
||||
### Required for all options:
|
||||
|
||||
- [gcloud CLI](https://cloud.google.com/sdk/docs/install) - For authentication and secrets management
|
||||
- [sops](https://github.com/mozilla/sops) - For secrets decryption
|
||||
- macOS: `brew install sops`
|
||||
- Linux: `sudo apt-get install sops` or download from GitHub releases
|
||||
- Windows: Install via Chocolatey `choco install sops` or download from GitHub releases
|
||||
|
||||
### Additional requirements for enabling GitHub webhook events
|
||||
|
||||
- make
|
||||
- Python development tools (build-essential, python3-dev)
|
||||
- [ngrok](https://ngrok.com/download) - For creating tunnels to localhost
|
||||
|
||||
## Option 1: Simple local development
|
||||
|
||||
This option will allow you to modify the both the OSS code and the code in this repo,
|
||||
and see the changes in real-time.
|
||||
|
||||
This option works best for most scenarios. The only thing it's missing is
|
||||
the GitHub events webhook, which is not necessary for most development.
|
||||
|
||||
### 1. OpenHands location
|
||||
|
||||
The open source OpenHands repo should be cloned as a sibling directory,
|
||||
in `../OpenHands`. This is hard-coded in the pyproject.toml (edit if necessary)
|
||||
|
||||
If you're doing this the first time, you may need to run
|
||||
|
||||
```
|
||||
poetry update openhands-ai
|
||||
```
|
||||
|
||||
### 2. Set up env
|
||||
|
||||
First run this to retrieve Github App secrets
|
||||
|
||||
```
|
||||
gcloud auth application-default login
|
||||
gcloud config set project global-432717
|
||||
local/decrypt_env.sh
|
||||
```
|
||||
|
||||
Now run this to generate a `.env` file, which will used to run SAAS locally
|
||||
|
||||
```
|
||||
python -m pip install PyYAML
|
||||
export LITE_LLM_API_KEY=<your LLM API key>
|
||||
python enterprise_local/convert_to_env.py
|
||||
```
|
||||
|
||||
You'll also need to set up the runtime image, so that the dev server doesn't try to rebuild it.
|
||||
|
||||
```
|
||||
export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:main-nikolaik
|
||||
docker pull $SANDBOX_RUNTIME_CONTAINER_IMAGE
|
||||
```
|
||||
|
||||
By default the application will log in json, you can override.
|
||||
|
||||
```
|
||||
export LOG_PLAIN_TEXT=1
|
||||
```
|
||||
|
||||
### 3. Start the OpenHands frontend
|
||||
|
||||
Start the frontend like you normally would in the open source OpenHands repo.
|
||||
|
||||
### 4. Start the SaaS backend
|
||||
|
||||
```
|
||||
make build
|
||||
|
||||
make start-backend
|
||||
```
|
||||
|
||||
You should have a server running on `localhost:3000`, similar to the open source backend.
|
||||
Oauth should work properly.
|
||||
|
||||
## Option 2: With Redis
|
||||
|
||||
Follow all the steps above, then setup redis:
|
||||
|
||||
```bash
|
||||
docker run -p 6379:6379 --name openhands-redis -d redis
|
||||
export REDIS_HOST=host.docker.internal # you may want this to be localhost
|
||||
export REDIS_PORT=6379
|
||||
```
|
||||
|
||||
## Option 3: Work with GitHub events
|
||||
|
||||
### 1. Setup env file
|
||||
|
||||
(see above)
|
||||
|
||||
### 2. Build OSS Openhands
|
||||
|
||||
Develop on [Openhands](https://github.com/All-Hands-AI/OpenHands) locally. When ready, run the following inside Openhands repo (not the Deploy repo)
|
||||
|
||||
```
|
||||
docker build -f containers/app/Dockerfile -t openhands .
|
||||
```
|
||||
|
||||
### 3. Build SAAS Openhands
|
||||
|
||||
Build the SAAS image locally inside Deploy repo. Note that `openhands` is the name of the image built in Step 2
|
||||
|
||||
```
|
||||
docker build -t openhands-saas ./app/ --build-arg BASE="openhands"
|
||||
```
|
||||
|
||||
### 4. Create a tunnel
|
||||
|
||||
Run in a separate terminal
|
||||
|
||||
```
|
||||
ngrok http 3000
|
||||
```
|
||||
|
||||
There will be a line
|
||||
|
||||
```
|
||||
Forwarding https://bc71-2603-7000-5000-1575-e4a6-697b-589e-5801.ngrok-free.app
|
||||
```
|
||||
|
||||
Remember this URL as it will be used in Step 5 and 6
|
||||
|
||||
### 5. Setup Staging Github App callback/webhook urls
|
||||
|
||||
Using the URL found in Step 4, add another callback URL (`https://bc71-2603-7000-5000-1575-e4a6-697b-589e-5801.ngrok-free.app/oauth/github/callback`)
|
||||
|
||||
### 6. Run
|
||||
|
||||
This is the last step! Run SAAS openhands locally using
|
||||
|
||||
```
|
||||
docker run --env-file ./app/.env -p 3000:3000 openhands-saas
|
||||
```
|
||||
|
||||
Note `--env-file` is what injects the `.env` file created in Step 1
|
||||
|
||||
Visit the tunnel domain found in Step 4 to run the app (`https://bc71-2603-7000-5000-1575-e4a6-697b-589e-5801.ngrok-free.app`)
|
||||
|
||||
### Local Debugging with VSCode
|
||||
|
||||
Local Development necessitates running a version of OpenHands that is as similar as possible to the version running in the SAAS Environment. Before running these steps, it is assumed you have a local development version of the OSS OpenHands project running.
|
||||
|
||||
#### Redis
|
||||
|
||||
A Local redis instance is required for clustered communication between server nodes. The standard docker instance will suffice.
|
||||
`docker run -it -p 6379:6379 --name my-redis -d redis`
|
||||
|
||||
#### Postgres
|
||||
|
||||
A Local postgres instance is required. I used the official docker image:
|
||||
`docker run -p 5432:5432 --name my-postgres -e POSTGRES_USER=postgres -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=openhands -d postgres`
|
||||
Run the alembic migrations:
|
||||
`poetry run alembic upgrade head `
|
||||
|
||||
#### VSCode launch.json
|
||||
|
||||
The VSCode launch.json below sets up 2 servers to test clustering, running independently on localhost:3030 and localhost:3031. Running only the server on 3030 is usually sufficient unless tests of the clustered functionality are required. Secrets may be harvested directly from staging by connecting...
|
||||
`kubectl exec --stdin --tty <POD_NAME> -n <NAMESPACE> -- /bin/bash`
|
||||
And then invoking `printenv`. NOTE: _DO NOT DO THIS WITH PROD!!!_ (Hopefully by the time you read this, nobody will have access.)
|
||||
|
||||
```
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Python File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}"
|
||||
},
|
||||
{
|
||||
"name": "OpenHands Deploy",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"saas_server:app",
|
||||
"--reload",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"3030"
|
||||
],
|
||||
"env": {
|
||||
"DEBUG": "1",
|
||||
"FILE_STORE": "local",
|
||||
"REDIS_HOST": "localhost:6379",
|
||||
"OPENHANDS": "<YOUR LOCAL OSS OPENHANDS DIR>",
|
||||
"FRONTEND_DIRECTORY": "<YOUR LOCAL OSS OPENHANDS DIR>/frontend/build",
|
||||
"SANDBOX_RUNTIME_CONTAINER_IMAGE": "ghcr.io/openhands/runtime:main-nikolaik",
|
||||
"FILE_STORE_PATH": "<YOUR HOME DIRECTORY>>/.openhands-state",
|
||||
"OPENHANDS_CONFIG_CLS": "server.config.SaaSServerConfig",
|
||||
"GITHUB_APP_ID": "1062351",
|
||||
"GITHUB_APP_PRIVATE_KEY": "<GITHUB PRIVATE KEY>",
|
||||
"GITHUB_APP_CLIENT_ID": "Iv23lis7eUWDQHIq8US0",
|
||||
"GITHUB_APP_CLIENT_SECRET": "<GITHUB CLIENT SECRET>",
|
||||
"POSTHOG_CLIENT_KEY": "<POSTHOG CLIENT KEY>",
|
||||
"LITE_LLM_API_URL": "https://llm-proxy.staging.all-hands.dev",
|
||||
"LITE_LLM_TEAM_ID": "62ea39c4-8886-44f3-b7ce-07ed4fe42d2c",
|
||||
"LITE_LLM_API_KEY": "<LITE LLM API KEY>"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/app"
|
||||
},
|
||||
{
|
||||
"name": "OpenHands Deploy 2",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"saas_server:app",
|
||||
"--reload",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"3031"
|
||||
],
|
||||
"env": {
|
||||
"DEBUG": "1",
|
||||
"FILE_STORE": "local",
|
||||
"REDIS_HOST": "localhost:6379",
|
||||
"OPENHANDS": "<YOUR LOCAL OSS OPENHANDS DIR>",
|
||||
"FRONTEND_DIRECTORY": "<YOUR LOCAL OSS OPENHANDS DIR>/frontend/build",
|
||||
"SANDBOX_RUNTIME_CONTAINER_IMAGE": "ghcr.io/openhands/runtime:main-nikolaik",
|
||||
"FILE_STORE_PATH": "<YOUR HOME DIRECTORY>>/.openhands-state",
|
||||
"OPENHANDS_CONFIG_CLS": "server.config.SaaSServerConfig",
|
||||
"GITHUB_APP_ID": "1062351",
|
||||
"GITHUB_APP_PRIVATE_KEY": "<GITHUB PRIVATE KEY>",
|
||||
"GITHUB_APP_CLIENT_ID": "Iv23lis7eUWDQHIq8US0",
|
||||
"GITHUB_APP_CLIENT_SECRET": "<GITHUB CLIENT SECRET>",
|
||||
"POSTHOG_CLIENT_KEY": "<POSTHOG CLIENT KEY>",
|
||||
"LITE_LLM_API_URL": "https://llm-proxy.staging.all-hands.dev",
|
||||
"LITE_LLM_TEAM_ID": "62ea39c4-8886-44f3-b7ce-07ed4fe42d2c",
|
||||
"LITE_LLM_API_KEY": "<LITE LLM API KEY>"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/app"
|
||||
},
|
||||
{
|
||||
"name": "Unit Tests",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"args": [
|
||||
"./tests/unit",
|
||||
//"./tests/unit/test_clustered_conversation_manager.py",
|
||||
"--durations=0"
|
||||
],
|
||||
"env": {
|
||||
"DEBUG": "1"
|
||||
},
|
||||
"justMyCode": false,
|
||||
"cwd": "${workspaceFolder}/app"
|
||||
},
|
||||
// set working directory...
|
||||
]
|
||||
}
|
||||
```
|
||||
127
enterprise/enterprise_local/convert_to_env.py
Normal file
127
enterprise/enterprise_local/convert_to_env.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def convert_yaml_to_env(yaml_file, target_parameters, output_env_file, prefix):
|
||||
"""Converts a YAML file into .env file format for specified target parameters under 'stringData' and 'data'.
|
||||
|
||||
:param yaml_file: Path to the YAML file.
|
||||
:param target_parameters: List of keys to extract from the YAML file.
|
||||
:param output_env_file: Path to the output .env file.
|
||||
:param prefix: Prefix for environment variables.
|
||||
"""
|
||||
try:
|
||||
# Load the YAML file
|
||||
with open(yaml_file, 'r') as file:
|
||||
yaml_data = yaml.safe_load(file)
|
||||
|
||||
# Extract sections
|
||||
string_data = yaml_data.get('stringData', None)
|
||||
data = yaml_data.get('data', None)
|
||||
|
||||
if string_data:
|
||||
env_source = string_data
|
||||
process_base64 = False
|
||||
elif data:
|
||||
env_source = data
|
||||
process_base64 = True
|
||||
else:
|
||||
print(
|
||||
"Error: Neither 'stringData' nor 'data' section found in the YAML file."
|
||||
)
|
||||
return
|
||||
|
||||
env_lines = []
|
||||
|
||||
for param in target_parameters:
|
||||
if param in env_source:
|
||||
value = env_source[param]
|
||||
if process_base64:
|
||||
try:
|
||||
decoded_value = base64.b64decode(value).decode('utf-8')
|
||||
formatted_value = (
|
||||
decoded_value.replace('\n', '\\n')
|
||||
if '\n' in decoded_value
|
||||
else decoded_value
|
||||
)
|
||||
except Exception as decode_error:
|
||||
print(f"Error decoding base64 for '{param}': {decode_error}")
|
||||
continue
|
||||
else:
|
||||
formatted_value = (
|
||||
value.replace('\n', '\\n')
|
||||
if isinstance(value, str) and '\n' in value
|
||||
else value
|
||||
)
|
||||
|
||||
new_key = prefix + param.upper().replace('-', '_')
|
||||
env_lines.append(f'{new_key}={formatted_value}')
|
||||
else:
|
||||
print(
|
||||
f"Warning: Parameter '{param}' not found in the selected section."
|
||||
)
|
||||
|
||||
# Write to the .env file
|
||||
with open(output_env_file, 'a') as env_file:
|
||||
env_file.write('\n'.join(env_lines) + '\n')
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error: {e}')
|
||||
|
||||
|
||||
lite_llm_api_key = os.getenv('LITE_LLM_API_KEY')
|
||||
if not lite_llm_api_key:
|
||||
print('Set the LITE_LLM_API_KEY environment variable to your API key')
|
||||
sys.exit(1)
|
||||
|
||||
yaml_file = 'github_decrypted.yaml'
|
||||
target_parameters = ['client-id', 'client-secret', 'webhook-secret', 'private-key']
|
||||
output_env_file = './enterprise/.env'
|
||||
|
||||
if os.path.exists(output_env_file):
|
||||
os.remove(output_env_file)
|
||||
convert_yaml_to_env(yaml_file, target_parameters, output_env_file, 'GITHUB_APP_')
|
||||
os.remove(yaml_file)
|
||||
|
||||
yaml_file = 'keycloak_realm_decrypted.yaml'
|
||||
target_parameters = ['client-id', 'client-secret', 'provider-name', 'realm-name']
|
||||
convert_yaml_to_env(yaml_file, target_parameters, output_env_file, 'KEYCLOAK_')
|
||||
os.remove(yaml_file)
|
||||
|
||||
yaml_file = 'keycloak_admin_decrypted.yaml'
|
||||
target_parameters = ['admin-password']
|
||||
convert_yaml_to_env(yaml_file, target_parameters, output_env_file, 'KEYCLOAK_')
|
||||
os.remove(yaml_file)
|
||||
|
||||
lines = []
|
||||
lines.append('KEYCLOAK_SERVER_URL=https://auth.staging.all-hands.dev/')
|
||||
lines.append('KEYCLOAK_SERVER_URL_EXT=https://auth.staging.all-hands.dev/')
|
||||
lines.append('OPENHANDS_CONFIG_CLS=server.config.SaaSServerConfig')
|
||||
lines.append(
|
||||
'OPENHANDS_GITHUB_SERVICE_CLS=integrations.github.github_service.SaaSGitHubService'
|
||||
)
|
||||
lines.append(
|
||||
'OPENHANDS_GITLAB_SERVICE_CLS=integrations.gitlab.gitlab_service.SaaSGitLabService'
|
||||
)
|
||||
lines.append(
|
||||
'OPENHANDS_BITBUCKET_SERVICE_CLS=integrations.bitbucket.bitbucket_service.SaaSBitBucketService'
|
||||
)
|
||||
lines.append(
|
||||
'OPENHANDS_CONVERSATION_VALIDATOR_CLS=storage.saas_conversation_validator.SaasConversationValidator'
|
||||
)
|
||||
lines.append('POSTHOG_CLIENT_KEY=test')
|
||||
lines.append('ENABLE_PROACTIVE_CONVERSATION_STARTERS=true')
|
||||
lines.append('MAX_CONCURRENT_CONVERSATIONS=10')
|
||||
lines.append('LITE_LLM_API_URL=https://llm-proxy.eval.all-hands.dev')
|
||||
lines.append('LITELLM_DEFAULT_MODEL=litellm_proxy/claude-sonnet-4-20250514')
|
||||
lines.append(f'LITE_LLM_API_KEY={lite_llm_api_key}')
|
||||
lines.append('LOCAL_DEPLOYMENT=true')
|
||||
lines.append('DB_HOST=localhost')
|
||||
|
||||
with open(output_env_file, 'a') as env_file:
|
||||
env_file.write('\n'.join(lines))
|
||||
|
||||
print(f'.env file created at: {output_env_file}')
|
||||
27
enterprise/enterprise_local/decrypt_env.sh
Normal file
27
enterprise/enterprise_local/decrypt_env.sh
Normal file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
# Check if DEPLOY_DIR argument was provided
|
||||
if [ $# -lt 1 ]; then
|
||||
echo "Usage: $0 <DEPLOY_DIR>"
|
||||
echo "Example: $0 /path/to/deploy"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Normalize path (remove trailing slash)
|
||||
DEPLOY_DIR="${DEPLOY_DIR%/}"
|
||||
|
||||
# Function to decrypt and rename
|
||||
decrypt_and_move() {
|
||||
local secret_path="$1"
|
||||
local output_name="$2"
|
||||
|
||||
${DEPLOY_DIR}/scripts/decrypt.sh "${DEPLOY_DIR}/${secret_path}"
|
||||
mv decrypted.yaml "${output_name}"
|
||||
echo "Moved decrypted.yaml to ${output_name}"
|
||||
}
|
||||
|
||||
# Decrypt each secret file
|
||||
decrypt_and_move "openhands/envs/feature/secrets/github-app.yaml" "github_decrypted.yaml"
|
||||
decrypt_and_move "openhands/envs/staging/secrets/keycloak-realm.yaml" "keycloak_realm_decrypted.yaml"
|
||||
decrypt_and_move "openhands/envs/staging/secrets/keycloak-admin.yaml" "keycloak_admin_decrypted.yaml"
|
||||
@@ -1,18 +1,47 @@
|
||||
from uuid import UUID
|
||||
|
||||
from experiments.constants import (
|
||||
ENABLE_EXPERIMENT_MANAGER,
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
)
|
||||
from experiments.experiment_versions import (
|
||||
handle_condenser_max_step_experiment,
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||
handle_condenser_max_step_experiment__v1,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.experiments.experiment_manager import ExperimentManager
|
||||
from openhands.sdk import Agent
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
class SaaSExperimentManager(ExperimentManager):
|
||||
@staticmethod
|
||||
def run_agent_variant_tests__v1(
|
||||
user_id: str | None, conversation_id: UUID, agent: Agent
|
||||
) -> Agent:
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return agent
|
||||
|
||||
agent = handle_condenser_max_step_experiment__v1(
|
||||
user_id, conversation_id, agent
|
||||
)
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_settings
|
||||
|
||||
@@ -5,12 +5,18 @@ This module contains the handler for the condenser max step experiment that test
|
||||
different max_size values for the condenser configuration.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CONDENSER_MAX_STEP
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.sdk import Agent
|
||||
from openhands.sdk.context.condenser import (
|
||||
LLMSummarizingCondenser,
|
||||
)
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
@@ -190,3 +196,37 @@ def handle_condenser_max_step_experiment(
|
||||
return conversation_settings
|
||||
|
||||
return conversation_settings
|
||||
|
||||
|
||||
def handle_condenser_max_step_experiment__v1(
|
||||
user_id: str | None,
|
||||
conversation_id: UUID,
|
||||
agent: Agent,
|
||||
) -> Agent:
|
||||
enabled_variant = _get_condenser_max_step_variant(user_id, str(conversation_id))
|
||||
|
||||
if enabled_variant is None:
|
||||
return agent
|
||||
|
||||
if enabled_variant == 'control':
|
||||
condenser_max_size = 120
|
||||
elif enabled_variant == 'treatment':
|
||||
condenser_max_size = 80
|
||||
else:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'unknown variant; returning original conversation settings',
|
||||
},
|
||||
)
|
||||
return agent
|
||||
|
||||
condenser_llm = agent.llm.model_copy(update={'usage_id': 'condenser'})
|
||||
condenser = LLMSummarizingCondenser(
|
||||
llm=condenser_llm, max_size=condenser_max_size, keep_first=4
|
||||
)
|
||||
|
||||
return agent.model_copy(update={'condenser': condenser})
|
||||
|
||||
@@ -31,7 +31,7 @@ from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
@@ -250,7 +250,7 @@ class GithubManager(Manager):
|
||||
f'[GitHub] Creating new conversation for user {user_info.username}'
|
||||
)
|
||||
|
||||
secret_store = UserSecrets(
|
||||
secret_store = Secrets(
|
||||
provider_tokens=MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
|
||||
@@ -22,9 +22,9 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
@@ -61,18 +61,15 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
# Check global setting first - if disabled globally, return False
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
with session_maker() as session:
|
||||
settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
@@ -131,6 +128,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
@@ -142,8 +140,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
@@ -151,6 +148,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
@@ -196,7 +194,6 @@ class GithubIssueComment(GithubIssue):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
@@ -233,8 +230,7 @@ class GithubPRComment(GithubIssueComment):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
@@ -280,7 +276,6 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
|
||||
@@ -25,7 +25,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
class GitlabManager(Manager):
|
||||
@@ -198,7 +198,7 @@ class GitlabManager(Manager):
|
||||
f'[GitLab] Creating new conversation for user {user_info.username}'
|
||||
)
|
||||
|
||||
secret_store = UserSecrets(
|
||||
secret_store = Secrets(
|
||||
provider_tokens=MappingProxyType(
|
||||
{
|
||||
ProviderType.GITLAB: ProviderToken(
|
||||
|
||||
@@ -32,6 +32,7 @@ from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
@@ -408,7 +409,7 @@ class JiraManager(Manager):
|
||||
svc_acc_api_key: str,
|
||||
) -> Tuple[str, str]:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
@@ -443,7 +444,7 @@ class JiraManager(Manager):
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
@@ -132,8 +132,10 @@ class JiraExistingConversationView(JiraViewInterface):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
@@ -34,6 +34,7 @@ from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
class JiraDcManager(Manager):
|
||||
@@ -422,7 +423,7 @@ class JiraDcManager(Manager):
|
||||
"""Get issue details from Jira DC API."""
|
||||
url = f'{job_context.base_api_url}/rest/api/2/issue/{job_context.issue_key}'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
@@ -452,7 +453,7 @@ class JiraDcManager(Manager):
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -60,7 +60,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
@@ -135,8 +135,10 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
@@ -31,6 +31,7 @@ from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
class LinearManager(Manager):
|
||||
@@ -408,7 +409,7 @@ class LinearManager(Manager):
|
||||
async def _query_api(self, query: str, variables: Dict, api_key: str) -> Dict:
|
||||
"""Query Linear GraphQL API."""
|
||||
headers = {'Authorization': api_key}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
|
||||
@@ -57,7 +57,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
@@ -132,8 +132,10 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
|
||||
@@ -87,7 +87,7 @@ class SlackManager(Manager):
|
||||
return slack_user, saas_user_auth
|
||||
|
||||
def _infer_repo_from_message(self, user_msg: str) -> str | None:
|
||||
# Regular expression to match patterns like "All-Hands-AI/OpenHands" or "deploy repo"
|
||||
# Regular expression to match patterns like "OpenHands/OpenHands" or "deploy repo"
|
||||
pattern = r'([a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+)|([a-zA-Z0-9_-]+)(?=\s+repo)'
|
||||
match = re.search(pattern, user_msg)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
@@ -166,6 +167,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'org_id': user_info.org_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
},
|
||||
)
|
||||
@@ -173,6 +175,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
org_id=user_info.org_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
)
|
||||
@@ -185,22 +188,30 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
self._verify_necessary_values_are_set()
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
if self.selected_repo and provider_tokens:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
repository = await provider_handler.verify_repo_provider(self.selected_repo)
|
||||
git_provider = repository.git_provider
|
||||
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.slack_to_openhands_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions
|
||||
if conversation_instructions
|
||||
else None,
|
||||
conversation_instructions=(
|
||||
conversation_instructions if conversation_instructions else None
|
||||
),
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.SLACK,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
git_provider=git_provider,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
@@ -263,8 +274,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
# Check if conversation has been deleted
|
||||
# Update logic when soft delete is implemented
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
|
||||
try:
|
||||
await conversation_store.get_metadata(self.conversation_id)
|
||||
except FileNotFoundError:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await saas_user_auth.get_provider_tokens()
|
||||
@@ -293,10 +306,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from uuid import UUID
|
||||
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
@@ -21,46 +24,72 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
query=f"metadata['org_id']:'{str(org_id)}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
logger.info(
|
||||
'no_customer_for_org_id',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
return customer_id
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
|
||||
customer_id = await find_customer_id_by_org_id(org.id)
|
||||
if customer_id:
|
||||
return {'customer_id': customer_id, 'org_id': str(org.id)}
|
||||
logger.info(
|
||||
'creating_customer',
|
||||
extra={'user_id': user_id, 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
email=org.contact_email,
|
||||
metadata={'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
org_id=org.id,
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
return customer.id
|
||||
return {'customer_id': customer.id, 'org_id': str(org.id)}
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
@@ -71,3 +100,28 @@ async def has_payment_method(user_id: str) -> bool:
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -381,7 +381,7 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
|
||||
# Pattern to match direct owner/repo mentions (e.g., "All-Hands-AI/OpenHands")
|
||||
# Pattern to match direct owner/repo mentions (e.g., "OpenHands/OpenHands")
|
||||
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||
|
||||
@@ -20,6 +20,8 @@ down_revision = '059'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# TODO: decide whether to modify this for orgs or users
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""
|
||||
@@ -28,8 +30,10 @@ def upgrade():
|
||||
|
||||
This replaces the functionality of the removed admin maintenance endpoint.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
|
||||
# Hardcoded value to prevent migration failures when constant is removed from codebase
|
||||
# This migration has already run in production, so we use the value that was current at the time
|
||||
CURRENT_USER_SETTINGS_VERSION = 4
|
||||
|
||||
# Create a connection and bind it to a session
|
||||
connection = op.get_bind()
|
||||
|
||||
27
enterprise/migrations/versions/077_drop_settings_table.py
Normal file
27
enterprise/migrations/versions/077_drop_settings_table.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""drop settings table
|
||||
|
||||
Revision ID: 077
|
||||
Revises: 076
|
||||
Create Date: 2025-10-21 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '077'
|
||||
down_revision: Union[str, None] = '076'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Drop the deprecated settings table."""
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op downgrade since the settings table is deprecated."""
|
||||
pass
|
||||
129
enterprise/migrations/versions/078_create_telemetry_tables.py
Normal file
129
enterprise/migrations/versions/078_create_telemetry_tables.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""create telemetry tables
|
||||
|
||||
Revision ID: 078
|
||||
Revises: 077
|
||||
Create Date: 2025-10-21
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '078'
|
||||
down_revision: Union[str, None] = '077'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create telemetry tables for metrics collection and configuration."""
|
||||
# Create telemetry_metrics table
|
||||
op.create_table(
|
||||
'telemetry_metrics',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.String(), # UUID as string
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column(
|
||||
'collected_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'metrics_data',
|
||||
sa.JSON(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'uploaded_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
'upload_attempts',
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default='0',
|
||||
),
|
||||
sa.Column(
|
||||
'last_upload_error',
|
||||
sa.Text(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes for telemetry_metrics
|
||||
op.create_index(
|
||||
'ix_telemetry_metrics_collected_at', 'telemetry_metrics', ['collected_at']
|
||||
)
|
||||
op.create_index(
|
||||
'ix_telemetry_metrics_uploaded_at', 'telemetry_metrics', ['uploaded_at']
|
||||
)
|
||||
|
||||
# Create telemetry_replicated_identity table (minimal persistent identity data)
|
||||
op.create_table(
|
||||
'telemetry_replicated_identity',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
server_default='1',
|
||||
),
|
||||
sa.Column(
|
||||
'customer_id',
|
||||
sa.String(255),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
'instance_id',
|
||||
sa.String(255),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
)
|
||||
|
||||
# Add constraint to ensure single row in telemetry_replicated_identity
|
||||
op.create_check_constraint(
|
||||
'single_identity_row', 'telemetry_replicated_identity', 'id = 1'
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop telemetry tables."""
|
||||
# Drop indexes first
|
||||
op.drop_index('ix_telemetry_metrics_uploaded_at', 'telemetry_metrics')
|
||||
op.drop_index('ix_telemetry_metrics_collected_at', 'telemetry_metrics')
|
||||
|
||||
# Drop tables
|
||||
op.drop_table('telemetry_replicated_identity')
|
||||
op.drop_table('telemetry_metrics')
|
||||
@@ -0,0 +1,39 @@
|
||||
"""rename user_secrets table to custom_secrets
|
||||
|
||||
Revision ID: 079
|
||||
Revises: 078
|
||||
Create Date: 2025-10-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '079'
|
||||
down_revision: Union[str, None] = '078'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename the table from user_secrets to custom_secrets
|
||||
op.rename_table('user_secrets', 'custom_secrets')
|
||||
|
||||
# Rename the index to match the new table name
|
||||
op.drop_index('idx_user_secrets_keycloak_user_id', 'custom_secrets')
|
||||
op.create_index(
|
||||
'idx_custom_secrets_keycloak_user_id', 'custom_secrets', ['keycloak_user_id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Rename the index back to the original name
|
||||
op.drop_index('idx_custom_secrets_keycloak_user_id', 'custom_secrets')
|
||||
op.create_index(
|
||||
'idx_user_secrets_keycloak_user_id', 'custom_secrets', ['keycloak_user_id']
|
||||
)
|
||||
|
||||
# Rename the table back from custom_secrets to user_secrets
|
||||
op.rename_table('custom_secrets', 'user_secrets')
|
||||
252
enterprise/migrations/versions/080_create_org_tables.py
Normal file
252
enterprise/migrations/versions/080_create_org_tables.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""create org tables from pgerd schema
|
||||
|
||||
Revision ID: 080
|
||||
Revises: 079
|
||||
Create Date: 2025-01-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '080'
|
||||
down_revision: Union[str, None] = '079'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute('CREATE EXTENSION IF NOT EXISTS pgcrypto;')
|
||||
# Remove current settings table
|
||||
op.execute('DROP TABLE IF EXISTS settings')
|
||||
|
||||
# Add migration_status column to user_settings table
|
||||
op.add_column(
|
||||
'user_settings',
|
||||
sa.Column('migration_status', sa.Boolean, nullable=True, default=False),
|
||||
)
|
||||
|
||||
# Create role table
|
||||
op.create_table(
|
||||
'role',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('rank', sa.Integer, nullable=False),
|
||||
sa.UniqueConstraint('name', name='role_name_unique'),
|
||||
)
|
||||
|
||||
# 1. Create default roles
|
||||
print('Creating default roles...')
|
||||
op.execute(
|
||||
sa.text("""
|
||||
INSERT INTO role (name, rank) VALUES ('admin', 1), ('user', 1000)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
""")
|
||||
)
|
||||
|
||||
# Create org table with settings fields
|
||||
op.create_table(
|
||||
'org',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text('gen_random_uuid()'),
|
||||
),
|
||||
sa.Column('name', sa.String, nullable=False),
|
||||
sa.Column('contact_name', sa.String, nullable=True),
|
||||
sa.Column('contact_email', sa.String, nullable=True),
|
||||
sa.Column('conversation_expiration', sa.Integer, nullable=True),
|
||||
# Settings fields moved to org table
|
||||
sa.Column('agent', sa.String, nullable=True),
|
||||
sa.Column('default_max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('security_analyzer', sa.String, nullable=True),
|
||||
sa.Column('confirmation_mode', sa.Boolean, nullable=True, default=False),
|
||||
sa.Column('default_llm_model', sa.String, nullable=True),
|
||||
sa.Column('_default_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('default_llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
|
||||
sa.Column('enable_default_condenser', sa.Boolean, nullable=False, default=True),
|
||||
sa.Column('billing_margin', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_proactive_conversation_starters',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
default=True,
|
||||
),
|
||||
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
|
||||
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
|
||||
sa.Column('org_version', sa.Integer, nullable=False, default=0),
|
||||
sa.Column('mcp_config', sa.JSON, nullable=True),
|
||||
sa.Column('_search_api_key', sa.String, nullable=True),
|
||||
sa.Column('_sandbox_api_key', sa.String, nullable=True),
|
||||
sa.Column('max_budget_per_task', sa.Float, nullable=True),
|
||||
sa.Column(
|
||||
'enable_solvability_analysis', sa.Boolean, nullable=True, default=False
|
||||
),
|
||||
sa.UniqueConstraint('name', name='org_name_unique'),
|
||||
)
|
||||
|
||||
# Create user table with user-specific settings fields
|
||||
op.create_table(
|
||||
'user',
|
||||
sa.Column(
|
||||
'id',
|
||||
postgresql.UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
server_default=sa.text('gen_random_uuid()'),
|
||||
),
|
||||
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=True),
|
||||
sa.Column('accepted_tos', sa.DateTime, nullable=True),
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean, nullable=True, default=False
|
||||
),
|
||||
sa.Column('language', sa.String, nullable=True),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
|
||||
sa.Column('email', sa.String, nullable=True),
|
||||
sa.Column('email_verified', sa.Boolean, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
['current_org_id'], ['org.id'], name='current_org_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
|
||||
)
|
||||
|
||||
# Create org_member table (junction table for many-to-many relationship)
|
||||
op.create_table(
|
||||
'org_member',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('_llm_api_key', sa.String, nullable=False),
|
||||
sa.Column('max_iterations', sa.Integer, nullable=True),
|
||||
sa.Column('llm_model', sa.String, nullable=True),
|
||||
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
|
||||
sa.Column('llm_base_url', sa.String, nullable=True),
|
||||
sa.Column('status', sa.String, nullable=True),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
|
||||
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
|
||||
sa.PrimaryKeyConstraint('org_id', 'user_id'),
|
||||
)
|
||||
|
||||
# Add org_id column to existing tables
|
||||
# billing_sessions
|
||||
op.add_column(
|
||||
'billing_sessions',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# Create conversation_metadata_saas table
|
||||
op.create_table(
|
||||
'conversation_metadata_saas',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
|
||||
),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
# custom_secrets
|
||||
op.add_column(
|
||||
'custom_secrets',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# api_keys
|
||||
op.add_column(
|
||||
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
|
||||
|
||||
# slack_conversation
|
||||
op.add_column(
|
||||
'slack_conversation',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# slack_users
|
||||
op.add_column(
|
||||
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
# stripe_customers
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
op.add_column(
|
||||
'stripe_customers',
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop migration_status column from user_settings table
|
||||
op.drop_column('user_settings', 'migration_status')
|
||||
|
||||
# Drop foreign keys and columns added to existing tables
|
||||
op.drop_constraint(
|
||||
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('stripe_customers', 'org_id')
|
||||
op.alter_column(
|
||||
'stripe_customers',
|
||||
'keycloak_user_id',
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
|
||||
op.drop_column('slack_users', 'org_id')
|
||||
|
||||
op.drop_constraint(
|
||||
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('slack_conversation', 'org_id')
|
||||
|
||||
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
|
||||
op.drop_column('api_keys', 'org_id')
|
||||
|
||||
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
|
||||
op.drop_column('custom_secrets', 'org_id')
|
||||
|
||||
# Drop conversation_metadata_saas table
|
||||
op.drop_table('conversation_metadata_saas')
|
||||
|
||||
op.drop_constraint(
|
||||
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
|
||||
)
|
||||
op.drop_column('billing_sessions', 'org_id')
|
||||
|
||||
# Drop tables in reverse order due to foreign key constraints
|
||||
op.drop_table('org_member')
|
||||
op.drop_table('user')
|
||||
op.drop_table('org')
|
||||
op.drop_table('role')
|
||||
10206
enterprise/poetry.lock
generated
10206
enterprise/poetry.lock
generated
File diff suppressed because one or more lines are too long
@@ -11,7 +11,7 @@ description = "Deploy OpenHands"
|
||||
authors = [ "OpenHands" ]
|
||||
license = "POLYFORM"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/All-Hands-AI/OpenHands"
|
||||
repository = "https://github.com/OpenHands/OpenHands"
|
||||
packages = [
|
||||
{ include = "server" },
|
||||
{ include = "storage" },
|
||||
|
||||
@@ -4,6 +4,10 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Ensure SAAS configuration is used
|
||||
if not os.getenv('OPENHANDS_CONFIG_CLS'):
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
|
||||
|
||||
import socketio # noqa: E402
|
||||
from fastapi import Request, status # noqa: E402
|
||||
from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
|
||||
@@ -31,7 +31,7 @@ from openhands.integrations.provider import (
|
||||
)
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.server.user_auth.user_auth import AuthType, UserAuth
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
|
||||
token_manager = TokenManager()
|
||||
@@ -52,7 +52,7 @@ class SaasUserAuth(UserAuth):
|
||||
settings_store: SaasSettingsStore | None = None
|
||||
secrets_store: SaasSecretsStore | None = None
|
||||
_settings: Settings | None = None
|
||||
_user_secrets: UserSecrets | None = None
|
||||
_secrets: Secrets | None = None
|
||||
accepted_tos: bool | None = None
|
||||
auth_type: AuthType = AuthType.COOKIE
|
||||
|
||||
@@ -102,7 +102,6 @@ class SaasUserAuth(UserAuth):
|
||||
return settings
|
||||
settings_store = await self.get_user_settings_store()
|
||||
settings = await settings_store.load()
|
||||
# If load() returned None, should settings be created?
|
||||
if settings:
|
||||
settings.email = self.email
|
||||
settings.email_verified = self.email_verified
|
||||
@@ -119,13 +118,13 @@ class SaasUserAuth(UserAuth):
|
||||
self.secrets_store = secrets_store
|
||||
return secrets_store
|
||||
|
||||
async def get_user_secrets(self):
|
||||
user_secrets = self._user_secrets
|
||||
async def get_secrets(self):
|
||||
user_secrets = self._secrets
|
||||
if user_secrets:
|
||||
return user_secrets
|
||||
secrets_store = await self.get_secrets_store()
|
||||
user_secrets = await secrets_store.load()
|
||||
self._user_secrets = user_secrets
|
||||
self._secrets = user_secrets
|
||||
return user_secrets
|
||||
|
||||
async def get_access_token(self) -> SecretStr | None:
|
||||
@@ -148,7 +147,7 @@ class SaasUserAuth(UserAuth):
|
||||
if not access_token:
|
||||
raise AuthError()
|
||||
|
||||
user_secrets = await self.get_user_secrets()
|
||||
user_secrets = await self.get_secrets()
|
||||
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
|
||||
@@ -37,6 +37,7 @@ from storage.offline_token_store import OfflineTokenStore
|
||||
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
def _before_sleep_callback(retry_state: RetryCallState) -> None:
|
||||
@@ -191,7 +192,7 @@ class TokenManager:
|
||||
access_token: str,
|
||||
idp: ProviderType,
|
||||
) -> dict[str, str | int]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
base_url = KEYCLOAK_SERVER_URL_EXT if self.external else KEYCLOAK_SERVER_URL
|
||||
url = f'{base_url}/realms/{KEYCLOAK_REALM_NAME}/broker/{idp.value}/token'
|
||||
headers = {
|
||||
@@ -265,7 +266,9 @@ class TokenManager:
|
||||
self._check_expiration_and_refresh
|
||||
)
|
||||
if not token_info:
|
||||
logger.info(f'No tokens for user: {username}, identity provider: {idp}')
|
||||
logger.error(
|
||||
f'No tokens for user: {username}, identity provider: {idp}'
|
||||
)
|
||||
raise ValueError(
|
||||
f'No tokens for user: {username}, identity provider: {idp}'
|
||||
)
|
||||
@@ -293,11 +296,12 @@ class TokenManager:
|
||||
refresh_token_expires_at: int,
|
||||
) -> dict[str, str | int] | None:
|
||||
current_time = int(time.time())
|
||||
# expire access_token ten minutes before actual expiration
|
||||
# expire access_token four hours before actual expiration
|
||||
# This ensures tokens are refreshed on resume to have at least 4 hours validity
|
||||
access_expired = (
|
||||
False
|
||||
if access_token_expires_at == 0
|
||||
else access_token_expires_at < current_time + 600
|
||||
else access_token_expires_at < current_time + 14400
|
||||
)
|
||||
refresh_expired = (
|
||||
False
|
||||
@@ -349,7 +353,7 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitHub token')
|
||||
@@ -375,7 +379,7 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitLab token')
|
||||
@@ -403,7 +407,7 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed Bitbucket token')
|
||||
|
||||
@@ -9,7 +9,7 @@ from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
@@ -525,16 +525,18 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
user_id = (
|
||||
conversation_metadata.user_id if conversation_metadata else None
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
else None
|
||||
)
|
||||
# Handle the stopped conversation asynchronously
|
||||
asyncio.create_task(
|
||||
|
||||
@@ -66,6 +66,7 @@ class SaaSServerConfig(ServerConfig):
|
||||
github_client_id: str = os.environ.get('GITHUB_APP_CLIENT_ID', '')
|
||||
enable_billing = os.environ.get('ENABLE_BILLING', 'false') == 'true'
|
||||
hide_llm_settings = os.environ.get('HIDE_LLM_SETTINGS', 'false') == 'true'
|
||||
stripe_publishable_key: str = os.environ.get('STRIPE_PUBLISHABLE_KEY', '')
|
||||
auth_url: str | None = os.environ.get('AUTH_URL')
|
||||
settings_store_class: str = 'storage.saas_settings_store.SaasSettingsStore'
|
||||
secret_store_class: str = 'storage.saas_secrets_store.SaasSecretsStore'
|
||||
@@ -168,6 +169,7 @@ class SaaSServerConfig(ServerConfig):
|
||||
'APP_SLUG': self.app_slug,
|
||||
'GITHUB_CLIENT_ID': self.github_client_id,
|
||||
'POSTHOG_CLIENT_KEY': self.posthog_client_key,
|
||||
'STRIPE_PUBLISHABLE_KEY': self.stripe_publishable_key,
|
||||
'FEATURE_FLAGS': {
|
||||
'ENABLE_BILLING': self.enable_billing,
|
||||
'HIDE_LLM_SETTINGS': self.hide_llm_settings,
|
||||
|
||||
@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
# Map of user settings versions to their corresponding default LLM models
|
||||
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
|
||||
USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
1: 'claude-3-5-sonnet-20241022',
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
@@ -30,29 +30,17 @@ USER_SETTINGS_VERSION_TO_MODEL = {
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
|
||||
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
|
||||
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
|
||||
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
|
||||
|
||||
LITE_LLM_API_URL = os.environ.get(
|
||||
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
|
||||
)
|
||||
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
|
||||
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
|
||||
SUBSCRIPTION_PRICE_DATA = {
|
||||
'MONTHLY_SUBSCRIPTION': {
|
||||
'unit_amount': 2000,
|
||||
'currency': 'usd',
|
||||
'product_data': {
|
||||
'name': 'OpenHands Monthly',
|
||||
'tax_code': 'txcd_10000000',
|
||||
},
|
||||
'tax_behavior': 'exclusive',
|
||||
'recurring': {'interval': 'month', 'interval_count': 1},
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '20'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
|
||||
@@ -102,5 +90,5 @@ def get_default_litellm_model():
|
||||
"""
|
||||
if LITELLM_DEFAULT_MODEL:
|
||||
return LITELLM_DEFAULT_MODEL
|
||||
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
|
||||
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
|
||||
return build_litellm_proxy_model_path(model)
|
||||
|
||||
@@ -44,11 +44,13 @@ class MyProcessor(MaintenanceTaskProcessor):
|
||||
### UserVersionUpgradeProcessor
|
||||
|
||||
Located in `user_version_upgrade_processor.py`, this processor:
|
||||
|
||||
- Handles up to 100 user IDs per task
|
||||
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
|
||||
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
|
||||
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
|
||||
|
||||
**Usage:**
|
||||
|
||||
```python
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
|
||||
|
||||
@@ -144,22 +146,26 @@ task = create_maintenance_task(
|
||||
## Best Practices
|
||||
|
||||
### Processor Design
|
||||
|
||||
- Keep tasks short-running (under 1 minute)
|
||||
- Handle errors gracefully and return meaningful error information
|
||||
- Use batch processing for large datasets
|
||||
- Include progress information in the return dict
|
||||
|
||||
### Error Handling
|
||||
|
||||
- Always wrap your processor logic in try-catch blocks
|
||||
- Return structured error information
|
||||
- Log important events for debugging
|
||||
|
||||
### Performance
|
||||
|
||||
- Limit batch sizes to avoid long-running tasks
|
||||
- Use database sessions efficiently
|
||||
- Consider memory usage for large datasets
|
||||
|
||||
### Testing
|
||||
|
||||
- Create unit tests for your processors
|
||||
- Test error conditions
|
||||
- Verify the processor serialization/deserialization works correctly
|
||||
@@ -167,6 +173,7 @@ task = create_maintenance_task(
|
||||
## Database Patterns
|
||||
|
||||
The maintenance task system follows the repository's established patterns:
|
||||
|
||||
- Uses `session_maker()` for database operations
|
||||
- Wraps sync database operations in `call_sync_from_async` for async routes
|
||||
- Follows proper SQLAlchemy query patterns
|
||||
@@ -174,15 +181,18 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Integration with Existing Systems
|
||||
|
||||
### User Management
|
||||
|
||||
- Integrates with the existing `UserSettings` model
|
||||
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
|
||||
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
|
||||
- Maintains compatibility with existing user management workflows
|
||||
|
||||
### Authentication
|
||||
|
||||
- Admin endpoints use the existing SaaS authentication system
|
||||
- Requires users to have `admin = True` in their UserSettings
|
||||
|
||||
### Monitoring
|
||||
|
||||
- Tasks are logged with structured information
|
||||
- Status updates are tracked in the database
|
||||
- Error information is preserved for debugging
|
||||
@@ -206,6 +216,7 @@ The maintenance task system follows the repository's established patterns:
|
||||
## Future Enhancements
|
||||
|
||||
Potential improvements that could be added:
|
||||
|
||||
- Task dependencies and scheduling
|
||||
- Retry mechanisms for failed tasks
|
||||
- Real-time progress updates
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
|
||||
|
||||
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
|
||||
"""
|
||||
Processor for upgrading user settings to the current version.
|
||||
|
||||
This processor takes a list of user IDs and upgrades any users
|
||||
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
|
||||
"""
|
||||
|
||||
user_ids: List[str]
|
||||
|
||||
async def __call__(self, task: MaintenanceTask) -> dict:
|
||||
"""
|
||||
Process user version upgrades for the specified user IDs.
|
||||
|
||||
Args:
|
||||
task: The maintenance task being processed
|
||||
|
||||
Returns:
|
||||
dict: Results containing successful and failed user IDs
|
||||
"""
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:start',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_count': len(self.user_ids),
|
||||
'current_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
if len(self.user_ids) > 100:
|
||||
raise ValueError(
|
||||
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
|
||||
)
|
||||
|
||||
config = load_openhands_config()
|
||||
|
||||
# Track results
|
||||
successful_upgrades = []
|
||||
failed_upgrades = []
|
||||
users_already_current = []
|
||||
|
||||
# Find users that need upgrading
|
||||
with session_maker() as session:
|
||||
users_to_upgrade = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id.in_(self.user_ids),
|
||||
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Track users that are already current
|
||||
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
|
||||
users_already_current = [
|
||||
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
|
||||
]
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:found_users',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'users_to_upgrade': len(users_to_upgrade),
|
||||
'users_already_current': len(users_already_current),
|
||||
'total_requested': len(self.user_ids),
|
||||
},
|
||||
)
|
||||
|
||||
# Process each user that needs upgrading
|
||||
for user_settings in users_to_upgrade:
|
||||
user_id = user_settings.keycloak_user_id
|
||||
old_version = user_settings.user_version
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:upgrading_user',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
# Create SaasSettingsStore instance and upgrade
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
await settings_store.create_default_settings(user_settings)
|
||||
|
||||
successful_upgrades.append(
|
||||
{
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:user_upgraded',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'new_version': CURRENT_USER_SETTINGS_VERSION,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_upgrades.append(
|
||||
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
|
||||
)
|
||||
|
||||
logger.error(
|
||||
'user_version_upgrade_processor:user_upgrade_failed',
|
||||
extra={
|
||||
'task_id': task.id,
|
||||
'user_id': user_id,
|
||||
'old_version': old_version,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Create result summary
|
||||
result = {
|
||||
'total_users': len(self.user_ids),
|
||||
'users_already_current': users_already_current,
|
||||
'successful_upgrades': successful_upgrades,
|
||||
'failed_upgrades': failed_upgrades,
|
||||
'summary': (
|
||||
f'Processed {len(self.user_ids)} users: '
|
||||
f'{len(successful_upgrades)} upgraded, '
|
||||
f'{len(users_already_current)} already current, '
|
||||
f'{len(failed_upgrades)} errors'
|
||||
),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
'user_version_upgrade_processor:completed',
|
||||
extra={'task_id': task.id, 'result': result},
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -1,12 +1,10 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_store import OrgStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
@@ -18,15 +16,14 @@ async def get_byor_key_from_db(user_id: str) -> str | None:
|
||||
"""Get the BYOR key from the database for a user."""
|
||||
|
||||
def _get_byor_key():
|
||||
with session_maker() as session:
|
||||
user_db_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if user_db_settings and user_db_settings.llm_api_key_for_byor:
|
||||
return user_db_settings.llm_api_key_for_byor
|
||||
return None
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return None
|
||||
return (
|
||||
org.default_llm_api_key_for_byor.get_secret_value()
|
||||
if org.default_llm_api_key_for_byor
|
||||
else None
|
||||
)
|
||||
|
||||
return await call_sync_from_async(_get_byor_key)
|
||||
|
||||
@@ -35,72 +32,42 @@ async def store_byor_key_in_db(user_id: str, key: str) -> None:
|
||||
"""Store the BYOR key in the database for a user."""
|
||||
|
||||
def _update_user_settings():
|
||||
with session_maker() as session:
|
||||
user_db_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(
|
||||
'Org not found when trying to store BYOR key for user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
if user_db_settings:
|
||||
user_db_settings.llm_api_key_for_byor = key
|
||||
session.commit()
|
||||
logger.info(
|
||||
'Successfully stored BYOR key in user settings',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'User settings not found when trying to store BYOR key',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return
|
||||
OrgStore.update_org(org.id, {'llm_api_key_for_byor': key})
|
||||
|
||||
await call_sync_from_async(_update_user_settings)
|
||||
|
||||
|
||||
async def generate_byor_key(user_id: str) -> str | None:
|
||||
"""Generate a new BYOR key for a user."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json={
|
||||
key = await LiteLlmManager.generate_key(
|
||||
user_id, None, f'BYOR Key - user {user_id}', {'type': 'byor'}
|
||||
)
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'metadata': {'type': 'byor'},
|
||||
'key_alias': f'BYOR Key - user {user_id}',
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json.get('key')
|
||||
|
||||
if key:
|
||||
logger.info(
|
||||
'Successfully generated new BYOR key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'key_length': len(key) if key else 0,
|
||||
'key_prefix': key[:10] + '...'
|
||||
if key and len(key) > 10
|
||||
else key,
|
||||
},
|
||||
)
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id, 'response_json': response_json},
|
||||
)
|
||||
return None
|
||||
return key
|
||||
else:
|
||||
logger.error(
|
||||
'Failed to generate BYOR LLM API key - no key in response',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error generating BYOR key',
|
||||
@@ -111,29 +78,14 @@ async def generate_byor_key(user_id: str) -> str | None:
|
||||
|
||||
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
|
||||
"""Delete the BYOR key from LiteLLM using the key directly."""
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Delete the key directly using the key value
|
||||
delete_url = f'{LITE_LLM_API_URL}/key/delete'
|
||||
delete_payload = {'keys': [byor_key]}
|
||||
|
||||
delete_response = await client.post(delete_url, json=delete_payload)
|
||||
delete_response.raise_for_status()
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
await LiteLlmManager.delete_key(byor_key)
|
||||
logger.info(
|
||||
'Successfully deleted BYOR key from LiteLLM',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Error deleting BYOR key from LiteLLM',
|
||||
@@ -311,15 +263,6 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
|
||||
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
|
||||
|
||||
try:
|
||||
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
|
||||
logger.warning(
|
||||
'LiteLLM API configuration not found', extra={'user_id': user_id}
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='LiteLLM API configuration not found',
|
||||
)
|
||||
|
||||
# Get the existing BYOR key from the database
|
||||
existing_byor_key = await get_byor_key_from_db(user_id)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Literal, Optional
|
||||
@@ -20,7 +21,9 @@ from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from storage.database import session_maker
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
@@ -80,7 +83,7 @@ def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if (request.url.hostname or '').endswith('staging.all-hand.dev')
|
||||
if request.url.hostname.endswith('staging.all-hand.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
@@ -138,6 +141,32 @@ async def keycloak_callback(
|
||||
)
|
||||
|
||||
user_id = user_info['sub']
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(user_id, user_settings, user_info)
|
||||
else:
|
||||
# new user
|
||||
user = await UserStore.create_user(user_id, user_info)
|
||||
|
||||
if not user:
|
||||
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
|
||||
|
||||
# default to github IDP for now.
|
||||
# TODO: remove default once Keycloak is updated universally with the new attribute.
|
||||
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
|
||||
@@ -174,17 +203,19 @@ async def keycloak_callback(
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.set(
|
||||
distinct_id=posthog_user_id,
|
||||
properties={
|
||||
'user_id': posthog_user_id,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
posthog.identify(
|
||||
posthog_user_id,
|
||||
{
|
||||
'$set': {
|
||||
'user_id': posthog_user_id, # Explicitly set as property
|
||||
'original_user_id': user_id, # Store the original user_id
|
||||
'is_feature_env': IS_FEATURE_ENV, # Track if this is a feature environment
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'auth:posthog_set:failed',
|
||||
'auth:posthog_identify:failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
@@ -212,17 +243,7 @@ async def keycloak_callback(
|
||||
f'&state={state}'
|
||||
)
|
||||
|
||||
has_accepted_tos = False
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
has_accepted_tos = (
|
||||
user_settings is not None and user_settings.accepted_tos is not None
|
||||
)
|
||||
|
||||
has_accepted_tos = user.accepted_tos is not None
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
@@ -341,24 +362,15 @@ async def accept_tos(request: Request):
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.accepted_tos = datetime.now(timezone.utc)
|
||||
session.merge(user_settings)
|
||||
else:
|
||||
# Create user settings if they don't exist
|
||||
user_settings = UserSettings(
|
||||
keycloak_user_id=user_id,
|
||||
accepted_tos=datetime.now(timezone.utc),
|
||||
user_version=0, # This will trigger a migration to the latest version on next load
|
||||
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
|
||||
if not user:
|
||||
session.rollback()
|
||||
logger.error('User for {user_id} not found.')
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User does not exist'},
|
||||
)
|
||||
session.add(user_settings)
|
||||
|
||||
user.accepted_tos = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
@@ -2,28 +2,20 @@
|
||||
import typing
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import httpx
|
||||
import stripe
|
||||
from dateutil.relativedelta import relativedelta # type: ignore
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse, RedirectResponse
|
||||
from fastapi.responses import RedirectResponse
|
||||
from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import (
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
STRIPE_API_KEY,
|
||||
STRIPE_WEBHOOK_SECRET,
|
||||
SUBSCRIPTION_PRICE_DATA,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import session_maker
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
@@ -31,23 +23,41 @@ stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
|
||||
|
||||
class BillingSessionType(Enum):
|
||||
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
|
||||
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
"""
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
"""
|
||||
if not is_all_hands_saas_environment(request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
)
|
||||
|
||||
|
||||
class GetCreditsResponse(BaseModel):
|
||||
credits: Decimal | None = None
|
||||
|
||||
|
||||
class SubscriptionAccessResponse(BaseModel):
|
||||
start_at: datetime
|
||||
end_at: datetime
|
||||
created_at: datetime
|
||||
cancelled_at: datetime | None = None
|
||||
stripe_subscription_id: str | None = None
|
||||
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
amount: int
|
||||
|
||||
@@ -78,117 +88,23 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
|
||||
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
|
||||
if not stripe_service.STRIPE_API_KEY:
|
||||
return GetCreditsResponse()
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_json = await _get_litellm_user(client, user_id)
|
||||
credits = calculate_credits(user_json['user_info'])
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
user_id, str(user.current_org_id)
|
||||
)
|
||||
# Update to use calculate_credits
|
||||
spend = user_team_info.get('spend', 0)
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
|
||||
credits = max(max_budget - spend, 0)
|
||||
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
|
||||
|
||||
|
||||
# Endpoint to retrieve user's current subscription access
|
||||
@billing_router.get('/subscription-access')
|
||||
async def get_subscription_access(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> SubscriptionAccessResponse | None:
|
||||
"""Get details of the currently valid subscription for the user."""
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.first()
|
||||
)
|
||||
if not subscription_access:
|
||||
return None
|
||||
return SubscriptionAccessResponse(
|
||||
start_at=subscription_access.start_at,
|
||||
end_at=subscription_access.end_at,
|
||||
created_at=subscription_access.created_at,
|
||||
cancelled_at=subscription_access.cancelled_at,
|
||||
stripe_subscription_id=subscription_access.stripe_subscription_id,
|
||||
)
|
||||
|
||||
|
||||
# Endpoint to check if a user has entered a payment method into stripe
|
||||
@billing_router.post('/has-payment-method')
|
||||
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
return await stripe_service.has_payment_method(user_id)
|
||||
|
||||
|
||||
# Endpoint to cancel user's subscription
|
||||
@billing_router.post('/cancel-subscription')
|
||||
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
|
||||
"""Cancel user's active subscription at the end of the current billing period."""
|
||||
if not user_id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
with session_maker() as session:
|
||||
# Find the user's active subscription
|
||||
now = datetime.now(UTC)
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if not subscription_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='No active subscription found',
|
||||
)
|
||||
|
||||
if not subscription_access.stripe_subscription_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot cancel subscription: missing Stripe subscription ID',
|
||||
)
|
||||
|
||||
try:
|
||||
# Cancel the subscription in Stripe at period end
|
||||
await stripe.Subscription.modify_async(
|
||||
subscription_access.stripe_subscription_id, cancel_at_period_end=True
|
||||
)
|
||||
|
||||
# Update local database
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
'end_at': subscription_access.end_at,
|
||||
},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
{'status': 'success', 'message': 'Subscription cancelled successfully'}
|
||||
)
|
||||
|
||||
except stripe.StripeError as e:
|
||||
logger.error(
|
||||
'stripe_cancellation_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'stripe_subscription_id': subscription_access.stripe_subscription_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f'Failed to cancel subscription: {str(e)}',
|
||||
)
|
||||
return await stripe_service.has_payment_method_by_user_id(user_id)
|
||||
|
||||
|
||||
# Endpoint to create a new setup intent in stripe
|
||||
@@ -196,15 +112,16 @@ async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONRespon
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
validate_saas_environment(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url=f'{request.base_url}?free_credits=success',
|
||||
cancel_url=f'{request.base_url}',
|
||||
)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Endpoint to create a new Stripe checkout session for credit purchase
|
||||
@@ -214,9 +131,11 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
customer=customer_info['customer_id'],
|
||||
line_items=[
|
||||
{
|
||||
'price_data': {
|
||||
@@ -229,7 +148,7 @@ async def create_checkout_session(
|
||||
'tax_behavior': 'exclusive',
|
||||
},
|
||||
'quantity': 1,
|
||||
}
|
||||
},
|
||||
],
|
||||
mode='payment',
|
||||
payment_method_types=['card'],
|
||||
@@ -242,8 +161,9 @@ async def create_checkout_session(
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'stripe_customer_id': customer_info['customer_id'],
|
||||
'user_id': user_id,
|
||||
'org_id': customer_info['org_id'],
|
||||
'amount': body.amount,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
},
|
||||
@@ -252,101 +172,14 @@ async def create_checkout_session(
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
org_id=customer_info['org_id'],
|
||||
price=body.amount,
|
||||
price_code='NA',
|
||||
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@billing_router.post('/subscription-checkout-session')
|
||||
async def create_subscription_checkout_session(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
existing_active_subscription = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.filter(SubscriptionAccess.user_id == user_id)
|
||||
.filter(SubscriptionAccess.start_at <= now)
|
||||
.filter(SubscriptionAccess.end_at >= now)
|
||||
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing_active_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
|
||||
)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
line_items=[
|
||||
{
|
||||
'price_data': subscription_price_data,
|
||||
'quantity': 1,
|
||||
}
|
||||
],
|
||||
mode='subscription',
|
||||
payment_method_types=['card'],
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
subscription_data={
|
||||
'metadata': {
|
||||
'user_id': user_id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
}
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_subscription_checkout_session',
|
||||
extra={
|
||||
'stripe_customer_id': customer_id,
|
||||
'user_id': user_id,
|
||||
'checkout_session_id': checkout_session.id,
|
||||
'billing_session_type': billing_session_type.value,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
billing_session = BillingSession(
|
||||
id=checkout_session.id,
|
||||
user_id=user_id,
|
||||
price=subscription_price_data['unit_amount'],
|
||||
price_code='NA',
|
||||
billing_session_type=billing_session_type.value,
|
||||
)
|
||||
session.add(billing_session)
|
||||
session.commit()
|
||||
|
||||
return CreateBillingSessionResponse(
|
||||
redirect_url=typing.cast(str, checkout_session.url)
|
||||
)
|
||||
|
||||
|
||||
@billing_router.get('/create-subscription-checkout-session')
|
||||
async def create_subscription_checkout_session_via_get(
|
||||
request: Request,
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
return RedirectResponse(response.redirect_url)
|
||||
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
|
||||
|
||||
|
||||
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
|
||||
@@ -368,15 +201,6 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
!= BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
stripe_session = stripe.checkout.Session.retrieve(session_id)
|
||||
if stripe_session.status != 'complete':
|
||||
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
|
||||
@@ -390,31 +214,37 @@ async def success_callback(session_id: str, request: Request):
|
||||
)
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Update max budget in litellm
|
||||
user_json = await _get_litellm_user(client, billing_session.user_id)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
new_max_budget = (
|
||||
(user_json.get('user_info') or {}).get('max_budget') or 0
|
||||
) + add_credits
|
||||
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
|
||||
user = UserStore.get_user_by_id(billing_session.user_id)
|
||||
user_team_info = await LiteLlmManager.get_user_team_info(
|
||||
billing_session.user_id, str(user.current_org_id)
|
||||
)
|
||||
amount_subtotal = stripe_session.amount_subtotal or 0
|
||||
add_credits = amount_subtotal / 100
|
||||
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
|
||||
'max_budget', 0
|
||||
)
|
||||
new_max_budget = max_budget + add_credits
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = amount_subtotal
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
await LiteLlmManager.update_team_and_users_budget(
|
||||
str(user.current_org_id), new_max_budget
|
||||
)
|
||||
|
||||
# Store transaction status
|
||||
billing_session.status = 'completed'
|
||||
billing_session.price = add_credits
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
session.merge(billing_session)
|
||||
logger.info(
|
||||
'stripe_checkout_success',
|
||||
extra={
|
||||
'amount_subtotal': stripe_session.amount_subtotal,
|
||||
'user_id': billing_session.user_id,
|
||||
'org_id': str(user.current_org_id),
|
||||
'checkout_session_id': billing_session.id,
|
||||
'stripe_customer_id': stripe_session.customer,
|
||||
},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=success', status_code=302
|
||||
@@ -444,203 +274,6 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
session.merge(billing_session)
|
||||
session.commit()
|
||||
|
||||
# Redirect credit purchases to billing screen, subscriptions to LLM settings
|
||||
if (
|
||||
billing_session.billing_session_type
|
||||
== BillingSessionType.DIRECT_PAYMENT.value
|
||||
):
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings/billing?checkout=cancel',
|
||||
status_code=302,
|
||||
)
|
||||
else:
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
# If no billing session found, default to LLM settings (subscription flow)
|
||||
return RedirectResponse(
|
||||
f'{request.base_url}settings?checkout=cancel', status_code=302
|
||||
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@billing_router.post('/stripe-webhook')
|
||||
async def stripe_webhook(request: Request) -> JSONResponse:
|
||||
"""Endpoint for stripe webhooks."""
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except ValueError as e:
|
||||
# Invalid payload
|
||||
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
|
||||
except stripe.SignatureVerificationError as e:
|
||||
# Invalid signature
|
||||
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
|
||||
|
||||
# Handle the event
|
||||
logger.info('stripe_webhook_event', extra={'event': event})
|
||||
event_type = event['type']
|
||||
if event_type == 'invoice.paid':
|
||||
invoice = event['data']['object']
|
||||
amount_paid = invoice.amount_paid
|
||||
metadata = invoice.parent.subscription_details.metadata # type: ignore
|
||||
billing_session_type = metadata.billing_session_type
|
||||
assert (
|
||||
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
|
||||
)
|
||||
user_id = metadata.user_id
|
||||
|
||||
start_at = datetime.now(UTC)
|
||||
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
|
||||
end_at = start_at + relativedelta(months=1)
|
||||
else:
|
||||
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = SubscriptionAccess(
|
||||
status='ACTIVE',
|
||||
user_id=user_id,
|
||||
start_at=start_at,
|
||||
end_at=end_at,
|
||||
amount_paid=amount_paid,
|
||||
stripe_invoice_payment_id=invoice.payment_intent,
|
||||
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
|
||||
)
|
||||
session.add(subscription_access)
|
||||
session.commit()
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
# Handle subscription cancellation
|
||||
if subscription.get('cancel_at_period_end') is True:
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(
|
||||
SubscriptionAccess.stripe_subscription_id == subscription_id
|
||||
)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access and not subscription_access.cancelled_at:
|
||||
subscription_access.cancelled_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'subscription_cancelled_via_webhook',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
subscription = event['data']['object']
|
||||
subscription_id = subscription['id']
|
||||
|
||||
with session_maker() as session:
|
||||
subscription_access = (
|
||||
session.query(SubscriptionAccess)
|
||||
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
|
||||
.filter(SubscriptionAccess.status == 'ACTIVE')
|
||||
.first()
|
||||
)
|
||||
|
||||
if subscription_access:
|
||||
subscription_access.status = 'DISABLED'
|
||||
subscription_access.updated_at = datetime.now(UTC)
|
||||
session.merge(subscription_access)
|
||||
session.commit()
|
||||
|
||||
# Reset user settings to free tier defaults
|
||||
reset_user_to_free_tier_settings(subscription_access.user_id)
|
||||
|
||||
logger.info(
|
||||
'subscription_expired_reset_to_free_tier',
|
||||
extra={
|
||||
'stripe_subscription_id': subscription_id,
|
||||
'user_id': subscription_access.user_id,
|
||||
'subscription_access_id': subscription_access.id,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
|
||||
|
||||
return JSONResponse({'status': 'success'})
|
||||
|
||||
|
||||
def reset_user_to_free_tier_settings(user_id: str) -> None:
|
||||
"""Reset user settings to free tier defaults when subscription ends."""
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user_settings:
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = None
|
||||
user_settings.llm_api_key_for_byor = None
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
user_settings.max_budget_per_task = None
|
||||
user_settings.confirmation_mode = False
|
||||
user_settings.enable_solvability_analysis = False
|
||||
user_settings.security_analyzer = 'llm'
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
user_settings.language = 'en'
|
||||
user_settings.enable_default_condenser = True
|
||||
user_settings.enable_sound_notifications = False
|
||||
user_settings.enable_proactive_conversation_starters = True
|
||||
user_settings.user_consents_to_analytics = False
|
||||
|
||||
session.merge(user_settings)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'user_settings_reset_to_free_tier',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'reset_timestamp': datetime.now(UTC).isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
|
||||
"""Get a user from litellm with the id matching that given.
|
||||
|
||||
If no such user exists, returns a dummy user in the format:
|
||||
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
|
||||
"""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
async def _upsert_litellm_user(
|
||||
client: httpx.AsyncClient, user_id: str, max_budget: float
|
||||
):
|
||||
"""Insert / Update a user in litellm."""
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
},
|
||||
json={
|
||||
'user_id': user_id,
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -6,7 +6,7 @@ from threading import Thread
|
||||
from fastapi import APIRouter, FastAPI
|
||||
from sqlalchemy import func, select
|
||||
from storage.database import a_session_maker, engine, session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import wait_all
|
||||
@@ -127,7 +127,7 @@ def _db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
with session_maker() as session:
|
||||
num_users = session.query(UserSettings).count()
|
||||
num_users = session.query(User).count()
|
||||
time.sleep(delay)
|
||||
logger.info(
|
||||
'check',
|
||||
@@ -155,7 +155,7 @@ async def _a_db_check(delay: int):
|
||||
delay: Number of seconds to hold the database connection
|
||||
"""
|
||||
async with a_session_maker() as a_session:
|
||||
stmt = select(func.count(UserSettings.id))
|
||||
stmt = select(func.count(User.id))
|
||||
num_users = await a_session.execute(stmt)
|
||||
await asyncio.sleep(delay)
|
||||
logger.info(f'a_num_users:{num_users.scalar_one()}')
|
||||
|
||||
@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
|
||||
update_conversation_stats,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.server.shared import conversation_manager
|
||||
|
||||
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
|
||||
|
||||
def _get_user_id(conversation_id: str) -> str:
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
|
||||
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
|
||||
from sqlalchemy.future import select
|
||||
from storage.database import session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.shared import file_store
|
||||
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
|
||||
def _verify_conversation():
|
||||
with session_maker() as session:
|
||||
metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi.responses import RedirectResponse
|
||||
from server.logger import logger
|
||||
|
||||
from openhands.server.shared import config
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
GITHUB_PROXY_ENDPOINTS = bool(os.environ.get('GITHUB_PROXY_ENDPOINTS'))
|
||||
|
||||
@@ -87,7 +88,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
]
|
||||
body = urlencode(query_params, doseq=True)
|
||||
url = 'https://github.com/login/oauth/access_token'
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, content=body)
|
||||
return Response(
|
||||
response.content,
|
||||
@@ -101,7 +102,7 @@ def add_github_proxy_routes(app: FastAPI):
|
||||
logger.info(f'github_proxy_post:1:{path}')
|
||||
body = await request.body()
|
||||
url = f'https://github.com/{path}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, content=body, headers=request.headers)
|
||||
return Response(
|
||||
response.content,
|
||||
|
||||
@@ -15,7 +15,6 @@ from integrations.slack.slack_manager import SlackManager
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import (
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
@@ -35,6 +34,8 @@ from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.shared import config, sio
|
||||
@@ -79,6 +80,14 @@ async def install_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if not config.jwt_secret:
|
||||
logger.error('slack_install_callback_error JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncWebClient() # no prepared token needed for this
|
||||
# Complete the installation by calling oauth.v2.access API method
|
||||
@@ -94,16 +103,17 @@ async def install_callback(
|
||||
|
||||
# Create a state variable for keycloak oauth
|
||||
payload = {}
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if state:
|
||||
payload = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
payload['slack_user_id'] = authed_user.get('id')
|
||||
payload['bot_access_token'] = bot_access_token
|
||||
payload['team_id'] = team_id
|
||||
|
||||
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
|
||||
state = jwt.encode(
|
||||
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
|
||||
# Redirect into keycloak
|
||||
scope = quote('openid email profile offline_access')
|
||||
@@ -149,9 +159,16 @@ async def keycloak_callback(
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
|
||||
if not config.jwt_secret:
|
||||
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
|
||||
return _html_response(
|
||||
title='Error',
|
||||
description=html.escape('JWT not configured'),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
payload: dict[str, str] = jwt.decode(
|
||||
state, jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
|
||||
)
|
||||
slack_user_id = payload['slack_user_id']
|
||||
bot_access_token = payload['bot_access_token']
|
||||
@@ -180,6 +197,22 @@ async def keycloak_callback(
|
||||
|
||||
user_info = await token_manager.get_user_info(keycloak_access_token)
|
||||
keycloak_user_id = user_info['sub']
|
||||
user = UserStore.get_user_by_id(keycloak_user_id)
|
||||
if not user:
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
if not user_settings:
|
||||
return _html_response(
|
||||
title='Failed to authenticate.',
|
||||
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
|
||||
status_code=400,
|
||||
)
|
||||
user = await UserStore.migrate_user(keycloak_user_id, user_settings, user_info)
|
||||
|
||||
# These tokens are offline access tokens - store them!
|
||||
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
|
||||
@@ -211,6 +244,7 @@ async def keycloak_callback(
|
||||
slack_display_name = slack_user_info.data['user']['profile']['display_name']
|
||||
slack_user = SlackUser(
|
||||
keycloak_user_id=keycloak_user_id,
|
||||
org_id=user.current_org_id,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_display_name=slack_display_name,
|
||||
)
|
||||
@@ -305,7 +339,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
|
||||
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
|
||||
payload = json.loads(form.get('payload'))
|
||||
|
||||
logger.info('slack_on_form_interaction', extra={'payload': payload})
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from sqlalchemy import orm
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.controller.agent import Agent
|
||||
from openhands.core.config import LLMConfig, OpenHandsConfig
|
||||
@@ -52,6 +53,7 @@ from openhands.storage.locations import (
|
||||
get_conversation_events_dir,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
from openhands.utils.import_utils import get_impl
|
||||
from openhands.utils.shutdown_listener import should_continue
|
||||
from openhands.utils.utils import create_registry_and_conversation_stats
|
||||
@@ -266,9 +268,10 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
):
|
||||
logger.info('starting_nested_conversation', extra={'sid': sid})
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'X-Session-API-Key': session_api_key,
|
||||
}
|
||||
},
|
||||
) as client:
|
||||
await self._setup_nested_settings(client, api_url, settings)
|
||||
await self._setup_provider_tokens(client, api_url, settings)
|
||||
@@ -484,9 +487,10 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
raise ValueError(f'no_such_conversation:{sid}')
|
||||
nested_url = self._get_nested_url_for_runtime(runtime['runtime_id'], sid)
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'X-Session-API-Key': runtime['session_api_key'],
|
||||
}
|
||||
},
|
||||
) as client:
|
||||
response = await client.post(f'{nested_url}/events', json=data)
|
||||
response.raise_for_status()
|
||||
@@ -522,16 +526,18 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
"""
|
||||
|
||||
with session_maker() as session:
|
||||
conversation_metadata = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not conversation_metadata:
|
||||
if not conversation_metadata_saas:
|
||||
raise ValueError(f'No conversation found {conversation_id}')
|
||||
|
||||
return conversation_metadata.user_id
|
||||
return str(conversation_metadata_saas.user_id)
|
||||
|
||||
async def _get_runtime_status_from_nested_runtime(
|
||||
self, session_api_key: Any | None, nested_url: str, conversation_id: str
|
||||
@@ -551,9 +557,10 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
return None
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={
|
||||
'X-Session-API-Key': session_api_key,
|
||||
}
|
||||
},
|
||||
) as client:
|
||||
# Query the nested runtime for conversation info
|
||||
response = await client.get(nested_url)
|
||||
@@ -828,6 +835,7 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
@contextlib.asynccontextmanager
|
||||
async def _httpx_client(self):
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(),
|
||||
headers={'X-API-Key': self.config.sandbox.api_key or ''},
|
||||
timeout=_HTTP_TIMEOUT,
|
||||
) as client:
|
||||
@@ -853,9 +861,17 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
with session_maker() as session:
|
||||
# Only include conversations updated in the past week
|
||||
one_week_ago = datetime.now(UTC) - timedelta(days=7)
|
||||
query = session.query(StoredConversationMetadata.conversation_id).filter(
|
||||
StoredConversationMetadata.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
query = (
|
||||
session.query(StoredConversationMetadata.conversation_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == user_id,
|
||||
StoredConversationMetadata.last_updated_at >= one_week_ago,
|
||||
)
|
||||
)
|
||||
user_conversation_ids = set(query)
|
||||
return user_conversation_ids
|
||||
@@ -929,11 +945,16 @@ class SaasNestedConversationManager(ConversationManager):
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
|
||||
.first()
|
||||
)
|
||||
if conversation_metadata is None or conversation_metadata_saas is None:
|
||||
# Conversation is running in different server
|
||||
return
|
||||
|
||||
user_id = conversation_metadata.user_id
|
||||
user_id = conversation_metadata_saas.user_id
|
||||
|
||||
# Get the id of the next event which is not present
|
||||
events_dir = get_conversation_events_dir(
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
from storage.api_key import ApiKey
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.billing_session_type import BillingSessionType
|
||||
from storage.conversation_callback import CallbackStatus, ConversationCallback
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.experiment_assignment import ExperimentAssignment
|
||||
from storage.feedback import ConversationFeedback, Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
from storage.role import Role
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_team import SlackTeam
|
||||
from storage.slack_user import SlackUser
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.subscription_access import SubscriptionAccess
|
||||
from storage.subscription_access_status import SubscriptionAccessStatus
|
||||
from storage.user import User
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
__all__ = [
|
||||
'ApiKey',
|
||||
'AuthTokens',
|
||||
'BillingSession',
|
||||
'BillingSessionType',
|
||||
'CallbackStatus',
|
||||
'ConversationCallback',
|
||||
'ConversationFeedback',
|
||||
'StoredConversationMetadataSaas',
|
||||
'ConversationWork',
|
||||
'ExperimentAssignment',
|
||||
'Feedback',
|
||||
'GithubAppInstallation',
|
||||
'GitlabWebhook',
|
||||
'JiraConversation',
|
||||
'JiraDcConversation',
|
||||
'JiraDcUser',
|
||||
'JiraDcWorkspace',
|
||||
'JiraUser',
|
||||
'JiraWorkspace',
|
||||
'LinearConversation',
|
||||
'LinearUser',
|
||||
'LinearWorkspace',
|
||||
'MaintenanceTask',
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
'Role',
|
||||
'SlackConversation',
|
||||
'SlackTeam',
|
||||
'SlackUser',
|
||||
'StoredConversationMetadata',
|
||||
'StoredOfflineToken',
|
||||
'StoredRepository',
|
||||
'StoredCustomSecrets',
|
||||
'StripeCustomer',
|
||||
'SubscriptionAccess',
|
||||
'SubscriptionAccessStatus',
|
||||
'User',
|
||||
'UserRepositoryMap',
|
||||
'UserSettings',
|
||||
'WebhookStatus',
|
||||
]
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,13 @@ class ApiKey(Base):
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
key = Column(String(255), nullable=False, unique=True, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
name = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
)
|
||||
last_used_at = Column(DateTime, nullable=True)
|
||||
expires_at = Column(DateTime, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='api_keys')
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
|
||||
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -11,9 +13,9 @@ class BillingSession(Base): # type: ignore
|
||||
"""
|
||||
|
||||
__tablename__ = 'billing_sessions'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
status = Column(
|
||||
Enum(
|
||||
'in_progress',
|
||||
@@ -24,15 +26,6 @@ class BillingSession(Base): # type: ignore
|
||||
),
|
||||
default='in_progress',
|
||||
)
|
||||
billing_session_type = Column(
|
||||
Enum(
|
||||
'DIRECT_PAYMENT',
|
||||
'MONTHLY_SUBSCRIPTION',
|
||||
name='billing_session_type_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='DIRECT_PAYMENT',
|
||||
)
|
||||
price = Column(DECIMAL(19, 4), nullable=False)
|
||||
price_code = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -43,3 +36,6 @@ class BillingSession(Base): # type: ignore
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='billing_sessions')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
@@ -7,6 +8,9 @@ from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.util import await_only
|
||||
|
||||
# Check if we're running in a test environment
|
||||
IS_TESTING = 'pytest' in sys.modules
|
||||
|
||||
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
|
||||
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
|
||||
DB_USER = os.environ.get('DB_USER', 'postgres')
|
||||
|
||||
91
enterprise/storage/encrypt_utils.py
Normal file
91
enterprise/storage/encrypt_utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import binascii
|
||||
import hashlib
|
||||
from base64 import b64decode, b64encode
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from pydantic import SecretStr
|
||||
from server.config import get_config
|
||||
|
||||
_fernet = None
|
||||
|
||||
|
||||
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
|
||||
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
|
||||
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
|
||||
|
||||
|
||||
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
fernet = get_fernet()
|
||||
for key, value in kwargs.items():
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if isinstance(value, dict):
|
||||
encrypt_kwargs(encrypt_keys, value)
|
||||
continue
|
||||
|
||||
if key in encrypt_keys:
|
||||
if isinstance(value, SecretStr):
|
||||
value = b64encode(
|
||||
fernet.encrypt(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
value = b64encode(fernet.encrypt(value.encode())).decode()
|
||||
kwargs[key] = value
|
||||
return kwargs
|
||||
|
||||
|
||||
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
|
||||
fernet = get_fernet()
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
if value is None:
|
||||
continue
|
||||
if key in encrypt_keys:
|
||||
if isinstance(value, SecretStr):
|
||||
value = fernet.decrypt(
|
||||
b64decode(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
value = fernet.decrypt(b64decode(value.encode())).decode()
|
||||
kwargs[key] = value
|
||||
except binascii.Error:
|
||||
pass # Key is in legacy format...
|
||||
return kwargs
|
||||
|
||||
|
||||
def encrypt_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return b64encode(
|
||||
get_fernet().encrypt(value.get_secret_value().encode())
|
||||
).decode()
|
||||
else:
|
||||
return b64encode(get_fernet().encrypt(value.encode())).decode()
|
||||
|
||||
|
||||
def decrypt_value(value: str | SecretStr) -> str:
|
||||
if isinstance(value, SecretStr):
|
||||
return (
|
||||
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
|
||||
)
|
||||
else:
|
||||
return get_fernet().decrypt(b64decode(value.encode())).decode()
|
||||
|
||||
|
||||
def get_fernet():
|
||||
global _fernet
|
||||
if _fernet is None:
|
||||
jwt_secret = get_config().jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
_fernet = Fernet(fernet_key)
|
||||
return _fernet
|
||||
|
||||
|
||||
def model_to_kwargs(model_instance):
|
||||
return {
|
||||
column.name: getattr(model_instance, column.name)
|
||||
for column in model_instance.__table__.columns
|
||||
}
|
||||
@@ -1,7 +1,16 @@
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
|
||||
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy import (
|
||||
ARRAY,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
|
||||
634
enterprise/storage/lite_llm_manager.py
Normal file
634
enterprise/storage/lite_llm_manager.py
Normal file
@@ -0,0 +1,634 @@
|
||||
"""
|
||||
Store class for managing organizational settings.
|
||||
"""
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import httpx
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
class LiteLlmManager:
|
||||
"""Manage LiteLLM interactions."""
|
||||
|
||||
@staticmethod
|
||||
async def create_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
oss_settings: Settings,
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'SettingsStore:update_settings_with_litellm_default:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id}',
|
||||
None,
|
||||
)
|
||||
|
||||
oss_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
oss_settings.llm_model = get_default_litellm_model()
|
||||
oss_settings.llm_api_key = SecretStr(key)
|
||||
oss_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return oss_settings
|
||||
|
||||
@staticmethod
|
||||
async def migrate_entries(
|
||||
org_id: str,
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
) -> UserSettings | None:
|
||||
logger.info(
|
||||
'SettingsStore:umigrate_lite_llm_entries:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
|
||||
if not user_json:
|
||||
return None
|
||||
user_info = user_json['user_info']
|
||||
max_budget = user_info.get('max_budget', 0.0)
|
||||
if not max_budget:
|
||||
# if max_budget is None, then we've already migrated the User
|
||||
return None
|
||||
spend = user_info.get('spend', 0.0)
|
||||
credits = max(max_budget - spend, 0.0)
|
||||
|
||||
await LiteLlmManager._create_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
await LiteLlmManager._delete_user(client, keycloak_user_id)
|
||||
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, credits
|
||||
)
|
||||
|
||||
key = await LiteLlmManager._generate_key(
|
||||
client,
|
||||
keycloak_user_id,
|
||||
org_id,
|
||||
f'OpenHands Cloud - user {keycloak_user_id}',
|
||||
None,
|
||||
)
|
||||
|
||||
user_settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
user_settings.llm_model = get_default_litellm_model()
|
||||
user_settings.llm_api_key = SecretStr(key)
|
||||
user_settings.llm_base_url = LITE_LLM_API_URL
|
||||
return user_settings
|
||||
|
||||
@staticmethod
|
||||
async def update_team_and_users_budget(
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
await LiteLlmManager._update_team(client, team_id, None, max_budget)
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
for membership in team_info.get('team_memberships', []):
|
||||
user_id = membership.get('user_id')
|
||||
if not user_id:
|
||||
continue
|
||||
await LiteLlmManager._update_user_in_team(
|
||||
client, user_id, team_id, max_budget
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _create_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
response.status_code == 400
|
||||
and 'already exists. Please use a different team id' in response.text
|
||||
):
|
||||
# team already exists, so update, then return
|
||||
await LiteLlmManager._update_team(
|
||||
client, team_id, team_alias, max_budget
|
||||
)
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a team from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_team(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
team_alias: str | None,
|
||||
max_budget: float | None,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
if team_alias is not None:
|
||||
json_data['team_alias'] = team_alias
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/update',
|
||||
json=json_data,
|
||||
)
|
||||
|
||||
# Team failed to update in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
},
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'email': email,
|
||||
},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': None,
|
||||
'models': [],
|
||||
'user_id': keycloak_user_id,
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': False,
|
||||
'send_invite_email': False,
|
||||
},
|
||||
)
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if response.status_code == 400 and 'already exists' in response.text:
|
||||
# user already exists, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
"""Get a user from litellm with the id matching that given."""
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
async def _update_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/update',
|
||||
json={
|
||||
'user_id': keycloak_user_id,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _delete_user(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
|
||||
)
|
||||
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_deleting_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _add_user_to_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_adding_litellm_user_to_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _get_user_team_info(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
) -> dict | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
return None
|
||||
|
||||
# Filter team_memberships based on team_id and keycloak_user_id
|
||||
user_membership = next(
|
||||
(
|
||||
membership
|
||||
for membership in team_info.get('team_memberships', [])
|
||||
if membership.get('user_id') == keycloak_user_id
|
||||
and membership.get('team_id') == team_id
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
return user_membership
|
||||
|
||||
@staticmethod
|
||||
async def _update_user_in_team(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_updating_litellm_user_in_team',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'max_budget': max_budget,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _generate_key(
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str | None,
|
||||
key_alias: str | None,
|
||||
metadata: dict | None,
|
||||
) -> str | None:
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
json_data: dict[str, Any] = {
|
||||
'user_id': keycloak_user_id,
|
||||
'models': [],
|
||||
}
|
||||
|
||||
if team_id is not None:
|
||||
json_data['team_id'] = team_id
|
||||
|
||||
if key_alias is not None:
|
||||
json_data['key_alias'] = key_alias
|
||||
|
||||
if metadata is not None:
|
||||
json_data['metadata'] = metadata
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/generate',
|
||||
json=json_data,
|
||||
)
|
||||
# Failed to generate user key for team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_generate_user_team_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [keycloak_user_id],
|
||||
'team_id': [team_id],
|
||||
'key_alias': [key_alias],
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
logger.info(
|
||||
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'team_id': [team_id],
|
||||
'key_alias': [key_alias],
|
||||
},
|
||||
)
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
async def _get_key_info(
|
||||
client: httpx.AsyncClient,
|
||||
org_id: int,
|
||||
keycloak_user_id: str,
|
||||
) -> dict | None:
|
||||
from storage.user_store import UserStore
|
||||
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return None
|
||||
user = UserStore.get_user_by_id(keycloak_user_id)
|
||||
if not user:
|
||||
return {}
|
||||
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return {}
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
key_info = response_json.get('info')
|
||||
if not key_info:
|
||||
return {}
|
||||
return {
|
||||
'key_max_budget': key_info.get('max_budget'),
|
||||
'key_spend': key_info.get('spend'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _delete_key(
|
||||
client: httpx.AsyncClient,
|
||||
key_id: str,
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/key/delete',
|
||||
json={
|
||||
'keys': [key_id],
|
||||
},
|
||||
)
|
||||
# Failed to key...
|
||||
if not response.is_success:
|
||||
if response.status_code == 404:
|
||||
# key doesn't exist, just return
|
||||
return
|
||||
logger.error(
|
||||
'error_deleting_key',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
'LiteLlmManager:_delete_key:key_deleted',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def with_http_client(
|
||||
internal_fn: Callable[..., Awaitable[Any]],
|
||||
) -> Callable[..., Awaitable[Any]]:
|
||||
@functools.wraps(internal_fn)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with httpx.AsyncClient(
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY}
|
||||
) as client:
|
||||
return await internal_fn(client, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# Public methods with injected client
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
delete_user = staticmethod(with_http_client(_delete_user))
|
||||
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
|
||||
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
|
||||
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
|
||||
generate_key = staticmethod(with_http_client(_generate_key))
|
||||
get_key_info = staticmethod(with_http_client(_get_key_info))
|
||||
delete_key = staticmethod(with_http_client(_delete_key))
|
||||
111
enterprise/storage/org.py
Normal file
111
enterprise/storage/org.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import SecretStr
|
||||
from server.constants import DEFAULT_BILLING_MARGIN
|
||||
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class Org(Base): # type: ignore
|
||||
"""Organization model."""
|
||||
|
||||
__tablename__ = 'org'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
contact_name = Column(String, nullable=True)
|
||||
contact_email = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
default_max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
default_llm_model = Column(String, nullable=True)
|
||||
_default_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
default_llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
|
||||
enable_proactive_conversation_starters = Column(
|
||||
Boolean, nullable=False, default=True
|
||||
)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
org_version = Column(Integer, nullable=False, default=0)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
_search_api_key = Column(String, nullable=True)
|
||||
_sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
|
||||
conversation_expiration = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org_members = relationship('OrgMember', back_populates='org')
|
||||
current_users = relationship('User', back_populates='current_org')
|
||||
billing_sessions = relationship('BillingSession', back_populates='org')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='org'
|
||||
)
|
||||
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
|
||||
api_keys = relationship('ApiKey', back_populates='org')
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key_for_byor' in kwargs:
|
||||
self.default_llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
|
||||
if 'search_api_key' in kwargs:
|
||||
self.search_api_key = kwargs.pop('search_api_key')
|
||||
if 'sandbox_api_key' in kwargs:
|
||||
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def default_llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._default_llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._default_llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@default_llm_api_key_for_byor.setter
|
||||
def default_llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._default_llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def search_api_key(self) -> SecretStr | None:
|
||||
if self._search_api_key:
|
||||
decrypted = decrypt_value(self._search_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@search_api_key.setter
|
||||
def search_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._search_api_key = encrypt_value(raw) if raw else None
|
||||
|
||||
@property
|
||||
def sandbox_api_key(self) -> SecretStr | None:
|
||||
if self._sandbox_api_key:
|
||||
decrypted = decrypt_value(self._sandbox_api_key)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@sandbox_api_key.setter
|
||||
def sandbox_api_key(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._sandbox_api_key = encrypt_value(raw) if raw else None
|
||||
65
enterprise/storage/org_member.py
Normal file
65
enterprise/storage/org_member.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
|
||||
|
||||
class OrgMember(Base): # type: ignore
|
||||
"""Junction table for organization-member relationships with roles."""
|
||||
|
||||
__tablename__ = 'org_member'
|
||||
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
|
||||
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
|
||||
_llm_api_key = Column(String, nullable=False)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
llm_model = Column(String, nullable=True)
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
user = relationship('User', back_populates='org_members')
|
||||
role = relationship('Role', back_populates='org_members')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
for key in list(kwargs):
|
||||
if hasattr(self.__class__, key):
|
||||
setattr(self, key, kwargs.pop(key))
|
||||
|
||||
# Handle custom property-style fields
|
||||
if 'llm_api_key' in kwargs:
|
||||
self.llm_api_key = kwargs.pop('llm_api_key')
|
||||
|
||||
if kwargs:
|
||||
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
|
||||
|
||||
@property
|
||||
def llm_api_key(self) -> SecretStr:
|
||||
decrypted = decrypt_value(self._llm_api_key)
|
||||
return SecretStr(decrypted)
|
||||
|
||||
@llm_api_key.setter
|
||||
def llm_api_key(self, value: str | SecretStr):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key = encrypt_value(raw)
|
||||
|
||||
@property
|
||||
def llm_api_key_for_byor(self) -> SecretStr | None:
|
||||
if self._llm_api_key_for_byor:
|
||||
decrypted = decrypt_value(self._llm_api_key_for_byor)
|
||||
return SecretStr(decrypted)
|
||||
return None
|
||||
|
||||
@llm_api_key_for_byor.setter
|
||||
def llm_api_key_for_byor(self, value: str | SecretStr | None):
|
||||
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
|
||||
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None
|
||||
97
enterprise/storage/org_member_store.py
Normal file
97
enterprise/storage/org_member_store.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Store class for managing organization-member relationships.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.org_member import OrgMember
|
||||
|
||||
|
||||
class OrgMemberStore:
|
||||
"""Store for managing organization-member relationships."""
|
||||
|
||||
@staticmethod
|
||||
def add_user_to_org(
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
with session_maker() as session:
|
||||
org_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
|
||||
"""Get organization-user relationship."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_user_orgs(user_id: int) -> list[OrgMember]:
|
||||
"""Get all organizations for a user."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
|
||||
|
||||
@staticmethod
|
||||
def get_org_members(org_id: UUID) -> list[OrgMember]:
|
||||
"""Get all users in an organization."""
|
||||
with session_maker() as session:
|
||||
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
|
||||
|
||||
@staticmethod
|
||||
def update_user_role_in_org(
|
||||
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
|
||||
) -> Optional[OrgMember]:
|
||||
"""Update user's role in an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
org_member.role_id = role_id
|
||||
if status is not None:
|
||||
org_member.status = status
|
||||
|
||||
session.commit()
|
||||
session.refresh(org_member)
|
||||
return org_member
|
||||
|
||||
@staticmethod
|
||||
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
|
||||
"""Remove a user from an organization."""
|
||||
with session_maker() as session:
|
||||
org_member = (
|
||||
session.query(OrgMember)
|
||||
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not org_member:
|
||||
return False
|
||||
|
||||
session.delete(org_member)
|
||||
session.commit()
|
||||
return True
|
||||
109
enterprise/storage/org_store.py
Normal file
109
enterprise/storage/org_store.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class OrgStore:
|
||||
"""Store for managing organizations."""
|
||||
|
||||
@staticmethod
|
||||
def create_org(
|
||||
kwargs: dict,
|
||||
) -> Org:
|
||||
"""Create a new organization."""
|
||||
with session_maker() as session:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
session.add(org)
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_id(org_id: UUID) -> Org | None:
|
||||
"""Get organization by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Org).filter(Org.id == org_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
logger.warning(f'User not found for ID {keycloak_user_id}')
|
||||
return None
|
||||
org_id = user.current_org_id
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.warning(
|
||||
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
|
||||
)
|
||||
return None
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_org_by_name(name: str) -> Org | None:
|
||||
"""Get organization by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Org).filter(Org.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_orgs() -> list[Org]:
|
||||
"""List all organizations."""
|
||||
with session_maker() as session:
|
||||
orgs = session.query(Org).all()
|
||||
return orgs
|
||||
|
||||
@staticmethod
|
||||
def update_org(
|
||||
org_id: UUID,
|
||||
kwargs: dict,
|
||||
) -> Optional[Org]:
|
||||
"""Update organization details."""
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
return None
|
||||
|
||||
if 'org_id' in kwargs:
|
||||
kwargs.pop('org_id')
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
session.commit()
|
||||
session.refresh(org)
|
||||
return org
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
c.name: getattr(settings, normalized)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
)
|
||||
and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
21
enterprise/storage/role.py
Normal file
21
enterprise/storage/role.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
SQLAlchemy model for Role.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Role(Base): # type: ignore
|
||||
"""Role model for user permissions."""
|
||||
|
||||
__tablename__ = 'role'
|
||||
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
rank = Column(Integer, nullable=False)
|
||||
|
||||
# Relationships
|
||||
users = relationship('User', back_populates='role')
|
||||
org_members = relationship('OrgMember', back_populates='role')
|
||||
40
enterprise/storage/role_store.py
Normal file
40
enterprise/storage/role_store.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Store class for managing roles.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from storage.database import session_maker
|
||||
from storage.role import Role
|
||||
|
||||
|
||||
class RoleStore:
|
||||
"""Store for managing roles."""
|
||||
|
||||
@staticmethod
|
||||
def create_role(name: str, rank: int) -> Role:
|
||||
"""Create a new role."""
|
||||
with session_maker() as session:
|
||||
role = Role(name=name, rank=rank)
|
||||
session.add(role)
|
||||
session.commit()
|
||||
session.refresh(role)
|
||||
return role
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_id(role_id: int) -> Optional[Role]:
|
||||
"""Get role by ID."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.id == role_id).first()
|
||||
|
||||
@staticmethod
|
||||
def get_role_by_name(name: str) -> Optional[Role]:
|
||||
"""Get role by name."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).filter(Role.name == name).first()
|
||||
|
||||
@staticmethod
|
||||
def list_roles() -> List[Role]:
|
||||
"""List all roles."""
|
||||
with session_maker() as session:
|
||||
return session.query(Role).order_by(Role.rank).all()
|
||||
339
enterprise/storage/saas_app_conversation_info_injector.py
Normal file
339
enterprise/storage/saas_app_conversation_info_injector.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
AppConversationInfoServiceInjector,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
AppConversationInfoPage,
|
||||
AppConversationSortOrder,
|
||||
)
|
||||
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
|
||||
SQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
|
||||
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
|
||||
|
||||
async def _secure_select(self):
|
||||
query = (
|
||||
select(StoredConversationMetadata)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def _secure_select_with_saas_metadata(self):
|
||||
"""Select query that includes SAAS metadata for retrieving user_id."""
|
||||
query = (
|
||||
select(StoredConversationMetadata, StoredConversationMetadataSaas)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
return query
|
||||
|
||||
async def search_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> AppConversationInfoPage:
|
||||
"""Search for conversations with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
if sort_order == AppConversationSortOrder.CREATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.created_at)
|
||||
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.created_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at)
|
||||
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
|
||||
elif sort_order == AppConversationSortOrder.TITLE:
|
||||
query = query.order_by(StoredConversationMetadata.title)
|
||||
elif sort_order == AppConversationSortOrder.TITLE_DESC:
|
||||
query = query.order_by(StoredConversationMetadata.title.desc())
|
||||
|
||||
# Apply pagination
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
query = query.offset(offset)
|
||||
except ValueError:
|
||||
# If page_id is not a valid integer, start from beginning
|
||||
offset = 0
|
||||
else:
|
||||
offset = 0
|
||||
|
||||
# Apply limit and get one extra to check if there are more results
|
||||
query = query.limit(limit + 1)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Check if there are more results
|
||||
has_more = len(rows) > limit
|
||||
if has_more:
|
||||
rows = rows[:limit]
|
||||
|
||||
items = [
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
for stored_metadata, saas_metadata in rows
|
||||
]
|
||||
|
||||
# Calculate next page ID
|
||||
next_page_id = None
|
||||
if has_more:
|
||||
next_page_id = str(offset + limit)
|
||||
|
||||
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
|
||||
|
||||
async def count_app_conversation_info(
|
||||
self,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
select(func.count(StoredConversationMetadata.conversation_id))
|
||||
.select_from(
|
||||
StoredConversationMetadata.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
)
|
||||
.where(StoredConversationMetadata.conversation_version == 'V1')
|
||||
)
|
||||
|
||||
# Apply user filtering
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
|
||||
|
||||
query = self._apply_filters_with_saas_metadata(
|
||||
query=query,
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
count = result.scalar()
|
||||
return count or 0
|
||||
|
||||
def _apply_filters_with_saas_metadata(
|
||||
self,
|
||||
query,
|
||||
title__contains: str | None = None,
|
||||
created_at__gte: datetime | None = None,
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
conditions = []
|
||||
if title__contains is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.title.like(f'%{title__contains}%')
|
||||
)
|
||||
|
||||
if created_at__gte is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
|
||||
|
||||
if created_at__lt is not None:
|
||||
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
|
||||
|
||||
if updated_at__gte is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at >= updated_at__gte
|
||||
)
|
||||
|
||||
if updated_at__lt is not None:
|
||||
conditions.append(
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
|
||||
async def get_app_conversation_info(
|
||||
self, conversation_id: UUID
|
||||
) -> AppConversationInfo | None:
|
||||
"""Get conversation info with user_id from SAAS metadata."""
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id == str(conversation_id)
|
||||
)
|
||||
result_set = await self.db_session.execute(query)
|
||||
result = result_set.first()
|
||||
if result:
|
||||
stored_metadata, saas_metadata = result
|
||||
return self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
return None
|
||||
|
||||
async def batch_get_app_conversation_info(
|
||||
self, conversation_ids: list[UUID]
|
||||
) -> list[AppConversationInfo | None]:
|
||||
"""Batch get conversation info with user_id from SAAS metadata."""
|
||||
conversation_id_strs = [
|
||||
str(conversation_id) for conversation_id in conversation_ids
|
||||
]
|
||||
query = await self._secure_select_with_saas_metadata()
|
||||
query = query.where(
|
||||
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
# Create a mapping of conversation_id to (metadata, saas_metadata)
|
||||
info_by_id = {}
|
||||
for stored_metadata, saas_metadata in rows:
|
||||
info_by_id[stored_metadata.conversation_id] = (
|
||||
stored_metadata,
|
||||
saas_metadata,
|
||||
)
|
||||
|
||||
results: list[AppConversationInfo | None] = []
|
||||
for conversation_id in conversation_id_strs:
|
||||
if conversation_id in info_by_id:
|
||||
stored_metadata, saas_metadata = info_by_id[conversation_id]
|
||||
results.append(
|
||||
self._to_info_with_user_id(stored_metadata, saas_metadata)
|
||||
)
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
return results
|
||||
|
||||
async def save_app_conversation_info(
|
||||
self, info: AppConversationInfo
|
||||
) -> AppConversationInfo:
|
||||
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
|
||||
# Save the base conversation metadata
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
)
|
||||
result = await self.db_session.execute(saas_query)
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
if existing_saas_metadata:
|
||||
# Update existing SAAS metadata
|
||||
existing_saas_metadata.user_id = user_id_uuid
|
||||
# Keep existing org_id or set to user_id if not specified
|
||||
if not existing_saas_metadata.org_id:
|
||||
existing_saas_metadata.org_id = user_id_uuid
|
||||
else:
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user_id_uuid, # Set org_id to user_id as it will not be specified
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
await self.db_session.commit()
|
||||
|
||||
return info
|
||||
|
||||
def _to_info_with_user_id(
|
||||
self,
|
||||
stored: StoredConversationMetadata,
|
||||
saas_metadata: StoredConversationMetadataSaas,
|
||||
) -> AppConversationInfo:
|
||||
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
|
||||
# Use the base _to_info method to get the basic info
|
||||
info = self._to_info(stored)
|
||||
|
||||
# Override the created_by_user_id with the user_id from SAAS metadata
|
||||
info.created_by_user_id = (
|
||||
str(saas_metadata.user_id) if saas_metadata.user_id else None
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
|
||||
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[AppConversationInfoService, None]:
|
||||
from openhands.app_server.config import (
|
||||
get_db_session,
|
||||
get_user_context,
|
||||
)
|
||||
|
||||
async with (
|
||||
get_user_context(state, request) as user_context,
|
||||
get_db_session(state, request) as db_session,
|
||||
):
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=db_session, user_context=user_context
|
||||
)
|
||||
yield service
|
||||
@@ -4,10 +4,13 @@ import dataclasses
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.integrations.provider import ProviderType
|
||||
@@ -29,11 +32,28 @@ logger = logging.getLogger(__name__)
|
||||
class SaasConversationStore(ConversationStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(self, user_id: str, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.session_maker = session_maker
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
logger.error(f'No user found by ID {user_id}')
|
||||
raise ValueError(f'No user found by ID {user_id}')
|
||||
self.org_id = user.current_org_id
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
return (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.filter(StoredConversationMetadata.conversation_id == conversation_id)
|
||||
)
|
||||
|
||||
@@ -41,7 +61,6 @@ class SaasConversationStore(ConversationStore):
|
||||
kwargs = {
|
||||
c.name: getattr(conversation_metadata, c.name)
|
||||
for c in StoredConversationMetadata.__table__.columns
|
||||
if c.name != 'github_user_id' # Skip github_user_id field
|
||||
}
|
||||
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
|
||||
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
|
||||
@@ -52,6 +71,8 @@ class SaasConversationStore(ConversationStore):
|
||||
# Convert string to ProviderType enum
|
||||
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
|
||||
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove V1 attributes
|
||||
kwargs.pop('max_budget_per_task', None)
|
||||
kwargs.pop('cache_read_tokens', None)
|
||||
@@ -64,7 +85,10 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
async def save_metadata(self, metadata: ConversationMetadata):
|
||||
kwargs = dataclasses.asdict(metadata)
|
||||
kwargs['user_id'] = self.user_id
|
||||
|
||||
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
|
||||
kwargs.pop('user_id', None)
|
||||
kwargs.pop('org_id', None)
|
||||
|
||||
# Convert ProviderType enum to string for storage
|
||||
if kwargs.get('git_provider') is not None:
|
||||
@@ -78,7 +102,31 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
session.merge(stored_metadata)
|
||||
|
||||
# Create or update the SaaS metadata record
|
||||
saas_metadata = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== stored_metadata.conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not saas_metadata:
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Update existing record
|
||||
saas_metadata.user_id = UUID(self.user_id)
|
||||
saas_metadata.org_id = self.org_id
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_save_metadata)
|
||||
@@ -98,7 +146,18 @@ class SaasConversationStore(ConversationStore):
|
||||
async def delete_metadata(self, conversation_id: str) -> None:
|
||||
def _delete_metadata():
|
||||
with self.session_maker() as session:
|
||||
self._select_by_id(session, conversation_id).delete()
|
||||
# Delete the main conversation metadata
|
||||
session.query(StoredConversationMetadata).filter(
|
||||
StoredConversationMetadata.conversation_id == conversation_id,
|
||||
).delete()
|
||||
|
||||
# Delete the SaaS metadata record
|
||||
session.query(StoredConversationMetadataSaas).filter(
|
||||
StoredConversationMetadataSaas.conversation_id == conversation_id,
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
|
||||
StoredConversationMetadataSaas.org_id == self.org_id,
|
||||
).delete()
|
||||
|
||||
session.commit()
|
||||
|
||||
await call_sync_from_async(_delete_metadata)
|
||||
@@ -122,7 +181,15 @@ class SaasConversationStore(ConversationStore):
|
||||
with self.session_maker() as session:
|
||||
conversations = (
|
||||
session.query(StoredConversationMetadata)
|
||||
.filter(StoredConversationMetadata.user_id == self.user_id)
|
||||
.join(
|
||||
StoredConversationMetadataSaas,
|
||||
StoredConversationMetadata.conversation_id
|
||||
== StoredConversationMetadataSaas.conversation_id,
|
||||
)
|
||||
.filter(
|
||||
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
|
||||
)
|
||||
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
|
||||
.order_by(StoredConversationMetadata.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit + 1)
|
||||
|
||||
@@ -7,11 +7,11 @@ from dataclasses import dataclass
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_user_secrets import StoredUserSecrets
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.storage.secrets.secrets_store import SecretsStore
|
||||
|
||||
|
||||
@@ -21,20 +21,20 @@ class SaasSecretsStore(SecretsStore):
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
|
||||
async def load(self) -> UserSecrets | None:
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
settings = (
|
||||
session.query(StoredUserSecrets)
|
||||
.filter(StoredUserSecrets.keycloak_user_id == self.user_id)
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return UserSecrets()
|
||||
return Secrets()
|
||||
|
||||
kwargs = {}
|
||||
for secret in settings:
|
||||
@@ -45,14 +45,14 @@ class SaasSecretsStore(SecretsStore):
|
||||
|
||||
self._decrypt_kwargs(kwargs)
|
||||
|
||||
return UserSecrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: UserSecrets):
|
||||
async def store(self, item: Secrets):
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
session.query(StoredUserSecrets).filter(
|
||||
StoredUserSecrets.keycloak_user_id == self.user_id
|
||||
session.query(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
).delete()
|
||||
|
||||
# Prepare the new secrets data
|
||||
@@ -74,7 +74,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
|
||||
# Add the new secrets
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredUserSecrets(
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
|
||||
@@ -2,317 +2,179 @@ from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from base64 import b64decode, b64encode
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
from cryptography.fernet import Fernet
|
||||
from integrations import stripe_service
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
CURRENT_USER_SETTINGS_VERSION,
|
||||
DEFAULT_INITIAL_BUDGET,
|
||||
LITE_LLM_API_KEY,
|
||||
LITE_LLM_API_URL,
|
||||
LITE_LLM_TEAM_ID,
|
||||
REQUIRE_PAYMENT,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import joinedload, sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_settings import StoredSettings
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.settings import Settings
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.settings.settings_store import SettingsStore
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.storage.settings.settings_store import SettingsStore as OssSettingsStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class SaasSettingsStore(SettingsStore):
|
||||
class SaasSettingsStore(OssSettingsStore):
|
||||
user_id: str
|
||||
session_maker: sessionmaker
|
||||
config: OpenHandsConfig
|
||||
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
|
||||
|
||||
async def load(self) -> Settings | None:
|
||||
if not self.user_id:
|
||||
def get_user_settings_by_keycloak_id(
|
||||
self, keycloak_user_id: str, session=None
|
||||
) -> UserSettings | None:
|
||||
"""
|
||||
Get UserSettings by keycloak_user_id.
|
||||
|
||||
Args:
|
||||
keycloak_user_id: The keycloak user ID to search for
|
||||
session: Optional existing database session. If not provided, creates a new one.
|
||||
|
||||
Returns:
|
||||
UserSettings object if found, None otherwise
|
||||
"""
|
||||
if not keycloak_user_id:
|
||||
return None
|
||||
with self.session_maker() as session:
|
||||
settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == self.user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
|
||||
logger.info(
|
||||
'saas_settings_store:load:triggering_migration',
|
||||
extra={'user_id': self.user_id},
|
||||
def _get_settings():
|
||||
if session:
|
||||
# Use provided session
|
||||
return (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
return await self.create_default_settings(settings)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
with self.session_maker() as session:
|
||||
existing = None
|
||||
kwargs = {}
|
||||
if item:
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
self._encrypt_kwargs(kwargs)
|
||||
query = session.query(UserSettings).filter(
|
||||
UserSettings.keycloak_user_id == self.user_id
|
||||
)
|
||||
|
||||
# First check if we have an existing entry in the new table
|
||||
existing = query.first()
|
||||
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in kwargs.items()
|
||||
if key in UserSettings.__table__.columns
|
||||
}
|
||||
if existing:
|
||||
# Update existing entry
|
||||
for key, value in kwargs.items():
|
||||
setattr(existing, key, value)
|
||||
existing.user_version = CURRENT_USER_SETTINGS_VERSION
|
||||
session.merge(existing)
|
||||
else:
|
||||
kwargs['keycloak_user_id'] = self.user_id
|
||||
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
|
||||
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
|
||||
settings = UserSettings(**kwargs)
|
||||
session.add(settings)
|
||||
session.commit()
|
||||
|
||||
async def create_default_settings(self, user_settings: UserSettings | None):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not self.user_id:
|
||||
return None
|
||||
|
||||
# Only users that have specified a payment method get default settings
|
||||
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
|
||||
self.user_id
|
||||
):
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:no_payment',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
settings: Settings | None = None
|
||||
if user_settings is None:
|
||||
settings = Settings(
|
||||
language='en',
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
elif isinstance(user_settings, UserSettings):
|
||||
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
|
||||
kwargs = {
|
||||
c.name: getattr(user_settings, c.name)
|
||||
for c in UserSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
|
||||
if settings:
|
||||
settings = await self.update_settings_with_litellm_default(settings)
|
||||
if settings is None:
|
||||
logger.info(
|
||||
'saas_settings_store:create_default_settings:litellm_update_failed',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
await self.store(settings)
|
||||
return settings
|
||||
|
||||
def load_legacy_db_settings(self, github_user_id: str) -> Settings | None:
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
with self.session_maker() as session:
|
||||
settings = (
|
||||
session.query(StoredSettings)
|
||||
.filter(StoredSettings.id == github_user_id)
|
||||
.first()
|
||||
)
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_db_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = {
|
||||
c.name: getattr(settings, c.name)
|
||||
for c in StoredSettings.__table__.columns
|
||||
if c.name in Settings.model_fields
|
||||
}
|
||||
self._decrypt_kwargs(kwargs)
|
||||
del kwargs['secrets_store']
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def load_legacy_file_store_settings(self, github_user_id: str):
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
|
||||
path = f'users/github/{github_user_id}/settings.json'
|
||||
|
||||
try:
|
||||
json_str = await call_sync_from_async(file_store.read, path)
|
||||
logger.info(
|
||||
'saas_settings_store:load_legacy_file_store_settings:found',
|
||||
extra={'github_user_id': github_user_id},
|
||||
)
|
||||
kwargs = json.loads(json_str)
|
||||
self._decrypt_kwargs(kwargs)
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'saas_settings_store:load_legacy_file_store_settings:error',
|
||||
extra={'github_user_id': github_user_id, 'error': str(e)},
|
||||
)
|
||||
return None
|
||||
|
||||
async def update_settings_with_litellm_default(
|
||||
self, settings: Settings
|
||||
) -> Settings | None:
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:start',
|
||||
extra={'user_id': self.user_id},
|
||||
)
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return None
|
||||
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
|
||||
key = LITE_LLM_API_KEY
|
||||
if not local_deploy:
|
||||
# Get user info to add to litellm
|
||||
token_manager = TokenManager()
|
||||
keycloak_user_info = (
|
||||
await token_manager.get_user_info_from_user_id(self.user_id) or {}
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
headers={
|
||||
'x-goog-api-key': LITE_LLM_API_KEY,
|
||||
}
|
||||
) as client:
|
||||
# Get the previous max budget to prevent accidental loss
|
||||
# In Litellm a get always succeeds, regardless of whether the user actually exists
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_json = response.json()
|
||||
user_info = response_json.get('user_info') or {}
|
||||
logger.info(
|
||||
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
|
||||
)
|
||||
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
|
||||
spend = user_info.get('spend') or 0
|
||||
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == self.user_id)
|
||||
# Create new session
|
||||
with self.session_maker() as new_session:
|
||||
return (
|
||||
new_session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == keycloak_user_id)
|
||||
.first()
|
||||
)
|
||||
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
|
||||
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
|
||||
if (
|
||||
user_settings
|
||||
and user_settings.user_version < 4
|
||||
and user_settings.billing_margin
|
||||
and user_settings.billing_margin != 1.0
|
||||
):
|
||||
billing_margin = user_settings.billing_margin
|
||||
logger.info(
|
||||
'user_settings_v4_budget_upgrade',
|
||||
extra={
|
||||
'max_budget': max_budget,
|
||||
'billing_margin': billing_margin,
|
||||
'spend': spend,
|
||||
},
|
||||
)
|
||||
max_budget *= billing_margin
|
||||
spend *= billing_margin
|
||||
user_settings.billing_margin = 1.0
|
||||
session.commit()
|
||||
|
||||
email = keycloak_user_info.get('email')
|
||||
return _get_settings()
|
||||
|
||||
# We explicitly delete here to guard against odd inherited settings on upgrade.
|
||||
# We don't care if this fails with a 404
|
||||
await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
|
||||
async def load(self) -> Settings | None:
|
||||
user = UserStore.get_user_by_id(self.user_id)
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(
|
||||
UserSettings.keycloak_user_id == self.user_id,
|
||||
UserSettings.migration_status.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
|
||||
# Create the new litellm user
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, email, max_budget, spend
|
||||
org_id = user.current_org_id
|
||||
org_member: OrgMember = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
kwargs = {
|
||||
**{
|
||||
normalized: getattr(org, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if (
|
||||
normalized := c.name.removeprefix('_default_')
|
||||
.removeprefix('default_')
|
||||
.lstrip('_')
|
||||
)
|
||||
if not response.is_success:
|
||||
logger.warning(
|
||||
'duplicate_user_email',
|
||||
extra={'user_id': self.user_id, 'email': email},
|
||||
)
|
||||
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
|
||||
response = await self._create_user_in_lite_llm(
|
||||
client, None, max_budget, spend
|
||||
)
|
||||
in Settings.model_fields
|
||||
},
|
||||
**{
|
||||
normalized: getattr(user, c.name)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) in Settings.model_fields
|
||||
},
|
||||
}
|
||||
kwargs['llm_api_key'] = org_member.llm_api_key
|
||||
if org_member.max_iterations:
|
||||
kwargs['max_iterations'] = org_member.max_iterations
|
||||
if org_member.llm_model:
|
||||
kwargs['llm_model'] = org_member.llm_model
|
||||
if org_member.llm_api_key_for_byor:
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
|
||||
# User failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': [self.user_id],
|
||||
'email': email,
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
},
|
||||
settings = Settings(**kwargs)
|
||||
return settings
|
||||
|
||||
async def store(self, item: Settings):
|
||||
# Call the static store method from SettingsStore
|
||||
with self.session_maker() as session:
|
||||
if not item:
|
||||
return None
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(self.user_id))
|
||||
).first()
|
||||
|
||||
if not user:
|
||||
# Check if we need to migrate from user_settings
|
||||
user_settings = None
|
||||
with session_maker() as session:
|
||||
user_settings = self.get_user_settings_by_keycloak_id(
|
||||
self.user_id, session
|
||||
)
|
||||
if user_settings:
|
||||
user = await UserStore.migrate_user(self.user_id, user_settings)
|
||||
else:
|
||||
logger.error(f'User not found for ID {self.user_id}')
|
||||
return None
|
||||
|
||||
response_json = response.json()
|
||||
key = response_json['key']
|
||||
|
||||
logger.info(
|
||||
'saas_settings_store:update_settings_with_litellm_default:user_created',
|
||||
extra={'user_id': self.user_id},
|
||||
org_id = user.current_org_id
|
||||
org_member = None
|
||||
for om in user.org_members:
|
||||
if om.org_id == org_id:
|
||||
org_member = om
|
||||
break
|
||||
if not org_member or not org_member.llm_api_key:
|
||||
return None
|
||||
org = session.query(Org).filter(Org.id == org_id).first()
|
||||
if not org:
|
||||
logger.error(
|
||||
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
|
||||
)
|
||||
return None
|
||||
|
||||
settings.agent = 'CodeActAgent'
|
||||
# Use the model corresponding to the current user settings version
|
||||
settings.llm_model = get_default_litellm_model()
|
||||
settings.llm_api_key = SecretStr(key)
|
||||
settings.llm_base_url = LITE_LLM_API_URL
|
||||
return settings
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
session.commit()
|
||||
|
||||
@classmethod
|
||||
async def get_instance(
|
||||
@@ -323,6 +185,9 @@ class SaasSettingsStore(SettingsStore):
|
||||
logger.debug(f'saas_settings_store.get_instance::{user_id}')
|
||||
return SaasSettingsStore(user_id, session_maker, config)
|
||||
|
||||
def _should_encrypt(self, key):
|
||||
return key in self.ENCRYPT_VALUES
|
||||
|
||||
def _decrypt_kwargs(self, kwargs: dict):
|
||||
fernet = self._fernet()
|
||||
for key, value in kwargs.items():
|
||||
@@ -365,29 +230,3 @@ class SaasSettingsStore(SettingsStore):
|
||||
jwt_secret = self.config.jwt_secret.get_secret_value()
|
||||
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
|
||||
return Fernet(fernet_key)
|
||||
|
||||
def _should_encrypt(self, key: str) -> bool:
|
||||
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
|
||||
|
||||
async def _create_user_in_lite_llm(
|
||||
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
|
||||
):
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
'user_email': email,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': spend,
|
||||
'user_id': str(self.user_id),
|
||||
'teams': [LITE_LLM_TEAM_ID],
|
||||
'auto_create_key': True,
|
||||
'send_invite_email': False,
|
||||
'metadata': {
|
||||
'version': CURRENT_USER_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
'key_alias': f'OpenHands Cloud - user {self.user_id}',
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -8,4 +10,8 @@ class SlackConversation(Base): # type: ignore
|
||||
conversation_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=False)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
parent_id = Column(String, nullable=True, index=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_conversations')
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -6,6 +8,7 @@ class SlackUser(Base): # type: ignore
|
||||
__tablename__ = 'slack_users'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
slack_user_id = Column(String, nullable=False, index=True)
|
||||
slack_display_name = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
@@ -13,3 +16,6 @@ class SlackUser(Base): # type: ignore
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='slack_users')
|
||||
|
||||
@@ -4,5 +4,4 @@ from openhands.app_server.app_conversation.sql_app_conversation_info_service imp
|
||||
|
||||
StoredConversationMetadata = _StoredConversationMetadata
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadata']
|
||||
|
||||
28
enterprise/storage/stored_conversation_metadata_saas.py
Normal file
28
enterprise/storage/stored_conversation_metadata_saas.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
SQLAlchemy model for ConversationMetadataSaas.
|
||||
|
||||
This model stores the SaaS-specific metadata for conversations,
|
||||
containing only the conversation_id, user_id, and org_id.
|
||||
"""
|
||||
|
||||
from sqlalchemy import UUID as SQL_UUID
|
||||
from sqlalchemy import Column, ForeignKey, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredConversationMetadataSaas(Base): # type: ignore
|
||||
"""SaaS conversation metadata model containing user and org associations."""
|
||||
|
||||
__tablename__ = 'conversation_metadata_saas'
|
||||
|
||||
conversation_id = Column(String, primary_key=True)
|
||||
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
|
||||
# Relationships
|
||||
user = relationship('User', back_populates='stored_conversation_metadata_saas')
|
||||
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
|
||||
|
||||
|
||||
__all__ = ['StoredConversationMetadataSaas']
|
||||
17
enterprise/storage/stored_custom_secrets.py
Normal file
17
enterprise/storage/stored_custom_secrets.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredCustomSecrets(Base): # type: ignore
|
||||
__tablename__ = 'custom_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='user_secrets')
|
||||
@@ -1,29 +0,0 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import JSON, Boolean, Column, Float, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredSettings(Base): # type: ignore
|
||||
"""
|
||||
Legacy user settings storage. This should be considered deprecated - use UserSettings isntead
|
||||
"""
|
||||
|
||||
__tablename__ = 'settings'
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
language = Column(String, nullable=True)
|
||||
agent = Column(String, nullable=True)
|
||||
max_iterations = Column(Integer, nullable=True)
|
||||
security_analyzer = Column(String, nullable=True)
|
||||
confirmation_mode = Column(Boolean, nullable=True, default=False)
|
||||
llm_model = Column(String, nullable=True)
|
||||
llm_api_key = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
remote_runtime_resource_factor = Column(Integer, nullable=True)
|
||||
enable_default_condenser = Column(Boolean, nullable=False, default=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
margin = Column(Float, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True, default=False)
|
||||
sandbox_base_container_image = Column(String, nullable=True)
|
||||
sandbox_runtime_container_image = Column(String, nullable=True)
|
||||
secrets_store = Column(JSON, nullable=True)
|
||||
@@ -1,11 +0,0 @@
|
||||
from sqlalchemy import Column, Identity, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class StoredUserSecrets(Base): # type: ignore
|
||||
__tablename__ = 'user_secrets'
|
||||
id = Column(Integer, Identity(), primary_key=True)
|
||||
keycloak_user_id = Column(String, nullable=True, index=True)
|
||||
secret_name = Column(String, nullable=False)
|
||||
secret_value = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
@@ -1,4 +1,6 @@
|
||||
from sqlalchemy import Column, DateTime, Integer, String, text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@@ -13,6 +15,7 @@ class StripeCustomer(Base): # type: ignore
|
||||
__tablename__ = 'stripe_customers'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
keycloak_user_id = Column(String, nullable=False)
|
||||
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
|
||||
stripe_customer_id = Column(String, nullable=False)
|
||||
created_at = Column(
|
||||
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
|
||||
@@ -23,3 +26,6 @@ class StripeCustomer(Base): # type: ignore
|
||||
onupdate=text('CURRENT_TIMESTAMP'),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='stripe_customers')
|
||||
|
||||
98
enterprise/storage/telemetry_identity.py
Normal file
98
enterprise/storage/telemetry_identity.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""SQLAlchemy model for telemetry identity.
|
||||
|
||||
This model stores persistent identity information that must survive container restarts
|
||||
for the OpenHands Enterprise Telemetry Service.
|
||||
"""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import CheckConstraint, Column, DateTime, Integer, String
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class TelemetryIdentity(Base): # type: ignore
|
||||
"""Stores persistent identity information for telemetry.
|
||||
|
||||
This table is designed to contain exactly one row (enforced by database constraint)
|
||||
that maintains only the identity data that cannot be reliably recomputed:
|
||||
- customer_id: Established relationship with Replicated
|
||||
- instance_id: Generated once, must remain stable
|
||||
|
||||
Operational data like timestamps are derived from the telemetry_metrics table.
|
||||
"""
|
||||
|
||||
__tablename__ = 'telemetry_replicated_identity'
|
||||
__table_args__ = (CheckConstraint('id = 1', name='single_identity_row'),)
|
||||
|
||||
id = Column(Integer, primary_key=True, default=1)
|
||||
customer_id = Column(String(255), nullable=True)
|
||||
instance_id = Column(String(255), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
customer_id: Optional[str] = None,
|
||||
instance_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize telemetry identity.
|
||||
|
||||
Args:
|
||||
customer_id: Unique identifier for the customer
|
||||
instance_id: Unique identifier for this OpenHands instance
|
||||
**kwargs: Additional keyword arguments for SQLAlchemy
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set defaults for fields that would normally be set by SQLAlchemy
|
||||
now = datetime.now(UTC)
|
||||
if not hasattr(self, 'created_at') or self.created_at is None:
|
||||
self.created_at = now
|
||||
if not hasattr(self, 'updated_at') or self.updated_at is None:
|
||||
self.updated_at = now
|
||||
|
||||
# Force id to be 1 to maintain single-row constraint
|
||||
self.id = 1
|
||||
self.customer_id = customer_id
|
||||
self.instance_id = instance_id
|
||||
|
||||
def set_customer_info(
|
||||
self,
|
||||
customer_id: Optional[str] = None,
|
||||
instance_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update customer and instance identification information.
|
||||
|
||||
Args:
|
||||
customer_id: Unique identifier for the customer
|
||||
instance_id: Unique identifier for this OpenHands instance
|
||||
"""
|
||||
if customer_id is not None:
|
||||
self.customer_id = customer_id
|
||||
if instance_id is not None:
|
||||
self.instance_id = instance_id
|
||||
|
||||
@property
|
||||
def has_customer_info(self) -> bool:
|
||||
"""Check if customer identification information is configured."""
|
||||
return bool(self.customer_id and self.instance_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<TelemetryIdentity(customer_id='{self.customer_id}', "
|
||||
f"instance_id='{self.instance_id}')>"
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
112
enterprise/storage/telemetry_metrics.py
Normal file
112
enterprise/storage/telemetry_metrics.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""SQLAlchemy model for telemetry metrics data.
|
||||
|
||||
This model stores individual metric collection records with upload tracking
|
||||
and retry logic for the OpenHands Enterprise Telemetry Service.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, String, Text
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class TelemetryMetrics(Base): # type: ignore
|
||||
"""Stores collected telemetry metrics with upload tracking.
|
||||
|
||||
Each record represents a single metrics collection event with associated
|
||||
metadata for upload status and retry logic.
|
||||
"""
|
||||
|
||||
__tablename__ = 'telemetry_metrics'
|
||||
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
collected_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
index=True,
|
||||
)
|
||||
metrics_data = Column(JSON, nullable=False)
|
||||
uploaded_at = Column(DateTime(timezone=True), nullable=True, index=True)
|
||||
upload_attempts = Column(Integer, nullable=False, default=0)
|
||||
last_upload_error = Column(Text, nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metrics_data: Dict[str, Any],
|
||||
collected_at: Optional[datetime] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a new telemetry metrics record.
|
||||
|
||||
Args:
|
||||
metrics_data: Dictionary containing the collected metrics
|
||||
collected_at: Timestamp when metrics were collected (defaults to now)
|
||||
**kwargs: Additional keyword arguments for SQLAlchemy
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set defaults for fields that would normally be set by SQLAlchemy
|
||||
now = datetime.now(UTC)
|
||||
if not hasattr(self, 'id') or self.id is None:
|
||||
self.id = str(uuid.uuid4())
|
||||
if not hasattr(self, 'upload_attempts') or self.upload_attempts is None:
|
||||
self.upload_attempts = 0
|
||||
if not hasattr(self, 'created_at') or self.created_at is None:
|
||||
self.created_at = now
|
||||
if not hasattr(self, 'updated_at') or self.updated_at is None:
|
||||
self.updated_at = now
|
||||
|
||||
self.metrics_data = metrics_data
|
||||
if collected_at:
|
||||
self.collected_at = collected_at
|
||||
elif not hasattr(self, 'collected_at') or self.collected_at is None:
|
||||
self.collected_at = now
|
||||
|
||||
def mark_uploaded(self) -> None:
|
||||
"""Mark this metrics record as successfully uploaded."""
|
||||
self.uploaded_at = datetime.now(UTC)
|
||||
self.last_upload_error = None
|
||||
|
||||
def mark_upload_failed(self, error_message: str) -> None:
|
||||
"""Mark this metrics record as having failed upload.
|
||||
|
||||
Args:
|
||||
error_message: Description of the upload failure
|
||||
"""
|
||||
self.upload_attempts += 1
|
||||
self.last_upload_error = error_message
|
||||
self.uploaded_at = None
|
||||
|
||||
@property
|
||||
def is_uploaded(self) -> bool:
|
||||
"""Check if this metrics record has been successfully uploaded."""
|
||||
return self.uploaded_at is not None
|
||||
|
||||
@property
|
||||
def needs_retry(self) -> bool:
|
||||
"""Check if this metrics record needs upload retry (failed but not too many attempts)."""
|
||||
return not self.is_uploaded and self.upload_attempts < 3
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<TelemetryMetrics(id='{self.id}', "
|
||||
f"collected_at='{self.collected_at}', "
|
||||
f'uploaded={self.is_uploaded})>'
|
||||
)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
41
enterprise/storage/user.py
Normal file
41
enterprise/storage/user.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
SQLAlchemy model for User.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class User(Base): # type: ignore
|
||||
"""User model with organizational relationships."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
|
||||
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
enable_sound_notifications = Column(Boolean, nullable=True)
|
||||
language = Column(String, nullable=True)
|
||||
user_consents_to_analytics = Column(Boolean, nullable=True)
|
||||
email = Column(String, nullable=True)
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
org_members = relationship('OrgMember', back_populates='user')
|
||||
current_org = relationship('Org', back_populates='current_users')
|
||||
stored_conversation_metadata_saas = relationship(
|
||||
'StoredConversationMetadataSaas', back_populates='user'
|
||||
)
|
||||
@@ -38,3 +38,6 @@ class UserSettings(Base): # type: ignore
|
||||
email_verified = Column(Boolean, nullable=True)
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
migration_status = Column(
|
||||
Boolean, nullable=True, default=False
|
||||
) # False = not migrated, True = migrated
|
||||
|
||||
228
enterprise/storage/user_store.py
Normal file
228
enterprise/storage/user_store.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""
|
||||
Store class for managing users.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from integrations.stripe_service import migrate_customer
|
||||
from server.logger import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import joinedload
|
||||
from storage.database import session_maker
|
||||
from storage.encrypt_utils import decrypt_model
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.storage.data_models.settings import Settings
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""Store for managing users."""
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
keycloak_user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
"""Create a new user."""
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
name=f'user_{keycloak_user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), keycloak_user_id=keycloak_user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
return None
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('admin')
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # admin of your own org.
|
||||
llm_api_key=settings.llm_api_key, # type: ignore[union-attr]
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
keycloak_user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not keycloak_user_id or not user_settings:
|
||||
return None
|
||||
|
||||
# Check if user is already migrated to prevent double migration
|
||||
if user_settings.migration_status is True:
|
||||
logger.warning(f'User {keycloak_user_id} already migrated, skipping')
|
||||
return UserStore.get_user_by_id(keycloak_user_id)
|
||||
kwargs = decrypt_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
'llm_api_key_for_byor',
|
||||
'search_api_key',
|
||||
'sandbox_api_key',
|
||||
],
|
||||
user_settings,
|
||||
)
|
||||
decrypted_user_settings = UserSettings(**kwargs)
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
name=f'user_{keycloak_user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id), keycloak_user_id, decrypted_user_settings
|
||||
)
|
||||
|
||||
await migrate_customer(session, keycloak_user_id, org)
|
||||
|
||||
org_kwargs = {
|
||||
c.name: getattr(decrypted_user_settings, c.name)
|
||||
for c in Org.__table__.columns
|
||||
if c.name != 'id' and hasattr(decrypted_user_settings, c.name)
|
||||
}
|
||||
for key, value in org_kwargs.items():
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
user_kwargs = {
|
||||
c.name: getattr(decrypted_user_settings, c.name)
|
||||
for c in User.__table__.columns
|
||||
if c.name != 'id' and hasattr(decrypted_user_settings, c.name)
|
||||
}
|
||||
user = User(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
)
|
||||
session.add(user)
|
||||
|
||||
role = RoleStore.get_role_by_name('admin')
|
||||
|
||||
org_member = OrgMember(
|
||||
org_id=org.id,
|
||||
user_id=user.id,
|
||||
role_id=role.id, # admin of your own org.
|
||||
llm_api_key=decrypted_user_settings.llm_api_key, # type: ignore[union-attr]
|
||||
status='active',
|
||||
)
|
||||
session.add(org_member)
|
||||
|
||||
# Mark the old user_settings as migrated instead of deleting
|
||||
user_settings.migration_status = True
|
||||
|
||||
# need to migrate conversation metadata
|
||||
session.execute(
|
||||
text("""
|
||||
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
|
||||
SELECT
|
||||
conversation_id,
|
||||
CASE
|
||||
WHEN user_id ~ '^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
||||
THEN user_id::uuid
|
||||
ELSE gen_random_uuid()
|
||||
END AS user_id,
|
||||
COALESCE(org_id, gen_random_uuid()) AS org_id
|
||||
FROM conversation_metadata
|
||||
WHERE user_id IS NOT NULL
|
||||
""")
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
user.org_members # load org_members
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(keycloak_user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def list_users() -> list[User]:
|
||||
"""List all users."""
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, keycloak_user_id: str
|
||||
) -> Optional[Settings]:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
return None
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
settings = await LiteLlmManager.create_entries(
|
||||
org_id, keycloak_user_id, settings
|
||||
)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
extra={'org_id': org_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def get_kwargs_from_settings(settings: Settings):
|
||||
kwargs = {
|
||||
c.name: getattr(settings, normalized)
|
||||
for c in User.__table__.columns
|
||||
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
|
||||
}
|
||||
return kwargs
|
||||
@@ -1,10 +1,9 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from server.constants import CURRENT_USER_SETTINGS_VERSION
|
||||
from server.maintenance_task_processor.user_version_upgrade_processor import (
|
||||
UserVersionUpgradeProcessor,
|
||||
)
|
||||
from server.constants import ORG_SETTINGS_VERSION
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.base import Base
|
||||
@@ -14,12 +13,16 @@ from storage.billing_session import BillingSession
|
||||
from storage.conversation_work import ConversationWork
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
from storage.stored_offline_token import StoredOfflineToken
|
||||
from storage.stored_settings import StoredSettings
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -68,7 +71,6 @@ def add_minimal_fixtures(session_maker):
|
||||
session.add(
|
||||
StoredConversationMetadata(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id='mock-user-id',
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
last_updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
accumulated_cost=5.25,
|
||||
@@ -77,6 +79,13 @@ def add_minimal_fixtures(session_maker):
|
||||
total_tokens=750,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredConversationMetadataSaas(
|
||||
conversation_id='mock-conversation-id',
|
||||
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StoredOfflineToken(
|
||||
user_id='mock-user-id',
|
||||
@@ -85,7 +94,38 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(StoredSettings(id='mock-user-id', user_consents_to_analytics=True))
|
||||
session.add(
|
||||
Org(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
name='mock-org',
|
||||
org_version=ORG_SETTINGS_VERSION,
|
||||
enable_default_condenser=True,
|
||||
enable_proactive_conversation_starters=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
Role(
|
||||
id=1,
|
||||
name='admin',
|
||||
rank=1,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
User(
|
||||
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_consents_to_analytics=True,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
|
||||
role_id=1,
|
||||
llm_api_key='mock-api-key',
|
||||
status='active',
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id='mock-user-id',
|
||||
@@ -94,13 +134,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-10'),
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
UserSettings(
|
||||
keycloak_user_id='mock-user-id',
|
||||
user_consents_to_analytics=True,
|
||||
user_version=CURRENT_USER_SETTINGS_VERSION,
|
||||
)
|
||||
)
|
||||
session.add(
|
||||
ConversationWork(
|
||||
conversation_id='mock-conversation-id',
|
||||
@@ -109,17 +142,6 @@ def add_minimal_fixtures(session_maker):
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
maintenance_task = MaintenanceTask(
|
||||
status=MaintenanceTaskStatus.PENDING,
|
||||
)
|
||||
maintenance_task.set_processor(
|
||||
UserVersionUpgradeProcessor(
|
||||
user_ids=['mock-user-id'],
|
||||
created_at=datetime.fromisoformat('2025-03-07'),
|
||||
updated_at=datetime.fromisoformat('2025-03-08'),
|
||||
)
|
||||
)
|
||||
session.add(maintenance_task)
|
||||
session.commit()
|
||||
|
||||
|
||||
|
||||
1
enterprise/tests/unit/experiments/__init__.py
Normal file
1
enterprise/tests/unit/experiments/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Unit tests for experiments module."""
|
||||
@@ -0,0 +1,137 @@
|
||||
# tests/test_condenser_max_step_experiment_v1.py
|
||||
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from experiments.experiment_manager import SaaSExperimentManager
|
||||
|
||||
# SUT imports (update the module path if needed)
|
||||
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||
handle_condenser_max_step_experiment__v1,
|
||||
)
|
||||
from pydantic import SecretStr
|
||||
|
||||
from openhands.sdk import LLM, Agent
|
||||
from openhands.sdk.context.condenser import LLMSummarizingCondenser
|
||||
|
||||
|
||||
def make_agent() -> Agent:
|
||||
"""Build a minimal valid Agent."""
|
||||
llm = LLM(
|
||||
usage_id='primary-llm',
|
||||
model='provider/model',
|
||||
api_key=SecretStr('sk-test'),
|
||||
)
|
||||
return Agent(llm=llm)
|
||||
|
||||
|
||||
def _patch_variant(monkeypatch, return_value):
|
||||
"""Patch the internal variant getter to return a specific value."""
|
||||
monkeypatch.setattr(
|
||||
'experiments.experiment_versions._004_condenser_max_step_experiment._get_condenser_max_step_variant',
|
||||
lambda user_id, conv_id: return_value,
|
||||
raising=True,
|
||||
)
|
||||
|
||||
|
||||
def test_control_variant_sets_condenser_with_max_size_120(monkeypatch):
|
||||
_patch_variant(monkeypatch, 'control')
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
|
||||
result = handle_condenser_max_step_experiment__v1('user-1', conv_id, agent)
|
||||
|
||||
# Should be a new Agent instance with a condenser installed
|
||||
assert result is not agent
|
||||
assert isinstance(result.condenser, LLMSummarizingCondenser)
|
||||
|
||||
# The condenser should have its own LLM (usage_id overridden to "condenser")
|
||||
assert result.condenser.llm.usage_id == 'condenser'
|
||||
# The original agent LLM remains unchanged
|
||||
assert agent.llm.usage_id == 'primary-llm'
|
||||
|
||||
# Control: max_size = 120, keep_first = 4
|
||||
assert result.condenser.max_size == 120
|
||||
assert result.condenser.keep_first == 4
|
||||
|
||||
|
||||
def test_treatment_variant_sets_condenser_with_max_size_80(monkeypatch):
|
||||
_patch_variant(monkeypatch, 'treatment')
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
|
||||
result = handle_condenser_max_step_experiment__v1('user-2', conv_id, agent)
|
||||
|
||||
assert result is not agent
|
||||
assert isinstance(result.condenser, LLMSummarizingCondenser)
|
||||
assert result.condenser.llm.usage_id == 'condenser'
|
||||
assert result.condenser.max_size == 80
|
||||
assert result.condenser.keep_first == 4
|
||||
|
||||
|
||||
def test_none_variant_returns_original_agent_without_changes(monkeypatch):
|
||||
_patch_variant(monkeypatch, None)
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
|
||||
result = handle_condenser_max_step_experiment__v1('user-3', conv_id, agent)
|
||||
|
||||
# No changes—same instance and no condenser attribute added
|
||||
assert result is agent
|
||||
assert getattr(result, 'condenser', None) is None
|
||||
|
||||
|
||||
def test_unknown_variant_returns_original_agent_without_changes(monkeypatch):
|
||||
_patch_variant(monkeypatch, 'weird-variant')
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
|
||||
result = handle_condenser_max_step_experiment__v1('user-4', conv_id, agent)
|
||||
|
||||
assert result is agent
|
||||
assert getattr(result, 'condenser', None) is None
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.handle_condenser_max_step_experiment__v1')
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', False)
|
||||
def test_run_agent_variant_tests_v1_noop_when_manager_disabled(
|
||||
mock_handle_condenser,
|
||||
):
|
||||
"""If ENABLE_EXPERIMENT_MANAGER is False, the method returns the exact same agent and does not call the handler."""
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
|
||||
result = SaaSExperimentManager.run_agent_variant_tests__v1(
|
||||
user_id='user-123',
|
||||
conversation_id=conv_id,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Same object returned (no copy)
|
||||
assert result is agent
|
||||
# Handler should not have been called
|
||||
mock_handle_condenser.assert_not_called()
|
||||
|
||||
|
||||
@patch('experiments.experiment_manager.ENABLE_EXPERIMENT_MANAGER', True)
|
||||
@patch('experiments.experiment_manager.EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', True)
|
||||
def test_run_agent_variant_tests_v1_calls_handler_and_sets_system_prompt(monkeypatch):
|
||||
"""When enabled, it should call the condenser experiment handler and set the long-horizon system prompt."""
|
||||
agent = make_agent()
|
||||
conv_id = uuid4()
|
||||
|
||||
_patch_variant(monkeypatch, 'treatment')
|
||||
|
||||
result: Agent = SaaSExperimentManager.run_agent_variant_tests__v1(
|
||||
user_id='user-abc',
|
||||
conversation_id=conv_id,
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Should be a different instance than the original (copied after handler runs)
|
||||
assert result is not agent
|
||||
assert result.system_prompt_filename == 'system_prompt_long_horizon.j2'
|
||||
|
||||
# The condenser returned by the handler must be preserved after the system-prompt override copy
|
||||
assert isinstance(result.condenser, LLMSummarizingCondenser)
|
||||
assert result.condenser.max_size == 80
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user