Compare commits

...

89 Commits

Author SHA1 Message Date
openhands 1bad8f1ebc Fix: Handle missing openhands_prs table in OpenhandsPRStore
When the enrich_user_interaction_data cronjob runs, it queries the
openhands_prs table. If the database migrations haven't been run yet,
this causes a ProgrammingError because the table doesn't exist.

This fix wraps the database query in get_unprocessed_prs() with a
try-except block to catch the ProgrammingError and log a warning
instead of crashing. This allows the cronjob to complete gracefully
even if the database isn't fully initialized.

Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-06 00:32:39 +00:00
Tim O'Farrell acc0e893e3 Bump openhands to 1.7.4 (#12269) 2026-01-05 21:40:42 +00:00
Xingyao Wang a8098505c2 Add litellm_extra_body metadata for V1 conversations (#12266)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-06 03:27:06 +08:00
sp.wack 9b834bf660 feat(frontend): create useAppTitle hook for dynamic document titles (#12224) 2026-01-05 23:17:53 +04:00
Xingyao Wang 5744f6602b Handle expired Keycloak session with user-friendly error message (#12168)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-05 15:04:36 +00:00
Neha Prasad 4a82768e6d feat: add empty state to Changes tab with icon and message (#12237)
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2026-01-05 14:22:47 +00:00
Hiep Le 6f86e589c8 feat: allow manual reinstallation for gitlab resolver (#12184) 2026-01-05 12:05:20 +07:00
shanemort1982 5bd8695ab8 feat: Add configurable sandbox host_port and container_url_pattern for remote access (#12255)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2026-01-04 20:26:16 -07:00
Tim O'Farrell 8c73c87583 Add extra_hosts support to agent-server containers (#12236)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-03 05:41:31 +00:00
Graham Neubig 40c25cd1ce fix: use Auth.Token for PyGithub authentication (#12248)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-02 21:14:16 -05:00
Graham Neubig 2ebde2529d fix: Handle LiteLLM v1.80+ 404 response for new users (#12250)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-02 22:18:47 +00:00
Graham Neubig cdc42130e1 fix: replace deprecated get_matching_events with search_events (#12249)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-02 21:59:17 +00:00
Graham Neubig 903c047015 Replace deprecated PyPDF2 with pypdf (#12203)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-02 21:47:52 +00:00
Graham Neubig ee2ad16442 fix: update pythonjsonlogger.jsonlogger to pythonjsonlogger.json (#12247)
Co-authored-by: openhands <openhands@all-hands.dev>
2026-01-02 16:13:00 -05:00
dependabot[bot] a96b47e481 chore(deps): bump posthog-js from 1.312.0 to 1.313.0 in /frontend in the version-all group (#12241)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-02 22:30:44 +04:00
Hiep Le 5a08277184 fix(backend): stabilize gitlab resolver in saas (#12231) 2026-01-03 01:25:28 +07:00
Hiep Le 63d5ceada6 feat(backend): block tld (#12240)
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
2026-01-03 00:42:22 +07:00
Mohammed Abdulai 1bae1fc4e6 doc: correct Slack channel to #dev-ui-ux (#12239)
Co-authored-by: Mohammed Abdulai <nurud43@gmail.com>
2026-01-02 15:28:08 +01:00
Engel Nyst 15bc78f4c1 Remove VSCode extension integration from OpenHands repo (#12234) 2026-01-01 19:28:05 +01:00
dependabot[bot] 437046f5a4 chore(deps): bump the version-all group in /frontend with 2 updates (#12232)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-01 19:29:41 +02:00
Cesar Garcia 714459d6eb fix: run stale issues workflow on upstream repository only (#12162)
Co-authored-by: mamoodi <mamoodiha@gmail.com>
2025-12-31 17:50:48 +00:00
Bharath A V f9b316453d fix: prevent nested buttons in tooltip button (#12177)
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-31 16:08:37 +00:00
Ryanakml 96d073ee5b fix(frontend): add missing onClose prop to conversation panel modals (#12219)
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-31 15:29:03 +00:00
Osama Mabkhot f7d416ac8e refactor(frontend): remove HeroUI BaseModal and migrate MetricsModal (#12174)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-31 15:18:58 +00:00
yunbae b7d5f903cf fix(frontend): Agent Tools & Metadata not available for V1 conversations (#12180)
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-31 19:08:09 +04:00
yunbae 2734a5a52d fix(frontend): show stop action button for running or starting conversations (#12215) 2025-12-31 19:07:09 +04:00
dependabot[bot] 51868ffac6 chore(deps): bump @tanstack/react-query from 5.90.15 to 5.90.16 in /frontend in the version-all group (#12225)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-31 14:44:20 +00:00
Aaron Sequeira 4c0f0a1e9b feat: Support Tau-Bench and BFCL evaluation benchmarks (#11953)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-31 03:12:50 +00:00
dependabot[bot] 82e0aa7924 chore(deps): bump ncipollo/release-action from 1.16.0 to 1.20.0 (#11851)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-31 03:02:48 +00:00
Eliot Jones 9043aa69d8 refactor: Update expected cygnal output format (#12060) 2025-12-30 22:01:36 -05:00
dependabot[bot] 23d379fa41 build(deps): bump node from 24.8-trixie-slim to 25.2-trixie-slim in /containers/app (#11756)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-30 21:18:13 -05:00
Neha Prasad 6f9c0aa3b1 fix: display conversation title in delete confirmation modal (#11818)
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-30 20:59:30 -05:00
Xingyao Wang 232dcf4991 fix(ci): update PAT_TOKEN to ALLHANDS_BOT_GITHUB_PAT for enterprise preview (#12216)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-31 04:41:45 +08:00
Hiep Le ffdd95305f fix(backend): invalid api key (#12217) 2025-12-31 02:05:43 +07:00
sp.wack bfe8275963 hotfix(test): add top-level mock for custom-toast-handlers in conversation-panel tests (#12220) 2025-12-30 19:04:29 +00:00
OpenHands Bot 06a97fc382 Bump SDK packages to v1.7.3 (#12218)
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2025-12-30 18:47:14 +00:00
Graham Neubig b5758b1604 Update GithubIntegration to use auth=Auth.AppAuth() (#12204)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-30 12:59:51 -05:00
mamoodi 3ae09680d6 Release 1.1.0 (#12212) 2025-12-30 11:35:14 -05:00
sp.wack 0e5f4325be hotfix(frontend): set terminal background color for xterm.js 6.0.0 compatibility (#12213) 2025-12-30 14:58:58 +00:00
dependabot[bot] 64d4085612 chore(deps): bump the version-all group in /frontend with 2 updates (#12211)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-30 18:52:39 +04:00
sp.wack 103e3ead0a hotfix(frontend): validate git changes response is array before mapping (#12208) 2025-12-30 12:33:09 +00:00
dependabot[bot] d5e83d0f06 chore(deps): bump peter-evans/create-or-update-comment from 4 to 5 (#12192)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
2025-12-29 23:50:40 +00:00
dependabot[bot] 443918af3c chore(deps): bump docker/setup-qemu-action from 3.6.0 to 3.7.0 (#12193)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-30 00:25:56 +01:00
dependabot[bot] 910646d11f chore(deps): bump actions/cache from 4 to 5 (#12191)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-30 00:25:17 +01:00
Engel Nyst d9d19043f1 chore: Mark V0 legacy files with clear headers and V1 pointers (#12165)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Rohit Malhotra <rohitvinodmalhotra@gmail.com>
2025-12-30 00:21:29 +01:00
Graham Neubig 4dec38c7ce fix(event-webhook): Improve error logging with exception type and stack trace (#12202)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 18:09:20 -05:00
Graham Neubig c3f51d9dbe fix(billing): Add error handling for LiteLLM API failures in get_credits (#12201)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 23:01:55 +00:00
chuckbutkus ecbd3ae749 Fix local dev deployments (#12198) 2025-12-29 16:18:02 -05:00
Hiep Le 8ee1394e8c feat: add button to authentication modal to resend verification email (#12179) 2025-12-30 02:12:14 +07:00
Tim O'Farrell d628e1f20a feat: Add frontend support for public conversation sharing (#12047)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: sp.wack <83104063+amanape@users.noreply.github.com>
2025-12-29 12:04:06 -07:00
sp.wack 1480d4acb0 fix(frontend): deduplicate events on WebSocket reconnect (#12197) 2025-12-29 19:03:48 +00:00
Hiep Le 58a70e8b0d fix(backend): preserve users custom llm settings during settings migrations (#12134)
Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
2025-12-29 23:28:20 +07:00
Hiep Le 49e46a5fa1 refactor(backend): remove <sub> in slack response (#12135) 2025-12-29 23:27:48 +07:00
Hiep Le 2cf6494773 fix(backend): install_gitlab_webhooks.py is not functioning as expected (#12185) 2025-12-29 23:27:31 +07:00
Hiep Le d3afbfa447 refactor(backend): add description field support for secrets (v1 conversations) (#12080) 2025-12-29 22:43:07 +07:00
Hiep Le 8d69b4066f fix(backend): exception occurs when running the latest code from the main branch (v1 conversations) (#12183) 2025-12-29 09:57:14 -05:00
dependabot[bot] 2261281656 chore(deps): bump @tanstack/react-query from 5.90.12 to 5.90.14 in /frontend in the version-all group (#12189)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-29 14:33:52 +00:00
sp.wack d68b2cdd1a hotfix(frontend): fix provider type import (#12187) 2025-12-29 18:01:22 +04:00
dependabot[bot] c70ecc8fe3 chore(deps): bump the version-all group across 1 directory with 6 updates (#12161)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-29 13:54:58 +00:00
Pedro Henrique a3e85e2c2d test: Add MC/DC tests for loop pattern detector (stuck_detector) (#11600)
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 14:15:15 +01:00
Hiep Le 3bef4e6c2d refactor(frontend): update the error message for email addresses containing + during signup (#12178) 2025-12-29 19:36:28 +07:00
Engel Nyst 97654e6a5e Configurable conda/mamba channel_alias for runtime builds (#11516)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-29 00:40:57 +01:00
Tim O'Farrell 30114666ad Bump the SDK to 1.7.1 (#12182)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-28 18:57:08 +00:00
dependabot[bot] ee50f333ba chore(deps): bump actions/upload-artifact from 4 to 5 (#11805)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 09:51:34 -05:00
dependabot[bot] 09d1748a14 build(deps): bump actions/setup-python from 5 to 6 (#11755)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 09:49:17 -05:00
dependabot[bot] 81519343c4 chore(deps): bump actions/download-artifact from 4 to 6 (#11524)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 09:49:02 -05:00
dependabot[bot] f742811e81 chore(deps): bump actions/setup-node from 4 to 6 (#11442)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
2025-12-28 08:58:26 -05:00
johba f8e4b5562e Forgejo integration (#11111)
Co-authored-by: johba <admin@noreply.localhost>
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: johba <johba@harb.eth>
Co-authored-by: enyst <engel.nyst@gmail.com>
Co-authored-by: Graham Neubig <neubig@gmail.com>
Co-authored-by: MrGeorgen <65063405+MrGeorgen@users.noreply.github.com>
Co-authored-by: MrGeorgen <moinl6162@gmail.com>
2025-12-27 15:57:31 -05:00
Tim O'Farrell cb1d1f8a0d Fix install-hooks CronJob failing when gitlab_webhook table doesn't exist (#12167)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-26 10:53:21 -07:00
Tim O'Farrell a829d10213 ALL-4634: implement public conversation sharing feature (#12044)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-26 10:02:01 -07:00
Tim O'Farrell cb8c1fa263 ALL-4627 Database Fixes (#12156)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-26 09:19:51 -07:00
lif c80f70392f fix(frontend): clean up console warnings in test suite (#12004)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: amanape <83104063+amanape@users.noreply.github.com>
2025-12-25 22:26:12 +04:00
Guy Elsmore-Paddock 94e6490a79 Use tini as Docker Runtime Init to Ensure Zombie Processes Get Reaped (#12133)
Co-authored-by: Tim O'Farrell <tofarr@gmail.com>
2025-12-25 06:16:52 +00:00
Tim O'Farrell 09af93a02a Agent server env override (#12068)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: Engel Nyst <engel.nyst@gmail.com>
2025-12-25 03:55:06 +00:00
shanemort1982 5407ea55aa Fix WebSocket localhost bug by passing DOCKER_HOST_ADDR to runtime containers (#12113)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-24 14:26:45 -07:00
Tim O'Farrell fe1026ee8a Fix for re-creating deleted conversation (#12152) 2025-12-24 12:13:29 -07:00
Tim O'Farrell 6d14ce420e Implement Export feature for V1 conversations with comprehensive unit tests (#12030)
Co-authored-by: openhands <openhands@all-hands.dev>
Co-authored-by: hieptl <hieptl.developer@gmail.com>
2025-12-24 17:50:57 +00:00
lif 36fe23aea3 fix(llm): retry LiteLLM bad gateway errors (#12117) 2025-12-24 06:37:12 -05:00
sp.wack 9049b95792 docs(frontend): React Router testing guide (#12145) 2025-12-24 14:21:55 +04:00
Hiep Le e2b2aa52cd feat: require email verification for new signups (#12123) 2025-12-24 14:56:02 +07:00
Tim O'Farrell dc99c7b62e Fix SQLAlchemy result handling in get_sandbox_by_session_api_key (#12148)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-24 00:11:16 +00:00
Tim O'Farrell 8bc1a47a78 Fix for error in get_sandbox_by_session_api_key (#12147) 2025-12-23 22:18:36 +00:00
Tim O'Farrell 8d0e7a92b8 ALL-4636 Resolution for connection leaks (#12144)
Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-23 19:02:56 +00:00
Hiep Le f6e7628bff feat: prevent signups using email addresses with a plus sign and enforce the existing email pattern (#12124) 2025-12-24 01:48:05 +07:00
sp.wack fae83230ee docs(frontend): Add API services guide for frontend development (#12132) 2025-12-23 12:57:55 +00:00
sp.wack a9d2f72d72 docs(frontend): Add MSW testing guide for frontend development (#12131) 2025-12-23 16:32:27 +04:00
Tim O'Farrell 2b8f779b65 fix: Runtime pods fail to start due to missing Playwright browser path (#12130) 2025-12-22 17:04:10 +00:00
Hiep Le 10edb28729 fix(frontend): llm settings view resets to basic after saving (#12097) 2025-12-22 23:00:57 +07:00
Hiep Le 5553d3ca2e feat: support blocking specific email domains (#12115) 2025-12-21 19:49:11 +07:00
407 changed files with 21931 additions and 13325 deletions
+1 -1
View File
@@ -15,7 +15,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
+4 -4
View File
@@ -27,7 +27,7 @@ jobs:
poetry-version: 2.1.3
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.12'
cache: 'poetry'
@@ -38,7 +38,7 @@ jobs:
sudo apt-get install -y libgtk-3-0 libnotify4 libnss3 libxss1 libxtst6 xauth xvfb libgbm1 libasound2t64 netcat-openbsd
- name: Setup Node.js
uses: actions/setup-node@v4
uses: actions/setup-node@v6
with:
node-version: '22'
cache: 'npm'
@@ -192,7 +192,7 @@ jobs:
- name: Upload test results
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: playwright-report
path: tests/e2e/test-results/
@@ -200,7 +200,7 @@ jobs:
- name: Upload OpenHands logs
if: always()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: openhands-logs
path: |
@@ -43,7 +43,7 @@ jobs:
⚠️ This PR contains **migrations**
- name: Comment warning on PR
uses: peter-evans/create-or-update-comment@v4
uses: peter-evans/create-or-update-comment@v5
with:
issue-number: ${{ github.event.pull_request.number }}
comment-id: ${{ steps.find-comment.outputs.comment-id }}
+1 -1
View File
@@ -23,7 +23,7 @@ jobs:
- name: Trigger remote job
run: |
curl --fail-with-body -sS -X POST \
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
-H "Accept: application/vnd.github+json" \
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
+1 -1
View File
@@ -39,7 +39,7 @@ jobs:
working-directory: ./frontend
run: npx playwright test --project=chromium
- name: Upload Playwright report
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
if: always()
with:
name: playwright-report
+6 -6
View File
@@ -64,7 +64,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.6.0
uses: docker/setup-qemu-action@v3.7.0
with:
image: tonistiigi/binfmt:latest
- name: Login to GHCR
@@ -102,7 +102,7 @@ jobs:
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.6.0
uses: docker/setup-qemu-action@v3.7.0
with:
image: tonistiigi/binfmt:latest
- name: Login to GHCR
@@ -161,7 +161,7 @@ jobs:
context: containers/runtime
- name: Upload runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime
@@ -247,7 +247,7 @@ jobs:
- name: Trigger remote job
run: |
curl --fail-with-body -sS -X POST \
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
-H "Accept: application/vnd.github+json" \
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
@@ -268,7 +268,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Download runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime
@@ -330,7 +330,7 @@ jobs:
uses: docker/setup-buildx-action@v3
- name: Download runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/download-artifact@v4
uses: actions/download-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime
+3 -3
View File
@@ -89,7 +89,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Upgrade pip
@@ -118,7 +118,7 @@ jobs:
contains(github.event.review.body, '@openhands-agent-exp')
)
)
uses: actions/cache@v4
uses: actions/cache@v5
with:
path: ${{ env.pythonLocation }}/lib/python3.12/site-packages/*
key: ${{ runner.os }}-pip-openhands-resolver-${{ hashFiles('/tmp/requirements.txt') }}
@@ -269,7 +269,7 @@ jobs:
fi
- name: Upload output.jsonl as artifact
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
if: always() # Upload even if the previous steps fail
with:
name: resolver-output
+3 -3
View File
@@ -63,7 +63,7 @@ jobs:
env:
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
- name: Store coverage file
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: coverage-openhands
path: |
@@ -95,7 +95,7 @@ jobs:
env:
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
- name: Store coverage file
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: coverage-enterprise
path: ".coverage.enterprise.${{ matrix.python_version }}"
@@ -113,7 +113,7 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/download-artifact@v5
- uses: actions/download-artifact@v6
id: download
with:
pattern: coverage-*
+1
View File
@@ -9,6 +9,7 @@ on:
jobs:
stale:
runs-on: blacksmith-4vcpu-ubuntu-2204
if: github.repository == 'OpenHands/OpenHands'
steps:
- uses: actions/stale@v9
with:
@@ -1,156 +0,0 @@
# Workflow that validates the VSCode extension builds correctly
name: VSCode Extension CI
# * Always run on "main"
# * Run on PRs that have changes in the VSCode extension folder or this workflow
# * Run on tags that start with "ext-v"
on:
push:
branches:
- main
tags:
- 'ext-v*'
pull_request:
paths:
- 'openhands/integrations/vscode/**'
- 'build_vscode.py'
- '.github/workflows/vscode-extension-build.yml'
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
concurrency:
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
cancel-in-progress: true
jobs:
# Validate VSCode extension builds correctly
validate-vscode-extension:
name: Validate VSCode Extension Build
runs-on: blacksmith-4vcpu-ubuntu-2204
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Node.js
uses: useblacksmith/setup-node@v5
with:
node-version: '22'
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install VSCode extension dependencies
working-directory: ./openhands/integrations/vscode
run: npm ci
- name: Build VSCode extension via build_vscode.py
run: python build_vscode.py
env:
# Ensure we don't skip the build
SKIP_VSCODE_BUILD: ""
- name: Validate .vsix file
run: |
# Verify the .vsix was created and is valid
if [ -f "openhands/integrations/vscode/openhands-vscode-0.0.1.vsix" ]; then
echo "✅ VSCode extension built successfully"
ls -la openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
# Basic validation that the .vsix is a valid zip file
echo "🔍 Validating .vsix structure..."
file openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
unzip -t openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
echo "✅ VSCode extension validation passed"
else
echo "❌ VSCode extension build failed - .vsix not found"
exit 1
fi
- name: Upload VSCode extension artifact
uses: actions/upload-artifact@v4
with:
name: vscode-extension
path: openhands/integrations/vscode/openhands-vscode-0.0.1.vsix
retention-days: 7
- name: Comment on PR with artifact link
if: github.event_name == 'pull_request'
uses: actions/github-script@v7
with:
script: |
const fs = require('fs');
const path = require('path');
// Get file size for display
const vsixPath = 'openhands/integrations/vscode/openhands-vscode-0.0.1.vsix';
const stats = fs.statSync(vsixPath);
const fileSizeKB = Math.round(stats.size / 1024);
const comment = `## 🔧 VSCode Extension Built Successfully!
The VSCode extension has been built and is ready for testing.
**📦 Download**: [openhands-vscode-0.0.1.vsix](https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}) (${fileSizeKB} KB)
**🚀 To install**:
1. Download the artifact from the workflow run above
2. In VSCode: \`Ctrl+Shift+P\` → "Extensions: Install from VSIX..."
3. Select the downloaded \`.vsix\` file
**✅ Tested with**: Node.js 22
**🔍 Validation**: File structure and integrity verified
---
*Built from commit ${{ github.sha }}*`;
// Check if we already commented on this PR and delete it
const { data: comments } = await github.rest.issues.listComments({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
});
const botComment = comments.find(comment =>
comment.user.login === 'github-actions[bot]' &&
comment.body.includes('VSCode Extension Built Successfully')
);
if (botComment) {
await github.rest.issues.deleteComment({
owner: context.repo.owner,
repo: context.repo.repo,
comment_id: botComment.id,
});
}
// Create a new comment
await github.rest.issues.createComment({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
body: comment
});
release:
name: Create GitHub Release
runs-on: blacksmith-4vcpu-ubuntu-2204
needs: validate-vscode-extension
if: startsWith(github.ref, 'refs/tags/ext-v')
steps:
- name: Download .vsix artifact
uses: actions/download-artifact@v4
with:
name: vscode-extension
path: ./
- name: Create Release
uses: ncipollo/release-action@v1.16.0
with:
artifacts: "*.vsix"
token: ${{ secrets.GITHUB_TOKEN }}
draft: true
allowUpdates: true
-51
View File
@@ -13,7 +13,6 @@ STAGED_FILES=$(git diff --cached --name-only)
# Check if any files match specific patterns
has_frontend_changes=false
has_backend_changes=false
has_vscode_changes=false
# Check each file individually to avoid issues with grep
for file in $STAGED_FILES; do
@@ -21,17 +20,12 @@ for file in $STAGED_FILES; do
has_frontend_changes=true
elif [[ $file == openhands/* || $file == evaluation/* || $file == tests/* ]]; then
has_backend_changes=true
# Check for VSCode extension changes (subset of backend changes)
if [[ $file == openhands/integrations/vscode/* ]]; then
has_vscode_changes=true
fi
fi
done
echo "Analyzing changes..."
echo "- Frontend changes: $has_frontend_changes"
echo "- Backend changes: $has_backend_changes"
echo "- VSCode extension changes: $has_vscode_changes"
# Run frontend linting if needed
if [ "$has_frontend_changes" = true ]; then
@@ -92,51 +86,6 @@ else
echo "Skipping backend checks (no backend changes detected)."
fi
# Run VSCode extension checks if needed
if [ "$has_vscode_changes" = true ]; then
# Check if we're in a CI environment
if [ -n "$CI" ]; then
echo "Skipping VSCode extension checks (CI environment detected)."
echo "WARNING: VSCode extension files have changed but checks are being skipped."
echo "Please run VSCode extension checks manually before submitting your PR."
else
echo "Running VSCode extension checks..."
if [ -d "openhands/integrations/vscode" ]; then
cd openhands/integrations/vscode || exit 1
echo "Running npm lint:fix..."
npm run lint:fix
if [ $? -ne 0 ]; then
echo "VSCode extension linting failed. Please fix the issues before committing."
EXIT_CODE=1
else
echo "VSCode extension linting passed!"
fi
echo "Running npm typecheck..."
npm run typecheck
if [ $? -ne 0 ]; then
echo "VSCode extension type checking failed. Please fix the issues before committing."
EXIT_CODE=1
else
echo "VSCode extension type checking passed!"
fi
echo "Running npm compile..."
npm run compile
if [ $? -ne 0 ]; then
echo "VSCode extension compilation failed. Please fix the issues before committing."
EXIT_CODE=1
else
echo "VSCode extension compilation passed!"
fi
cd ../../..
fi
fi
else
echo "Skipping VSCode extension checks (no VSCode extension changes detected)."
fi
# If no specific code changes detected, run basic checks
if [ "$has_frontend_changes" = false ] && [ "$has_backend_changes" = false ]; then
+1 -1
View File
@@ -31,7 +31,7 @@ We're always looking to improve the look and feel of the application. If you've
for something that's bugging you, feel free to open up a PR that changes the [`./frontend`](./frontend) directory.
If you're looking to make a bigger change, add a new UI element, or significantly alter the style
of the application, please open an issue first, or better, join the #eng-ui-ux channel in our Slack
of the application, please open an issue first, or better, join the #dev-ui-ux channel in our Slack
to gather consensus from our design team first.
#### Improving the agent
+1 -1
View File
@@ -161,7 +161,7 @@ poetry run pytest ./tests/unit/test_*.py
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.0-nikolaik`
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:1.1-nikolaik`
## Develop inside Docker container
-113
View File
@@ -1,113 +0,0 @@
import os
import pathlib
import subprocess
# This script is intended to be run by Poetry during the build process.
# Define the expected name of the .vsix file based on the extension's package.json
# This should match the name and version in openhands-vscode/package.json
EXTENSION_NAME = 'openhands-vscode'
EXTENSION_VERSION = '0.0.1'
VSIX_FILENAME = f'{EXTENSION_NAME}-{EXTENSION_VERSION}.vsix'
# Paths
ROOT_DIR = pathlib.Path(__file__).parent.resolve()
VSCODE_EXTENSION_DIR = ROOT_DIR / 'openhands' / 'integrations' / 'vscode'
def check_node_version():
"""Check if Node.js version is sufficient for building the extension."""
try:
result = subprocess.run(
['node', '--version'], capture_output=True, text=True, check=True
)
version_str = result.stdout.strip()
# Extract major version number (e.g., "v12.22.9" -> 12)
major_version = int(version_str.lstrip('v').split('.')[0])
return major_version >= 18 # Align with frontend actual usage (18.20.1)
except (subprocess.CalledProcessError, FileNotFoundError, ValueError):
return False
def build_vscode_extension():
"""Builds the VS Code extension."""
vsix_path = VSCODE_EXTENSION_DIR / VSIX_FILENAME
# Check if VSCode extension build is disabled via environment variable
if os.environ.get('SKIP_VSCODE_BUILD', '').lower() in ('1', 'true', 'yes'):
print('--- Skipping VS Code extension build (SKIP_VSCODE_BUILD is set) ---')
if vsix_path.exists():
print(f'--- Using existing VS Code extension: {vsix_path} ---')
else:
print('--- No pre-built VS Code extension found ---')
return
# Check Node.js version - if insufficient, use pre-built extension as fallback
if not check_node_version():
print('--- Warning: Node.js version < 18 detected or Node.js not found ---')
print('--- Skipping VS Code extension build (requires Node.js >= 18) ---')
print('--- Using pre-built extension if available ---')
if not vsix_path.exists():
print('--- Warning: No pre-built VS Code extension found ---')
print('--- VS Code extension will not be available ---')
else:
print(f'--- Using pre-built VS Code extension: {vsix_path} ---')
return
print(f'--- Building VS Code extension in {VSCODE_EXTENSION_DIR} ---')
try:
# Ensure npm dependencies are installed
print('--- Running npm install for VS Code extension ---')
subprocess.run(
['npm', 'install'],
cwd=VSCODE_EXTENSION_DIR,
check=True,
shell=os.name == 'nt',
)
# Package the extension
print(f'--- Packaging VS Code extension ({VSIX_FILENAME}) ---')
subprocess.run(
['npm', 'run', 'package-vsix'],
cwd=VSCODE_EXTENSION_DIR,
check=True,
shell=os.name == 'nt',
)
# Verify the generated .vsix file exists
if not vsix_path.exists():
raise FileNotFoundError(
f'VS Code extension package not found after build: {vsix_path}'
)
print(f'--- VS Code extension built successfully: {vsix_path} ---')
except subprocess.CalledProcessError as e:
print(f'--- Warning: Failed to build VS Code extension: {e} ---')
print('--- Continuing without building extension ---')
if not vsix_path.exists():
print('--- Warning: No pre-built VS Code extension found ---')
print('--- VS Code extension will not be available ---')
def build(setup_kwargs):
"""This function is called by Poetry during the build process.
`setup_kwargs` is a dictionary that will be passed to `setuptools.setup()`.
"""
print('--- Running custom Poetry build script (build_vscode.py) ---')
# Build the VS Code extension and place the .vsix file
build_vscode_extension()
# Poetry will handle including files based on pyproject.toml `include` patterns.
# Ensure openhands/integrations/vscode/*.vsix is included there.
print('--- Custom Poetry build script (build_vscode.py) finished ---')
if __name__ == '__main__':
print('Running build_vscode.py directly for testing VS Code extension packaging...')
build_vscode_extension()
print('Direct execution of build_vscode.py finished.')
+1 -1
View File
@@ -1,5 +1,5 @@
ARG OPENHANDS_BUILD_VERSION=dev
FROM node:24.8-trixie-slim AS frontend-builder
FROM node:25.2-trixie-slim AS frontend-builder
WORKDIR /app
+1 -1
View File
@@ -12,7 +12,7 @@ services:
- SANDBOX_API_HOSTNAME=host.docker.internal
- DOCKER_HOST_ADDR=host.docker.internal
#
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.0-nikolaik}
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/openhands/runtime:1.1-nikolaik}
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:
+1 -1
View File
@@ -7,7 +7,7 @@ services:
image: openhands:latest
container_name: openhands-app-${DATE:-}
environment:
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.0-nikolaik}
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.openhands.dev/openhands/runtime:1.1-nikolaik}
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
ports:
+1 -1
View File
@@ -50,7 +50,7 @@ First run this to retrieve Github App secrets
```
gcloud auth application-default login
gcloud config set project global-432717
local/decrypt_env.sh
enterprise_local/decrypt_env.sh /path/to/root/of/deploy/repo
```
Now run this to generate a `.env` file, which will used to run SAAS locally
+2 -2
View File
@@ -4,12 +4,12 @@ set -euo pipefail
# Check if DEPLOY_DIR argument was provided
if [ $# -lt 1 ]; then
echo "Usage: $0 <DEPLOY_DIR>"
echo "Example: $0 /path/to/deploy"
echo "Example: $0 /path/to/root/of/deploy/repo"
exit 1
fi
# Normalize path (remove trailing slash)
DEPLOY_DIR="${DEPLOY_DIR%/}"
DEPLOY_DIR="${1%/}"
# Function to decrypt and rename
decrypt_and_move() {
@@ -6,7 +6,7 @@ from datetime import datetime
from enum import Enum
from typing import Any
from github import Github, GithubIntegration
from github import Auth, Github, GithubIntegration
from integrations.github.github_view import (
GithubIssue,
)
@@ -84,7 +84,7 @@ class GitHubDataCollector:
# self.full_saved_pr_path = 'github_data/prs/{}-{}/data.json'
self.full_saved_pr_path = 'prs/github/{}-{}/data.json'
self.github_integration = GithubIntegration(
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
auth=Auth.AppAuth(GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY)
)
self.conversation_id = None
@@ -143,7 +143,7 @@ class GitHubDataCollector:
try:
installation_token = self._get_installation_access_token(installation_id)
with Github(installation_token) as github_client:
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(repo_name)
issue = repo.get_issue(issue_number)
comments = []
@@ -237,7 +237,7 @@ class GitHubDataCollector:
def _get_pr_commits(self, installation_id: str, repo_name: str, pr_number: int):
commits = []
installation_token = self._get_installation_access_token(installation_id)
with Github(installation_token) as github_client:
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(repo_name)
pr = repo.get_pull(pr_number)
@@ -1,6 +1,6 @@
from types import MappingProxyType
from github import Github, GithubIntegration
from github import Auth, Github, GithubIntegration
from integrations.github.data_collector import GitHubDataCollector
from integrations.github.github_solvability import summarize_issue_solvability
from integrations.github.github_view import (
@@ -21,6 +21,7 @@ from integrations.utils import (
CONVERSATION_URL,
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
get_session_expired_message,
)
from integrations.v1_utils import get_saas_user_auth
from jinja2 import Environment, FileSystemLoader
@@ -31,7 +32,11 @@ from server.utils.conversation_callback_utils import register_callback_processor
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.storage.data_models.secrets import Secrets
from openhands.utils.async_utils import call_sync_from_async
@@ -43,7 +48,7 @@ class GithubManager(Manager):
self.token_manager = token_manager
self.data_collector = data_collector
self.github_integration = GithubIntegration(
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
auth=Auth.AppAuth(GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY)
)
self.jinja_env = Environment(
@@ -77,7 +82,7 @@ class GithubManager(Manager):
reaction: The reaction to add (e.g. "eyes", "+1", "-1", "laugh", "confused", "heart", "hooray", "rocket")
installation_token: GitHub installation access token for API access
"""
with Github(installation_token) as github_client:
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
# Add reaction based on view type
if isinstance(github_view, GithubInlinePRComment):
@@ -199,7 +204,7 @@ class GithubManager(Manager):
outgoing_message = message.message
if isinstance(github_view, GithubInlinePRComment):
with Github(installation_token) as github_client:
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
pr = repo.get_pull(github_view.issue_number)
pr.create_review_comment_reply(
@@ -211,7 +216,7 @@ class GithubManager(Manager):
or isinstance(github_view, GithubIssueComment)
or isinstance(github_view, GithubIssue)
):
with Github(installation_token) as github_client:
with Github(auth=Auth.Token(installation_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
issue = repo.get_issue(number=github_view.issue_number)
issue.create_comment(outgoing_message)
@@ -342,6 +347,13 @@ class GithubManager(Manager):
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(
f'[GitHub] Session expired for user {user_info.username}: {str(e)}'
)
msg_info = get_session_expired_message(user_info.username)
msg = self.create_outgoing_message(msg_info)
await self.send_message(msg, github_view)
@@ -1,7 +1,7 @@
import asyncio
import time
from github import Github
from github import Auth, Github
from integrations.github.github_view import (
GithubInlinePRComment,
GithubIssueComment,
@@ -47,7 +47,7 @@ def fetch_github_issue_context(
context_parts.append(f'Title: {github_view.title}')
context_parts.append(f'Description:\n{github_view.description}')
with Github(user_token) as github_client:
with Github(auth=Auth.Token(user_token)) as github_client:
repo = github_client.get_repo(github_view.full_repo_name)
issue = repo.get_issue(github_view.issue_number)
if issue.labels:
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from uuid import UUID, uuid4
from github import Github, GithubIntegration
from github import Auth, Github, GithubIntegration
from github.Issue import Issue
from integrations.github.github_types import (
WorkflowRun,
@@ -729,13 +729,13 @@ class GithubFactory:
def _interact_with_github() -> Issue | None:
with GithubIntegration(
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
auth=Auth.AppAuth(GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY)
) as integration:
access_token = integration.get_access_token(
payload['installation']['id']
).token
with Github(access_token) as gh:
with Github(auth=Auth.Token(access_token)) as gh:
repo = gh.get_repo(selected_repo)
login = (
payload['organization']['login']
@@ -867,12 +867,12 @@ class GithubFactory:
access_token = ''
with GithubIntegration(
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
auth=Auth.AppAuth(GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY)
) as integration:
access_token = integration.get_access_token(installation_id).token
head_ref = None
with Github(access_token) as gh:
with Github(auth=Auth.Token(access_token)) as gh:
repo = gh.get_repo(selected_repo)
pull_request = repo.get_pull(issue_number)
head_ref = pull_request.head.ref
@@ -15,6 +15,7 @@ from integrations.utils import (
CONVERSATION_URL,
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
from pydantic import SecretStr
@@ -24,7 +25,11 @@ from server.utils.conversation_callback_utils import register_callback_processor
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.storage.data_models.secrets import Secrets
@@ -249,6 +254,13 @@ class GitlabManager(Manager):
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(
f'[GitLab] Session expired for user {user_info.username}: {str(e)}'
)
msg_info = get_session_expired_message(user_info.username)
# Send the acknowledgment message
msg = self.create_outgoing_message(msg_info)
await self.send_message(msg, gitlab_view)
@@ -80,22 +80,52 @@ class SaaSGitLabService(GitLabService):
logger.warning('external_auth_token and user_id not set!')
return gitlab_token
async def get_owned_groups(self) -> list[dict]:
async def get_owned_groups(self, min_access_level: int = 40) -> list[dict]:
"""
Get all groups for which the current user is the owner.
Get all top-level groups where the current user has admin access.
This method supports pagination and fetches all groups where the user has
at least the specified access level.
Args:
min_access_level: Minimum access level required (default: 40 for Maintainer or Owner)
- 40: Maintainer or Owner
- 50: Owner only
Returns:
list[dict]: A list of groups owned by the current user.
list[dict]: A list of groups where user has the specified access level or higher.
"""
url = f'{self.BASE_URL}/groups'
params = {'owned': 'true', 'per_page': 100, 'top_level_only': 'true'}
groups_with_admin_access = []
page = 1
per_page = 100
try:
response, headers = await self._make_request(url, params)
return response
except Exception:
logger.warning('Error fetching owned groups', exc_info=True)
return []
while True:
try:
url = f'{self.BASE_URL}/groups'
params = {
'page': str(page),
'per_page': str(per_page),
'min_access_level': min_access_level,
'top_level_only': 'true',
}
response, headers = await self._make_request(url, params)
if not response:
break
groups_with_admin_access.extend(response)
page += 1
# Check if we've reached the last page
link_header = headers.get('Link', '')
if 'rel="next"' not in link_header:
break
except Exception:
logger.warning(f'Error fetching groups on page {page}', exc_info=True)
break
return groups_with_admin_access
async def add_owned_projects_and_groups_to_db(self, owned_personal_projects):
"""
@@ -527,3 +557,55 @@ class SaaSGitLabService(GitLabService):
await self._make_request(url=url, params=params, method=RequestMethod.POST)
except Exception as e:
logger.exception(f'[GitLab]: Reply to MR failed {e}')
async def get_user_resources_with_admin_access(
self,
) -> tuple[list[dict], list[dict]]:
"""
Get all projects and groups where the current user has admin access (maintainer or owner).
Returns:
tuple[list[dict], list[dict]]: A tuple containing:
- list of projects where user has admin access
- list of groups where user has admin access
"""
projects_with_admin_access = []
groups_with_admin_access = []
# Fetch all projects the user is a member of
page = 1
per_page = 100
while True:
try:
url = f'{self.BASE_URL}/projects'
params = {
'page': str(page),
'per_page': str(per_page),
'membership': 1,
'min_access_level': 40, # Maintainer or Owner
}
response, headers = await self._make_request(url, params)
if not response:
break
projects_with_admin_access.extend(response)
page += 1
# Check if we've reached the last page
link_header = headers.get('Link', '')
if 'rel="next"' not in link_header:
break
except Exception:
logger.warning(f'Error fetching projects on page {page}', exc_info=True)
break
# Fetch all groups where user is owner or maintainer
groups_with_admin_access = await self.get_owned_groups(min_access_level=40)
logger.info(
f'Found {len(projects_with_admin_access)} projects and {len(groups_with_admin_access)} groups with admin access'
)
return projects_with_admin_access, groups_with_admin_access
@@ -0,0 +1,199 @@
"""Shared utilities for GitLab webhook installation.
This module contains reusable functions and classes for installing GitLab webhooks
that can be used by both the cron job and API routes.
"""
from typing import cast
from uuid import uuid4
from integrations.types import GitLabResourceType
from integrations.utils import GITLAB_WEBHOOK_URL
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.gitlab_webhook_store import GitlabWebhookStore
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.service_types import GitService
# Webhook configuration constants
WEBHOOK_NAME = 'OpenHands Resolver'
SCOPES: list[str] = [
'note_events',
'merge_requests_events',
'confidential_issues_events',
'issues_events',
'confidential_note_events',
'job_events',
'pipeline_events',
]
class BreakLoopException(Exception):
"""Exception raised when webhook installation conditions are not met or rate limited."""
pass
async def verify_webhook_conditions(
gitlab_service: type[GitService],
resource_type: GitLabResourceType,
resource_id: str,
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
) -> None:
"""
Verify all conditions are met for webhook installation.
Raises BreakLoopException if any condition fails or rate limited.
Args:
gitlab_service: GitLab service instance
resource_type: Type of resource (PROJECT or GROUP)
resource_id: ID of the resource
webhook_store: Webhook store instance
webhook: Webhook object to verify
"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
# Check if resource exists
does_resource_exist, status = await gitlab_service.check_resource_exists(
resource_type, resource_id
)
logger.info(
'Does resource exists',
extra={
'does_resource_exist': does_resource_exist,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
if status == WebhookStatus.RATE_LIMITED:
raise BreakLoopException()
if not does_resource_exist and status != WebhookStatus.RATE_LIMITED:
await webhook_store.delete_webhook(webhook)
raise BreakLoopException()
# Check if user has admin access
(
is_user_admin_of_resource,
status,
) = await gitlab_service.check_user_has_admin_access_to_resource(
resource_type, resource_id
)
logger.info(
'Is user admin',
extra={
'is_user_admin': is_user_admin_of_resource,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
if status == WebhookStatus.RATE_LIMITED:
raise BreakLoopException()
if not is_user_admin_of_resource:
await webhook_store.delete_webhook(webhook)
raise BreakLoopException()
# Check if webhook already exists
(
does_webhook_exist_on_resource,
status,
) = await gitlab_service.check_webhook_exists_on_resource(
resource_type, resource_id, GITLAB_WEBHOOK_URL
)
logger.info(
'Does webhook already exist',
extra={
'does_webhook_exist_on_resource': does_webhook_exist_on_resource,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
if status == WebhookStatus.RATE_LIMITED:
raise BreakLoopException()
if does_webhook_exist_on_resource != webhook.webhook_exists:
await webhook_store.update_webhook(
webhook, {'webhook_exists': does_webhook_exist_on_resource}
)
if does_webhook_exist_on_resource:
raise BreakLoopException()
async def install_webhook_on_resource(
gitlab_service: type[GitService],
resource_type: GitLabResourceType,
resource_id: str,
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
) -> tuple[str | None, WebhookStatus | None]:
"""
Install webhook on a GitLab resource.
Args:
gitlab_service: GitLab service instance
resource_type: Type of resource (PROJECT or GROUP)
resource_id: ID of the resource
webhook_store: Webhook store instance
webhook: Webhook object to install
Returns:
Tuple of (webhook_id, status)
"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
webhook_secret = f'{webhook.user_id}-{str(uuid4())}'
webhook_uuid = f'{str(uuid4())}'
webhook_id, status = await gitlab_service.install_webhook(
resource_type=resource_type,
resource_id=resource_id,
webhook_name=WEBHOOK_NAME,
webhook_url=GITLAB_WEBHOOK_URL,
webhook_secret=webhook_secret,
webhook_uuid=webhook_uuid,
scopes=SCOPES,
)
logger.info(
'Creating new webhook',
extra={
'webhook_id': webhook_id,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
if status == WebhookStatus.RATE_LIMITED:
raise BreakLoopException()
if webhook_id:
await webhook_store.update_webhook(
webhook=webhook,
update_fields={
'webhook_secret': webhook_secret,
'webhook_exists': True, # webhook was created
'webhook_url': GITLAB_WEBHOOK_URL,
'scopes': SCOPES,
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
},
)
logger.info(
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
)
return webhook_id, status
+10 -1
View File
@@ -17,6 +17,7 @@ from integrations.utils import (
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
filter_potential_repos_by_user_msg,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
@@ -30,7 +31,11 @@ from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import server_config
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.http_session import httpx_verify_option
@@ -380,6 +385,10 @@ class JiraManager(Manager):
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(f'[Jira] Session expired: {str(e)}')
msg_info = get_session_expired_message()
except Exception as e:
logger.error(
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
@@ -19,6 +19,7 @@ from integrations.utils import (
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
filter_potential_repos_by_user_msg,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
@@ -32,7 +33,11 @@ from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import server_config
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.http_session import httpx_verify_option
@@ -397,6 +402,10 @@ class JiraDcManager(Manager):
logger.warning(f'[Jira DC] LLM authentication error: {str(e)}')
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(f'[Jira DC] Session expired: {str(e)}')
msg_info = get_session_expired_message()
except Exception as e:
logger.error(
f'[Jira DC] Unexpected error starting job: {str(e)}', exc_info=True
@@ -16,6 +16,7 @@ from integrations.utils import (
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
filter_potential_repos_by_user_msg,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
@@ -29,7 +30,11 @@ from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import server_config
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
from openhands.utils.http_session import httpx_verify_option
@@ -387,6 +392,10 @@ class LinearManager(Manager):
logger.warning(f'[Linear] LLM authentication error: {str(e)}')
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(f'[Linear] Session expired: {str(e)}')
msg_info = get_session_expired_message()
except Exception as e:
logger.error(
f'[Linear] Unexpected error starting job: {str(e)}', exc_info=True
+13 -1
View File
@@ -14,6 +14,7 @@ from integrations.slack.slack_view import (
from integrations.utils import (
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
from pydantic import SecretStr
@@ -29,7 +30,11 @@ from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
from openhands.integrations.service_types import Repository
from openhands.server.shared import config, server_config
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
from openhands.server.user_auth.user_auth import UserAuth
authorize_url_generator = AuthorizeUrlGenerator(
@@ -352,6 +357,13 @@ class SlackManager(Manager):
msg_info = f'@{user_info.slack_display_name} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
except SessionExpiredError as e:
logger.warning(
f'[Slack] Session expired for user {user_info.slack_display_name}: {str(e)}'
)
msg_info = get_session_expired_message(user_info.slack_display_name)
except StartingConvoException as e:
msg_info = str(e)
+69 -25
View File
@@ -20,6 +20,7 @@ from openhands.events.action import (
AgentFinishAction,
MessageAction,
)
from openhands.events.event_filter import EventFilter
from openhands.events.event_store_abc import EventStoreABC
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.integrations.service_types import Repository
@@ -46,6 +47,27 @@ ENABLE_PROACTIVE_CONVERSATION_STARTERS = (
os.getenv('ENABLE_PROACTIVE_CONVERSATION_STARTERS', 'false').lower() == 'true'
)
def get_session_expired_message(username: str | None = None) -> str:
"""Get a user-friendly session expired message.
Used by integrations to notify users when their Keycloak offline session
has expired.
Args:
username: Optional username to mention in the message. If provided,
the message will include @username prefix (used by Git providers
like GitHub, GitLab, Slack). If None, returns a generic message
(used by Jira, Jira DC, Linear).
Returns:
A formatted session expired message
"""
if username:
return f'@{username} your session has expired. Please login again at [OpenHands Cloud]({HOST_URL}) and try again.'
return f'Your session has expired. Please login again at [OpenHands Cloud]({HOST_URL}) and try again.'
# Toggle for solvability report feature
ENABLE_SOLVABILITY_ANALYSIS = (
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
@@ -203,18 +225,35 @@ def get_summary_for_agent_state(
def get_final_agent_observation(
event_store: EventStoreABC,
) -> list[AgentStateChangedObservation]:
return event_store.get_matching_events(
source=EventSource.ENVIRONMENT,
event_types=(AgentStateChangedObservation,),
limit=1,
reverse=True,
events = list(
event_store.search_events(
filter=EventFilter(
source=EventSource.ENVIRONMENT,
include_types=(AgentStateChangedObservation,),
),
limit=1,
reverse=True,
)
)
result = [e for e in events if isinstance(e, AgentStateChangedObservation)]
assert len(result) == len(events)
return result
def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
return event_store.get_matching_events(
source=EventSource.USER, event_types=(MessageAction,), limit=1, reverse='true'
events = list(
event_store.search_events(
filter=EventFilter(
source=EventSource.USER,
include_types=(MessageAction,),
),
limit=1,
reverse=True,
)
)
result = [e for e in events if isinstance(e, MessageAction)]
assert len(result) == len(events)
return result
def extract_summary_from_event_store(
@@ -226,18 +265,22 @@ def extract_summary_from_event_store(
conversation_link = CONVERSATION_URL.format(conversation_id)
summary_instruction = get_summary_instruction()
instruction_event: list[MessageAction] = event_store.get_matching_events(
query=json.dumps(summary_instruction),
source=EventSource.USER,
event_types=(MessageAction,),
limit=1,
reverse=True,
instruction_events = list(
event_store.search_events(
filter=EventFilter(
query=json.dumps(summary_instruction),
source=EventSource.USER,
include_types=(MessageAction,),
),
limit=1,
reverse=True,
)
)
final_agent_observation = get_final_agent_observation(event_store)
# Find summary instruction event ID
if len(instruction_event) == 0:
if not instruction_events:
logger.warning(
'no_instruction_event_found', extra={'conversation_id': conversation_id}
)
@@ -245,19 +288,19 @@ def extract_summary_from_event_store(
final_agent_observation, conversation_link
) # Agent did not receive summary instruction
event_id: int = instruction_event[0].id
agent_messages: list[MessageAction | AgentFinishAction] = (
event_store.get_matching_events(
start_id=event_id,
source=EventSource.AGENT,
event_types=(MessageAction, AgentFinishAction),
reverse=True,
summary_events = list(
event_store.search_events(
filter=EventFilter(
source=EventSource.AGENT,
include_types=(MessageAction, AgentFinishAction),
),
limit=1,
reverse=True,
start_id=instruction_events[0].id,
)
)
if len(agent_messages) == 0:
if not summary_events:
logger.warning(
'no_agent_messages_found', extra={'conversation_id': conversation_id}
)
@@ -265,10 +308,11 @@ def extract_summary_from_event_store(
final_agent_observation, conversation_link
) # Agent failed to generate summary
summary_event: MessageAction | AgentFinishAction = agent_messages[0]
summary_event = summary_events[0]
if isinstance(summary_event, MessageAction):
return summary_event.content
assert isinstance(summary_event, AgentFinishAction)
return summary_event.final_thought
@@ -321,7 +365,7 @@ def append_conversation_footer(message: str, conversation_id: str) -> str:
The message with the conversation footer appended
"""
conversation_link = CONVERSATION_URL.format(conversation_id)
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
footer = f'\n\n[View full conversation]({conversation_link})'
return message + footer
@@ -0,0 +1,41 @@
"""add public column to conversation_metadata
Revision ID: 085
Revises: 084
Create Date: 2025-01-27 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '085'
down_revision: Union[str, None] = '084'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
op.add_column(
'conversation_metadata',
sa.Column('public', sa.Boolean(), nullable=True),
)
op.create_index(
op.f('ix_conversation_metadata_public'),
'conversation_metadata',
['public'],
unique=False,
)
def downgrade() -> None:
"""Downgrade schema."""
op.drop_index(
op.f('ix_conversation_metadata_public'),
table_name='conversation_metadata',
)
op.drop_column('conversation_metadata', 'public')
+159 -138
View File
@@ -4517,14 +4517,14 @@ dev = ["Sphinx (>=5.1.1)", "black (==24.8.0)", "build (>=0.10.0)", "coverage[tom
[[package]]
name = "libtmux"
version = "0.46.2"
version = "0.53.0"
description = "Typed library that provides an ORM wrapper for tmux, a terminal multiplexer."
optional = false
python-versions = "<4.0,>=3.9"
python-versions = "<4.0,>=3.10"
groups = ["main"]
files = [
{file = "libtmux-0.46.2-py3-none-any.whl", hash = "sha256:6c32dbf22bde8e5e33b2714a4295f6e838dc640f337cd4c085a044f6828c7793"},
{file = "libtmux-0.46.2.tar.gz", hash = "sha256:9a398fec5d714129c8344555d466e1a903dfc0f741ba07aabe75a8ceb25c5dda"},
{file = "libtmux-0.53.0-py3-none-any.whl", hash = "sha256:024b7ae6a12aae55358e8feb914c8632b3ab9bd61c0987c53559643c6a58ee4f"},
{file = "libtmux-0.53.0.tar.gz", hash = "sha256:1d19af4cea0c19543954d7e7317c7025c0739b029cccbe3b843212fae238f1bd"},
]
[[package]]
@@ -4558,25 +4558,25 @@ valkey = ["valkey (>=6)"]
[[package]]
name = "litellm"
version = "1.80.7"
version = "1.80.11"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "<4.0,>=3.9"
groups = ["main"]
files = [
{file = "litellm-1.80.7-py3-none-any.whl", hash = "sha256:f7d993f78c1e0e4e1202b2a925cc6540b55b6e5fb055dd342d88b145ab3102ed"},
{file = "litellm-1.80.7.tar.gz", hash = "sha256:3977a8d195aef842d01c18bf9e22984829363c6a4b54daf9a43c9dd9f190b42c"},
{file = "litellm-1.80.11-py3-none-any.whl", hash = "sha256:406283d66ead77dc7ff0e0b2559c80e9e497d8e7c2257efb1cb9210a20d09d54"},
{file = "litellm-1.80.11.tar.gz", hash = "sha256:c9fc63e7acb6360363238fe291bcff1488c59ff66020416d8376c0ee56414a19"},
]
[package.dependencies]
aiohttp = ">=3.10"
click = "*"
fastuuid = ">=0.13.0"
grpcio = ">=1.62.3,<1.68.0"
grpcio = {version = ">=1.62.3,<1.68.0", markers = "python_version < \"3.14\""}
httpx = ">=0.23.0"
importlib-metadata = ">=6.8.0"
jinja2 = ">=3.1.2,<4.0.0"
jsonschema = ">=4.22.0,<5.0.0"
jsonschema = ">=4.23.0,<5.0.0"
openai = ">=2.8.0"
pydantic = ">=2.5.0,<3.0.0"
python-dotenv = ">=0.2.0"
@@ -4587,7 +4587,7 @@ tokenizers = "*"
caching = ["diskcache (>=5.6.1,<6.0.0)"]
extra-proxy = ["azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-keyvault-secrets (>=4.8.0,<5.0.0)", "google-cloud-iam (>=2.19.1,<3.0.0)", "google-cloud-kms (>=2.21.3,<3.0.0)", "prisma (==0.11.0)", "redisvl (>=0.4.1,<0.5.0) ; python_version >= \"3.9\" and python_version < \"3.14\"", "resend (>=0.8.0)"]
mlflow = ["mlflow (>3.1.4) ; python_version >= \"3.10\""]
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.22)", "litellm-proxy-extras (==0.4.9)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
proxy = ["PyJWT (>=2.10.1,<3.0.0) ; python_version >= \"3.9\"", "apscheduler (>=3.10.4,<4.0.0)", "azure-identity (>=1.15.0,<2.0.0) ; python_version >= \"3.9\"", "azure-storage-blob (>=12.25.1,<13.0.0)", "backoff", "boto3 (==1.36.0)", "cryptography", "fastapi (>=0.120.1)", "fastapi-sso (>=0.16.0,<0.17.0)", "gunicorn (>=23.0.0,<24.0.0)", "litellm-enterprise (==0.1.27)", "litellm-proxy-extras (==0.4.16)", "mcp (>=1.21.2,<2.0.0) ; python_version >= \"3.10\"", "orjson (>=3.9.7,<4.0.0)", "polars (>=1.31.0,<2.0.0) ; python_version >= \"3.10\"", "pynacl (>=1.5.0,<2.0.0)", "python-multipart (>=0.0.18,<0.0.19)", "pyyaml (>=6.0.1,<7.0.0)", "rich (==13.7.1)", "rq", "soundfile (>=0.12.1,<0.13.0)", "uvicorn (>=0.31.1,<0.32.0)", "uvloop (>=0.21.0,<0.22.0) ; sys_platform != \"win32\"", "websockets (>=15.0.1,<16.0.0)"]
semantic-router = ["semantic-router (>=0.1.12) ; python_version >= \"3.9\" and python_version < \"3.14\""]
utils = ["numpydoc"]
@@ -5836,14 +5836,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
[[package]]
name = "openhands-agent-server"
version = "1.6.0"
version = "1.7.4"
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_agent_server-1.6.0-py3-none-any.whl", hash = "sha256:e6ae865ac3e7a96b234e10a0faad23f6210e025bbf7721cb66bc7a71d160848c"},
{file = "openhands_agent_server-1.6.0.tar.gz", hash = "sha256:44ce7694ae2d4bb0666d318ef13e6618bd4dc73022c60354839fe6130e67d02a"},
{file = "openhands_agent_server-1.7.4-py3-none-any.whl", hash = "sha256:997b3dc5243a1ba105f5bd9b0b5bc0cd590c5aa79cd609f23f841218e5f77393"},
{file = "openhands_agent_server-1.7.4.tar.gz", hash = "sha256:0491cf2a5d596610364cbbe9360412bc10a66ae71c0466ab64fd264826e6f1d8"},
]
[package.dependencies]
@@ -5860,7 +5860,7 @@ wsproto = ">=1.2.0"
[[package]]
name = "openhands-ai"
version = "0.0.0-post.5687+7853b41ad"
version = "0.0.0-post.5803+a8098505c"
description = "OpenHands: Code Less, Make More"
optional = false
python-versions = "^3.12,<3.14"
@@ -5896,15 +5896,15 @@ json-repair = "*"
jupyter_kernel_gateway = "*"
kubernetes = "^33.1.0"
libtmux = ">=0.46.2"
litellm = ">=1.74.3, <=1.80.7, !=1.64.4, !=1.67.*"
litellm = ">=1.74.3, !=1.64.4, !=1.67.*"
lmnr = "^0.7.20"
memory-profiler = "^0.61.0"
numpy = "*"
openai = "2.8.0"
openhands-aci = "0.3.2"
openhands-agent-server = "1.6.0"
openhands-sdk = "1.6.0"
openhands-tools = "1.6.0"
openhands-agent-server = "1.7.4"
openhands-sdk = "1.7.4"
openhands-tools = "1.7.4"
opentelemetry-api = "^1.33.1"
opentelemetry-exporter-otlp-proto-grpc = "^1.33.1"
pathspec = "^0.12.1"
@@ -5921,7 +5921,6 @@ pygithub = "^2.5.0"
pyjwt = "^2.9.0"
pylatexenc = "*"
pypdf = "^6.0.0"
PyPDF2 = "*"
python-docx = "*"
python-dotenv = "*"
python-frontmatter = "^1.1.0"
@@ -5960,23 +5959,23 @@ url = ".."
[[package]]
name = "openhands-sdk"
version = "1.6.0"
version = "1.7.4"
description = "OpenHands SDK - Core functionality for building AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_sdk-1.6.0-py3-none-any.whl", hash = "sha256:94d2f87fb35406373da6728ae2d88584137f9e9b67fa0e940444c72f2e44e7d3"},
{file = "openhands_sdk-1.6.0.tar.gz", hash = "sha256:f45742350e3874a7f5b08befc4a9d5adc7e4454f7ab5f8391c519eee3116090f"},
{file = "openhands_sdk-1.7.4-py3-none-any.whl", hash = "sha256:b57511a0467bd3fa64e8cccb7e8026f95e10ee7c5b148335eaa762a32aad8369"},
{file = "openhands_sdk-1.7.4.tar.gz", hash = "sha256:f8e63f996a13d2ea41447384b77a4ffebeb9e85aa54fafcf584f97f7cdc2cd9b"},
]
[package.dependencies]
deprecation = ">=2.1.0"
fastmcp = ">=2.11.3"
httpx = ">=0.27.0"
litellm = ">=1.80.7"
litellm = ">=1.80.10"
lmnr = ">=0.7.24"
pydantic = ">=2.11.7"
pydantic = ">=2.12.5"
python-frontmatter = ">=1.1.0"
python-json-logger = ">=3.3.0"
tenacity = ">=9.1.2"
@@ -5987,14 +5986,14 @@ boto3 = ["boto3 (>=1.35.0)"]
[[package]]
name = "openhands-tools"
version = "1.6.0"
version = "1.7.4"
description = "OpenHands Tools - Runtime tools for AI agents"
optional = false
python-versions = ">=3.12"
groups = ["main"]
files = [
{file = "openhands_tools-1.6.0-py3-none-any.whl", hash = "sha256:176556d44186536751b23fe052d3505492cc2afb8d52db20fb7a2cc0169cd57a"},
{file = "openhands_tools-1.6.0.tar.gz", hash = "sha256:d07ba31050fd4a7891a4c48388aa53ce9f703e17064ddbd59146d6c77e5980b3"},
{file = "openhands_tools-1.7.4-py3-none-any.whl", hash = "sha256:b6a9b04bc59610087d6df789054c966df176c16371fc9c0b0f333ba09f5710d1"},
{file = "openhands_tools-1.7.4.tar.gz", hash = "sha256:776b570da0e86ae48c7815e9adb3839e953e2f4cab7295184ce15849348c52e7"},
]
[package.dependencies]
@@ -6003,7 +6002,7 @@ binaryornot = ">=0.4.4"
browser-use = ">=0.8.0"
cachetools = "*"
func-timeout = ">=4.3.5"
libtmux = ">=0.46.2"
libtmux = ">=0.53.0"
openhands-sdk = "*"
pydantic = ">=2.11.7"
tom-swe = ">=1.0.3"
@@ -7255,22 +7254,22 @@ markers = {test = "platform_python_implementation == \"CPython\" and sys_platfor
[[package]]
name = "pydantic"
version = "2.11.7"
version = "2.12.5"
description = "Data validation using Python type hints"
optional = false
python-versions = ">=3.9"
groups = ["main", "test"]
files = [
{file = "pydantic-2.11.7-py3-none-any.whl", hash = "sha256:dde5df002701f6de26248661f6835bbe296a47bf73990135c7d07ce741b9623b"},
{file = "pydantic-2.11.7.tar.gz", hash = "sha256:d989c3c6cb79469287b1569f7447a17848c998458d49ebe294e975b9baf0f0db"},
{file = "pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d"},
{file = "pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49"},
]
[package.dependencies]
annotated-types = ">=0.6.0"
email-validator = {version = ">=2.0.0", optional = true, markers = "extra == \"email\""}
pydantic-core = "2.33.2"
typing-extensions = ">=4.12.2"
typing-inspection = ">=0.4.0"
pydantic-core = "2.41.5"
typing-extensions = ">=4.14.1"
typing-inspection = ">=0.4.2"
[package.extras]
email = ["email-validator (>=2.0.0)"]
@@ -7278,115 +7277,137 @@ timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows
[[package]]
name = "pydantic-core"
version = "2.33.2"
version = "2.41.5"
description = "Core functionality for Pydantic validation and serialization"
optional = false
python-versions = ">=3.9"
groups = ["main", "test"]
files = [
{file = "pydantic_core-2.33.2-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:2b3d326aaef0c0399d9afffeb6367d5e26ddc24d351dbc9c636840ac355dc5d8"},
{file = "pydantic_core-2.33.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e5b2671f05ba48b94cb90ce55d8bdcaaedb8ba00cc5359f6810fc918713983d"},
{file = "pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0069c9acc3f3981b9ff4cdfaf088e98d83440a4c7ea1bc07460af3d4dc22e72d"},
{file = "pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d53b22f2032c42eaaf025f7c40c2e3b94568ae077a606f006d206a463bc69572"},
{file = "pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:0405262705a123b7ce9f0b92f123334d67b70fd1f20a9372b907ce1080c7ba02"},
{file = "pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b25d91e288e2c4e0662b8038a28c6a07eaac3e196cfc4ff69de4ea3db992a1b"},
{file = "pydantic_core-2.33.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6bdfe4b3789761f3bcb4b1ddf33355a71079858958e3a552f16d5af19768fef2"},
{file = "pydantic_core-2.33.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:efec8db3266b76ef9607c2c4c419bdb06bf335ae433b80816089ea7585816f6a"},
{file = "pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:031c57d67ca86902726e0fae2214ce6770bbe2f710dc33063187a68744a5ecac"},
{file = "pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:f8de619080e944347f5f20de29a975c2d815d9ddd8be9b9b7268e2e3ef68605a"},
{file = "pydantic_core-2.33.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:73662edf539e72a9440129f231ed3757faab89630d291b784ca99237fb94db2b"},
{file = "pydantic_core-2.33.2-cp310-cp310-win32.whl", hash = "sha256:0a39979dcbb70998b0e505fb1556a1d550a0781463ce84ebf915ba293ccb7e22"},
{file = "pydantic_core-2.33.2-cp310-cp310-win_amd64.whl", hash = "sha256:b0379a2b24882fef529ec3b4987cb5d003b9cda32256024e6fe1586ac45fc640"},
{file = "pydantic_core-2.33.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:4c5b0a576fb381edd6d27f0a85915c6daf2f8138dc5c267a57c08a62900758c7"},
{file = "pydantic_core-2.33.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e799c050df38a639db758c617ec771fd8fb7a5f8eaaa4b27b101f266b216a246"},
{file = "pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dc46a01bf8d62f227d5ecee74178ffc448ff4e5197c756331f71efcc66dc980f"},
{file = "pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a144d4f717285c6d9234a66778059f33a89096dfb9b39117663fd8413d582dcc"},
{file = "pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:73cf6373c21bc80b2e0dc88444f41ae60b2f070ed02095754eb5a01df12256de"},
{file = "pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3dc625f4aa79713512d1976fe9f0bc99f706a9dee21dfd1810b4bbbf228d0e8a"},
{file = "pydantic_core-2.33.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:881b21b5549499972441da4758d662aeea93f1923f953e9cbaff14b8b9565aef"},
{file = "pydantic_core-2.33.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:bdc25f3681f7b78572699569514036afe3c243bc3059d3942624e936ec93450e"},
{file = "pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:fe5b32187cbc0c862ee201ad66c30cf218e5ed468ec8dc1cf49dec66e160cc4d"},
{file = "pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:bc7aee6f634a6f4a95676fcb5d6559a2c2a390330098dba5e5a5f28a2e4ada30"},
{file = "pydantic_core-2.33.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:235f45e5dbcccf6bd99f9f472858849f73d11120d76ea8707115415f8e5ebebf"},
{file = "pydantic_core-2.33.2-cp311-cp311-win32.whl", hash = "sha256:6368900c2d3ef09b69cb0b913f9f8263b03786e5b2a387706c5afb66800efd51"},
{file = "pydantic_core-2.33.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e063337ef9e9820c77acc768546325ebe04ee38b08703244c1309cccc4f1bab"},
{file = "pydantic_core-2.33.2-cp311-cp311-win_arm64.whl", hash = "sha256:6b99022f1d19bc32a4c2a0d544fc9a76e3be90f0b3f4af413f87d38749300e65"},
{file = "pydantic_core-2.33.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a7ec89dc587667f22b6a0b6579c249fca9026ce7c333fc142ba42411fa243cdc"},
{file = "pydantic_core-2.33.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3c6db6e52c6d70aa0d00d45cdb9b40f0433b96380071ea80b09277dba021ddf7"},
{file = "pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e61206137cbc65e6d5256e1166f88331d3b6238e082d9f74613b9b765fb9025"},
{file = "pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eb8c529b2819c37140eb51b914153063d27ed88e3bdc31b71198a198e921e011"},
{file = "pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c52b02ad8b4e2cf14ca7b3d918f3eb0ee91e63b3167c32591e57c4317e134f8f"},
{file = "pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:96081f1605125ba0855dfda83f6f3df5ec90c61195421ba72223de35ccfb2f88"},
{file = "pydantic_core-2.33.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f57a69461af2a5fa6e6bbd7a5f60d3b7e6cebb687f55106933188e79ad155c1"},
{file = "pydantic_core-2.33.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:572c7e6c8bb4774d2ac88929e3d1f12bc45714ae5ee6d9a788a9fb35e60bb04b"},
{file = "pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:db4b41f9bd95fbe5acd76d89920336ba96f03e149097365afe1cb092fceb89a1"},
{file = "pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:fa854f5cf7e33842a892e5c73f45327760bc7bc516339fda888c75ae60edaeb6"},
{file = "pydantic_core-2.33.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5f483cfb75ff703095c59e365360cb73e00185e01aaea067cd19acffd2ab20ea"},
{file = "pydantic_core-2.33.2-cp312-cp312-win32.whl", hash = "sha256:9cb1da0f5a471435a7bc7e439b8a728e8b61e59784b2af70d7c169f8dd8ae290"},
{file = "pydantic_core-2.33.2-cp312-cp312-win_amd64.whl", hash = "sha256:f941635f2a3d96b2973e867144fde513665c87f13fe0e193c158ac51bfaaa7b2"},
{file = "pydantic_core-2.33.2-cp312-cp312-win_arm64.whl", hash = "sha256:cca3868ddfaccfbc4bfb1d608e2ccaaebe0ae628e1416aeb9c4d88c001bb45ab"},
{file = "pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f"},
{file = "pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6"},
{file = "pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef"},
{file = "pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a"},
{file = "pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916"},
{file = "pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a"},
{file = "pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d"},
{file = "pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56"},
{file = "pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5"},
{file = "pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e"},
{file = "pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162"},
{file = "pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849"},
{file = "pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9"},
{file = "pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9"},
{file = "pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac"},
{file = "pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5"},
{file = "pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9"},
{file = "pydantic_core-2.33.2-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:a2b911a5b90e0374d03813674bf0a5fbbb7741570dcd4b4e85a2e48d17def29d"},
{file = "pydantic_core-2.33.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6fa6dfc3e4d1f734a34710f391ae822e0a8eb8559a85c6979e14e65ee6ba2954"},
{file = "pydantic_core-2.33.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c54c939ee22dc8e2d545da79fc5381f1c020d6d3141d3bd747eab59164dc89fb"},
{file = "pydantic_core-2.33.2-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53a57d2ed685940a504248187d5685e49eb5eef0f696853647bf37c418c538f7"},
{file = "pydantic_core-2.33.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09fb9dd6571aacd023fe6aaca316bd01cf60ab27240d7eb39ebd66a3a15293b4"},
{file = "pydantic_core-2.33.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0e6116757f7959a712db11f3e9c0a99ade00a5bbedae83cb801985aa154f071b"},
{file = "pydantic_core-2.33.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d55ab81c57b8ff8548c3e4947f119551253f4e3787a7bbc0b6b3ca47498a9d3"},
{file = "pydantic_core-2.33.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c20c462aa4434b33a2661701b861604913f912254e441ab8d78d30485736115a"},
{file = "pydantic_core-2.33.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:44857c3227d3fb5e753d5fe4a3420d6376fa594b07b621e220cd93703fe21782"},
{file = "pydantic_core-2.33.2-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:eb9b459ca4df0e5c87deb59d37377461a538852765293f9e6ee834f0435a93b9"},
{file = "pydantic_core-2.33.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9fcd347d2cc5c23b06de6d3b7b8275be558a0c90549495c699e379a80bf8379e"},
{file = "pydantic_core-2.33.2-cp39-cp39-win32.whl", hash = "sha256:83aa99b1285bc8f038941ddf598501a86f1536789740991d7d8756e34f1e74d9"},
{file = "pydantic_core-2.33.2-cp39-cp39-win_amd64.whl", hash = "sha256:f481959862f57f29601ccced557cc2e817bce7533ab8e01a797a48b49c9692b3"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:5c4aa4e82353f65e548c476b37e64189783aa5384903bfea4f41580f255fddfa"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d946c8bf0d5c24bf4fe333af284c59a19358aa3ec18cb3dc4370080da1e8ad29"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87b31b6846e361ef83fedb187bb5b4372d0da3f7e28d85415efa92d6125d6e6d"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa9d91b338f2df0508606f7009fde642391425189bba6d8c653afd80fd6bb64e"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2058a32994f1fde4ca0480ab9d1e75a0e8c87c22b53a3ae66554f9af78f2fe8c"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:0e03262ab796d986f978f79c943fc5f620381be7287148b8010b4097f79a39ec"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:1a8695a8d00c73e50bff9dfda4d540b7dee29ff9b8053e38380426a85ef10052"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:fa754d1850735a0b0e03bcffd9d4b4343eb417e47196e4485d9cca326073a42c"},
{file = "pydantic_core-2.33.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:a11c8d26a50bfab49002947d3d237abe4d9e4b5bdc8846a63537b6488e197808"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:dd14041875d09cc0f9308e37a6f8b65f5585cf2598a53aa0123df8b129d481f8"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d87c561733f66531dced0da6e864f44ebf89a8fba55f31407b00c2f7f9449593"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2f82865531efd18d6e07a04a17331af02cb7a651583c418df8266f17a63c6612"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bfb5112df54209d820d7bf9317c7a6c9025ea52e49f46b6a2060104bba37de7"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:64632ff9d614e5eecfb495796ad51b0ed98c453e447a76bcbeeb69615079fc7e"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:f889f7a40498cc077332c7ab6b4608d296d852182211787d4f3ee377aaae66e8"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:de4b83bb311557e439b9e186f733f6c645b9417c84e2eb8203f3f820a4b988bf"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:82f68293f055f51b51ea42fafc74b6aad03e70e191799430b90c13d643059ebb"},
{file = "pydantic_core-2.33.2-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:329467cecfb529c925cf2bbd4d60d2c509bc2fb52a20c1045bf09bb70971a9c1"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:87acbfcf8e90ca885206e98359d7dca4bcbb35abdc0ff66672a293e1d7a19101"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:7f92c15cd1e97d4b12acd1cc9004fa092578acfa57b67ad5e43a197175d01a64"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3f26877a748dc4251cfcfda9dfb5f13fcb034f5308388066bcfe9031b63ae7d"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac89aea9af8cd672fa7b510e7b8c33b0bba9a43186680550ccf23020f32d535"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:970919794d126ba8645f3837ab6046fb4e72bbc057b3709144066204c19a455d"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:3eb3fe62804e8f859c49ed20a8451342de53ed764150cb14ca71357c765dc2a6"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:3abcd9392a36025e3bd55f9bd38d908bd17962cc49bc6da8e7e96285336e2bca"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:3a1c81334778f9e3af2f8aeb7a960736e5cab1dfebfb26aabca09afd2906c039"},
{file = "pydantic_core-2.33.2-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:2807668ba86cb38c6817ad9bc66215ab8584d1d304030ce4f0887336f28a5e27"},
{file = "pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc"},
{file = "pydantic_core-2.41.5-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:77b63866ca88d804225eaa4af3e664c5faf3568cea95360d21f4725ab6e07146"},
{file = "pydantic_core-2.41.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dfa8a0c812ac681395907e71e1274819dec685fec28273a28905df579ef137e2"},
{file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5921a4d3ca3aee735d9fd163808f5e8dd6c6972101e4adbda9a4667908849b97"},
{file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e25c479382d26a2a41b7ebea1043564a937db462816ea07afa8a44c0866d52f9"},
{file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f547144f2966e1e16ae626d8ce72b4cfa0caedc7fa28052001c94fb2fcaa1c52"},
{file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6f52298fbd394f9ed112d56f3d11aabd0d5bd27beb3084cc3d8ad069483b8941"},
{file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:100baa204bb412b74fe285fb0f3a385256dad1d1879f0a5cb1499ed2e83d132a"},
{file = "pydantic_core-2.41.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:05a2c8852530ad2812cb7914dc61a1125dc4e06252ee98e5638a12da6cc6fb6c"},
{file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:29452c56df2ed968d18d7e21f4ab0ac55e71dc59524872f6fc57dcf4a3249ed2"},
{file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_armv7l.whl", hash = "sha256:d5160812ea7a8a2ffbe233d8da666880cad0cbaf5d4de74ae15c313213d62556"},
{file = "pydantic_core-2.41.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:df3959765b553b9440adfd3c795617c352154e497a4eaf3752555cfb5da8fc49"},
{file = "pydantic_core-2.41.5-cp310-cp310-win32.whl", hash = "sha256:1f8d33a7f4d5a7889e60dc39856d76d09333d8a6ed0f5f1190635cbec70ec4ba"},
{file = "pydantic_core-2.41.5-cp310-cp310-win_amd64.whl", hash = "sha256:62de39db01b8d593e45871af2af9e497295db8d73b085f6bfd0b18c83c70a8f9"},
{file = "pydantic_core-2.41.5-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a3a52f6156e73e7ccb0f8cced536adccb7042be67cb45f9562e12b319c119da6"},
{file = "pydantic_core-2.41.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7f3bf998340c6d4b0c9a2f02d6a400e51f123b59565d74dc60d252ce888c260b"},
{file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:378bec5c66998815d224c9ca994f1e14c0c21cb95d2f52b6021cc0b2a58f2a5a"},
{file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e7b576130c69225432866fe2f4a469a85a54ade141d96fd396dffcf607b558f8"},
{file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6cb58b9c66f7e4179a2d5e0f849c48eff5c1fca560994d6eb6543abf955a149e"},
{file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:88942d3a3dff3afc8288c21e565e476fc278902ae4d6d134f1eeda118cc830b1"},
{file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f31d95a179f8d64d90f6831d71fa93290893a33148d890ba15de25642c5d075b"},
{file = "pydantic_core-2.41.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c1df3d34aced70add6f867a8cf413e299177e0c22660cc767218373d0779487b"},
{file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:4009935984bd36bd2c774e13f9a09563ce8de4abaa7226f5108262fa3e637284"},
{file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:34a64bc3441dc1213096a20fe27e8e128bd3ff89921706e83c0b1ac971276594"},
{file = "pydantic_core-2.41.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:c9e19dd6e28fdcaa5a1de679aec4141f691023916427ef9bae8584f9c2fb3b0e"},
{file = "pydantic_core-2.41.5-cp311-cp311-win32.whl", hash = "sha256:2c010c6ded393148374c0f6f0bf89d206bf3217f201faa0635dcd56bd1520f6b"},
{file = "pydantic_core-2.41.5-cp311-cp311-win_amd64.whl", hash = "sha256:76ee27c6e9c7f16f47db7a94157112a2f3a00e958bc626e2f4ee8bec5c328fbe"},
{file = "pydantic_core-2.41.5-cp311-cp311-win_arm64.whl", hash = "sha256:4bc36bbc0b7584de96561184ad7f012478987882ebf9f9c389b23f432ea3d90f"},
{file = "pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7"},
{file = "pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0"},
{file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69"},
{file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75"},
{file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05"},
{file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc"},
{file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c"},
{file = "pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5"},
{file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c"},
{file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294"},
{file = "pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1"},
{file = "pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d"},
{file = "pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815"},
{file = "pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3"},
{file = "pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9"},
{file = "pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34"},
{file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0"},
{file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03ca43e12fab6023fc79d28ca6b39b05f794ad08ec2feccc59a339b02f2b3d33"},
{file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc799088c08fa04e43144b164feb0c13f9a0bc40503f8df3e9fde58a3c0c101e"},
{file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97aeba56665b4c3235a0e52b2c2f5ae9cd071b8a8310ad27bddb3f7fb30e9aa2"},
{file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:406bf18d345822d6c21366031003612b9c77b3e29ffdb0f612367352aab7d586"},
{file = "pydantic_core-2.41.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b93590ae81f7010dbe380cdeab6f515902ebcbefe0b9327cc4804d74e93ae69d"},
{file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:01a3d0ab748ee531f4ea6c3e48ad9dac84ddba4b0d82291f87248f2f9de8d740"},
{file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:6561e94ba9dacc9c61bce40e2d6bdc3bfaa0259d3ff36ace3b1e6901936d2e3e"},
{file = "pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:915c3d10f81bec3a74fbd4faebe8391013ba61e5a1a8d48c4455b923bdda7858"},
{file = "pydantic_core-2.41.5-cp313-cp313-win32.whl", hash = "sha256:650ae77860b45cfa6e2cdafc42618ceafab3a2d9a3811fcfbd3bbf8ac3c40d36"},
{file = "pydantic_core-2.41.5-cp313-cp313-win_amd64.whl", hash = "sha256:79ec52ec461e99e13791ec6508c722742ad745571f234ea6255bed38c6480f11"},
{file = "pydantic_core-2.41.5-cp313-cp313-win_arm64.whl", hash = "sha256:3f84d5c1b4ab906093bdc1ff10484838aca54ef08de4afa9de0f5f14d69639cd"},
{file = "pydantic_core-2.41.5-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:3f37a19d7ebcdd20b96485056ba9e8b304e27d9904d233d7b1015db320e51f0a"},
{file = "pydantic_core-2.41.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1d1d9764366c73f996edd17abb6d9d7649a7eb690006ab6adbda117717099b14"},
{file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e1c2af0fce638d5f1988b686f3b3ea8cd7de5f244ca147c777769e798a9cd1"},
{file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506d766a8727beef16b7adaeb8ee6217c64fc813646b424d0804d67c16eddb66"},
{file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4819fa52133c9aa3c387b3328f25c1facc356491e6135b459f1de698ff64d869"},
{file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b761d210c9ea91feda40d25b4efe82a1707da2ef62901466a42492c028553a2"},
{file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22f0fb8c1c583a3b6f24df2470833b40207e907b90c928cc8d3594b76f874375"},
{file = "pydantic_core-2.41.5-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c870e99878c634505236d81e5443092fba820f0373997ff75f90f68cd553"},
{file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0177272f88ab8312479336e1d777f6b124537d47f2123f89cb37e0accea97f90"},
{file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:63510af5e38f8955b8ee5687740d6ebf7c2a0886d15a6d65c32814613681bc07"},
{file = "pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:e56ba91f47764cc14f1daacd723e3e82d1a89d783f0f5afe9c364b8bb491ccdb"},
{file = "pydantic_core-2.41.5-cp314-cp314-win32.whl", hash = "sha256:aec5cf2fd867b4ff45b9959f8b20ea3993fc93e63c7363fe6851424c8a7e7c23"},
{file = "pydantic_core-2.41.5-cp314-cp314-win_amd64.whl", hash = "sha256:8e7c86f27c585ef37c35e56a96363ab8de4e549a95512445b85c96d3e2f7c1bf"},
{file = "pydantic_core-2.41.5-cp314-cp314-win_arm64.whl", hash = "sha256:e672ba74fbc2dc8eea59fb6d4aed6845e6905fc2a8afe93175d94a83ba2a01a0"},
{file = "pydantic_core-2.41.5-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8566def80554c3faa0e65ac30ab0932b9e3a5cd7f8323764303d468e5c37595a"},
{file = "pydantic_core-2.41.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b80aa5095cd3109962a298ce14110ae16b8c1aece8b72f9dafe81cf597ad80b3"},
{file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3006c3dd9ba34b0c094c544c6006cc79e87d8612999f1a5d43b769b89181f23c"},
{file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:72f6c8b11857a856bcfa48c86f5368439f74453563f951e473514579d44aa612"},
{file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cb1b2f9742240e4bb26b652a5aeb840aa4b417c7748b6f8387927bc6e45e40d"},
{file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3d54f38609ff308209bd43acea66061494157703364ae40c951f83ba99a1a9"},
{file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff4321e56e879ee8d2a879501c8e469414d948f4aba74a2d4593184eb326660"},
{file = "pydantic_core-2.41.5-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0d2568a8c11bf8225044aa94409e21da0cb09dcdafe9ecd10250b2baad531a9"},
{file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:a39455728aabd58ceabb03c90e12f71fd30fa69615760a075b9fec596456ccc3"},
{file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:239edca560d05757817c13dc17c50766136d21f7cd0fac50295499ae24f90fdf"},
{file = "pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470"},
{file = "pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa"},
{file = "pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c"},
{file = "pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008"},
{file = "pydantic_core-2.41.5-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:8bfeaf8735be79f225f3fefab7f941c712aaca36f1128c9d7e2352ee1aa87bdf"},
{file = "pydantic_core-2.41.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:346285d28e4c8017da95144c7f3acd42740d637ff41946af5ce6e5e420502dd5"},
{file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a75dafbf87d6276ddc5b2bf6fae5254e3d0876b626eb24969a574fff9149ee5d"},
{file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7b93a4d08587e2b7e7882de461e82b6ed76d9026ce91ca7915e740ecc7855f60"},
{file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e8465ab91a4bd96d36dde3263f06caa6a8a6019e4113f24dc753d79a8b3a3f82"},
{file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:299e0a22e7ae2b85c1a57f104538b2656e8ab1873511fd718a1c1c6f149b77b5"},
{file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:707625ef0983fcfb461acfaf14de2067c5942c6bb0f3b4c99158bed6fedd3cf3"},
{file = "pydantic_core-2.41.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f41eb9797986d6ebac5e8edff36d5cef9de40def462311b3eb3eeded1431e425"},
{file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0384e2e1021894b1ff5a786dbf94771e2986ebe2869533874d7e43bc79c6f504"},
{file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_armv7l.whl", hash = "sha256:f0cd744688278965817fd0839c4a4116add48d23890d468bc436f78beb28abf5"},
{file = "pydantic_core-2.41.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:753e230374206729bf0a807954bcc6c150d3743928a73faffee51ac6557a03c3"},
{file = "pydantic_core-2.41.5-cp39-cp39-win32.whl", hash = "sha256:873e0d5b4fb9b89ef7c2d2a963ea7d02879d9da0da8d9d4933dee8ee86a8b460"},
{file = "pydantic_core-2.41.5-cp39-cp39-win_amd64.whl", hash = "sha256:e4f4a984405e91527a0d62649ee21138f8e3d0ef103be488c1dc11a80d7f184b"},
{file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:b96d5f26b05d03cc60f11a7761a5ded1741da411e7fe0909e27a5e6a0cb7b034"},
{file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:634e8609e89ceecea15e2d61bc9ac3718caaaa71963717bf3c8f38bfde64242c"},
{file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:93e8740d7503eb008aa2df04d3b9735f845d43ae845e6dcd2be0b55a2da43cd2"},
{file = "pydantic_core-2.41.5-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f15489ba13d61f670dcc96772e733aad1a6f9c429cc27574c6cdaed82d0146ad"},
{file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd"},
{file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc"},
{file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56"},
{file = "pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b5819cd790dbf0c5eb9f82c73c16b39a65dd6dd4d1439dcdea7816ec9adddab8"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:5a4e67afbc95fa5c34cf27d9089bca7fcab4e51e57278d710320a70b956d1b9a"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ece5c59f0ce7d001e017643d8d24da587ea1f74f6993467d85ae8a5ef9d4f42b"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:16f80f7abe3351f8ea6858914ddc8c77e02578544a0ebc15b4c2e1a0e813b0b2"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:33cb885e759a705b426baada1fe68cbb0a2e68e34c5d0d0289a364cf01709093"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:c8d8b4eb992936023be7dee581270af5c6e0697a8559895f527f5b7105ecd36a"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:242a206cd0318f95cd21bdacff3fcc3aab23e79bba5cac3db5a841c9ef9c6963"},
{file = "pydantic_core-2.41.5-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d3a978c4f57a597908b7e697229d996d77a6d3c94901e9edee593adada95ce1a"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2379fa7ed44ddecb5bfe4e48577d752db9fc10be00a6b7446e9663ba143de26"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:266fb4cbf5e3cbd0b53669a6d1b039c45e3ce651fd5442eff4d07c2cc8d66808"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58133647260ea01e4d0500089a8c4f07bd7aa6ce109682b1426394988d8aaacc"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:287dad91cfb551c363dc62899a80e9e14da1f0e2b6ebde82c806612ca2a13ef1"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:03b77d184b9eb40240ae9fd676ca364ce1085f203e1b1256f8ab9984dca80a84"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:a668ce24de96165bb239160b3d854943128f4334822900534f2fe947930e5770"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f14f8f046c14563f8eb3f45f499cc658ab8d10072961e07225e507adb700e93f"},
{file = "pydantic_core-2.41.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:56121965f7a4dc965bff783d70b907ddf3d57f6eba29b6d2e5dabfaf07799c51"},
{file = "pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e"},
]
[package.dependencies]
typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0"
typing-extensions = ">=4.14.1"
[[package]]
name = "pydantic-settings"
@@ -13625,14 +13646,14 @@ files = [
[[package]]
name = "typing-inspection"
version = "0.4.1"
version = "0.4.2"
description = "Runtime typing introspection tools"
optional = false
python-versions = ">=3.9"
groups = ["main", "test"]
files = [
{file = "typing_inspection-0.4.1-py3-none-any.whl", hash = "sha256:389055682238f53b04f7badcb49b989835495a96700ced5dab2d8feae4b26f51"},
{file = "typing_inspection-0.4.1.tar.gz", hash = "sha256:6ae134cc0203c33377d43188d4064e9b357dba58cff3185f22924610e70a9d28"},
{file = "typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7"},
{file = "typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464"},
]
[package.dependencies]
+9
View File
@@ -37,6 +37,12 @@ from server.routes.mcp_patch import patch_mcp_server # noqa: E402
from server.routes.oauth_device import oauth_device_router # noqa: E402
from server.routes.readiness import readiness_router # noqa: E402
from server.routes.user import saas_user_router # noqa: E402
from server.sharing.shared_conversation_router import ( # noqa: E402
router as shared_conversation_router,
)
from server.sharing.shared_event_router import ( # noqa: E402
router as shared_event_router,
)
from openhands.server.app import app as base_app # noqa: E402
from openhands.server.listen_socket import sio # noqa: E402
@@ -66,6 +72,8 @@ base_app.include_router(saas_user_router) # Add additional route SAAS user call
base_app.include_router(
billing_router
) # Add routes for credit management and Stripe payment integration
base_app.include_router(shared_conversation_router)
base_app.include_router(shared_event_router)
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
if GITHUB_APP_CLIENT_ID:
@@ -99,6 +107,7 @@ base_app.include_router(
event_webhook_router
) # Add routes for Events in nested runtimes
base_app.add_middleware(
CORSMiddleware,
allow_origins=PERMITTED_CORS_ORIGINS,
+5
View File
@@ -38,3 +38,8 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
'y',
'on',
)
BLOCKED_EMAIL_DOMAINS = [
domain.strip().lower()
for domain in os.getenv('BLOCKED_EMAIL_DOMAINS', '').split(',')
if domain.strip()
]
+75
View File
@@ -0,0 +1,75 @@
from server.auth.constants import BLOCKED_EMAIL_DOMAINS
from openhands.core.logger import openhands_logger as logger
class DomainBlocker:
def __init__(self) -> None:
logger.debug('Initializing DomainBlocker')
self.blocked_domains: list[str] = BLOCKED_EMAIL_DOMAINS
if self.blocked_domains:
logger.info(
f'Successfully loaded {len(self.blocked_domains)} blocked email domains: {self.blocked_domains}'
)
def is_active(self) -> bool:
"""Check if domain blocking is enabled"""
return bool(self.blocked_domains)
def _extract_domain(self, email: str) -> str | None:
"""Extract and normalize email domain from email address"""
if not email:
return None
try:
# Extract domain part after @
if '@' not in email:
return None
domain = email.split('@')[1].strip().lower()
return domain if domain else None
except Exception:
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
return None
def is_domain_blocked(self, email: str) -> bool:
"""Check if email domain is blocked
Supports blocking:
- Exact domains: 'example.com' blocks 'user@example.com'
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
"""
if not self.is_active():
return False
if not email:
logger.debug('No email provided for domain check')
return False
domain = self._extract_domain(email)
if not domain:
logger.debug(f'Could not extract domain from email: {email}')
return False
# Check if domain matches any blocked pattern
for blocked_pattern in self.blocked_domains:
if blocked_pattern.startswith('.'):
# TLD pattern (e.g., '.us') - check if domain ends with it
if domain.endswith(blocked_pattern):
logger.warning(
f'Email domain {domain} is blocked by TLD pattern {blocked_pattern} for email: {email}'
)
return True
else:
# Full domain pattern (e.g., 'example.com')
# Block exact match or subdomains
if domain == blocked_pattern or domain.endswith(f'.{blocked_pattern}'):
logger.warning(
f'Email domain {domain} is blocked by domain pattern {blocked_pattern} for email: {email}'
)
return True
logger.debug(f'Email domain {domain} is not blocked')
return False
domain_blocker = DomainBlocker()
+109
View File
@@ -0,0 +1,109 @@
"""Email validation utilities for preventing duplicate signups with + modifier."""
import re
def extract_base_email(email: str) -> str | None:
"""Extract base email from an email address.
For emails with + modifier, extracts the base email (local part before + and @, plus domain).
For emails without + modifier, returns the email as-is.
Examples:
extract_base_email("joe+test@example.com") -> "joe@example.com"
extract_base_email("joe@example.com") -> "joe@example.com"
extract_base_email("joe+openhands+test@example.com") -> "joe@example.com"
Args:
email: The email address to process
Returns:
The base email address, or None if email format is invalid
"""
if not email or '@' not in email:
return None
try:
local_part, domain = email.rsplit('@', 1)
# Extract the part before + if it exists
base_local = local_part.split('+', 1)[0]
return f'{base_local}@{domain}'
except (ValueError, AttributeError):
return None
def has_plus_modifier(email: str) -> bool:
"""Check if an email address contains a + modifier.
Args:
email: The email address to check
Returns:
True if email contains + before @, False otherwise
"""
if not email or '@' not in email:
return False
try:
local_part, _ = email.rsplit('@', 1)
return '+' in local_part
except (ValueError, AttributeError):
return False
def matches_base_email(email: str, base_email: str) -> bool:
"""Check if an email matches a base email pattern.
An email matches if:
- It is exactly the base email (e.g., joe@example.com)
- It has the same base local part and domain, with or without + modifier
(e.g., joe+test@example.com matches base joe@example.com)
Args:
email: The email address to check
base_email: The base email to match against
Returns:
True if email matches the base pattern, False otherwise
"""
if not email or not base_email:
return False
# Extract base from both emails for comparison
email_base = extract_base_email(email)
base_email_normalized = extract_base_email(base_email)
if not email_base or not base_email_normalized:
return False
# Emails match if they have the same base
return email_base.lower() == base_email_normalized.lower()
def get_base_email_regex_pattern(base_email: str) -> re.Pattern | None:
"""Generate a regex pattern to match emails with the same base.
For base_email "joe@example.com", the pattern will match:
- joe@example.com
- joe+anything@example.com
Args:
base_email: The base email address
Returns:
A compiled regex pattern, or None if base_email is invalid
"""
base = extract_base_email(base_email)
if not base:
return None
try:
local_part, domain = base.rsplit('@', 1)
# Escape special regex characters in local part and domain
escaped_local = re.escape(local_part)
escaped_domain = re.escape(domain)
# Pattern: joe@example.com OR joe+anything@example.com
pattern = rf'^{escaped_local}(\+[^@\s]+)?@{escaped_domain}$'
return re.compile(pattern, re.IGNORECASE)
except (ValueError, AttributeError):
return None
+15 -2
View File
@@ -13,6 +13,7 @@ from server.auth.auth_error import (
ExpiredError,
NoCredentialsError,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.token_manager import TokenManager
from server.config import get_config
from server.logger import logger
@@ -153,8 +154,10 @@ class SaasUserAuth(UserAuth):
try:
# TODO: I think we can do this in a single request if we refactor
with session_maker() as session:
tokens = session.query(AuthTokens).where(
AuthTokens.keycloak_user_id == self.user_id
tokens = (
session.query(AuthTokens)
.where(AuthTokens.keycloak_user_id == self.user_id)
.all()
)
for token in tokens:
@@ -312,6 +315,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
user_id = access_token_payload['sub']
email = access_token_payload['email']
email_verified = access_token_payload['email_verified']
# Check if email domain is blocked
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for existing user with email: {email}'
)
raise AuthError(
'Access denied: Your email domain is not allowed to access this service'
)
logger.debug('saas_user_auth_from_signed_token:return')
return SaasUserAuth(
+236
View File
@@ -1,3 +1,4 @@
import asyncio
import base64
import hashlib
import json
@@ -13,6 +14,7 @@ from keycloak.exceptions import (
KeycloakAuthenticationError,
KeycloakConnectionError,
KeycloakError,
KeycloakPostError,
)
from server.auth.constants import (
BITBUCKET_APP_CLIENT_ID,
@@ -25,6 +27,11 @@ from server.auth.constants import (
KEYCLOAK_SERVER_URL,
KEYCLOAK_SERVER_URL_EXT,
)
from server.auth.email_validation import (
extract_base_email,
get_base_email_regex_pattern,
matches_base_email,
)
from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
from server.config import get_config
from server.logger import logger
@@ -37,6 +44,7 @@ from storage.offline_token_store import OfflineTokenStore
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
from openhands.integrations.service_types import ProviderType
from openhands.server.types import SessionExpiredError
from openhands.utils.http_session import httpx_verify_option
@@ -459,6 +467,14 @@ class TokenManager:
except KeycloakConnectionError:
logger.exception('KeycloakConnectionError when refreshing token')
raise
except KeycloakPostError as e:
error_message = str(e)
if 'invalid_grant' in error_message or 'session not found' in error_message:
logger.warning(f'User session expired or invalid: {error_message}')
raise SessionExpiredError(
'Your session has expired. Please login again.'
) from e
raise
@retry(
stop=stop_after_attempt(2),
@@ -509,6 +525,183 @@ class TokenManager:
logger.info(f'Got user ID {keycloak_user_id} from email: {email}')
return keycloak_user_id
async def _query_users_by_wildcard_pattern(
self, local_part: str, domain: str
) -> dict[str, dict]:
"""Query Keycloak for users matching a wildcard email pattern.
Tries multiple query methods to find users with emails matching
the pattern {local_part}*@{domain}. This catches the base email
and all + modifier variants.
Args:
local_part: The local part of the email (before @)
domain: The domain part of the email (after @)
Returns:
Dictionary mapping user IDs to user objects
"""
keycloak_admin = get_keycloak_admin(self.external)
all_users = {}
# Query for users with emails matching the base pattern using wildcard
# Pattern: {local_part}*@{domain} - catches base email and all + variants
# This may also catch unintended matches (e.g., joesmith@example.com), but
# they will be filtered out by the regex pattern check later
# Use 'search' parameter for Keycloak 26+ (better wildcard support)
wildcard_queries = [
{'search': f'{local_part}*@{domain}'}, # Try 'search' parameter first
{'q': f'email:{local_part}*@{domain}'}, # Fallback to 'q' parameter
]
for query_params in wildcard_queries:
try:
users = await keycloak_admin.a_get_users(query_params)
for user in users:
all_users[user.get('id')] = user
break # Success, no need to try fallback
except Exception as e:
logger.debug(
f'Wildcard query failed with {list(query_params.keys())[0]}: {e}'
)
continue # Try next query method
return all_users
def _find_duplicate_in_users(
self, users: dict[str, dict], base_email: str, current_user_id: str
) -> bool:
"""Check if any user in the provided list matches the base email pattern.
Filters users to find duplicates that match the base email pattern,
excluding the current user.
Args:
users: Dictionary mapping user IDs to user objects
base_email: The base email to match against
current_user_id: The user ID to exclude from the check
Returns:
True if a duplicate is found, False otherwise
"""
regex_pattern = get_base_email_regex_pattern(base_email)
if not regex_pattern:
logger.warning(
f'Could not generate regex pattern for base email: {base_email}'
)
# Fallback to simple matching
for user in users.values():
user_email = user.get('email', '').lower()
if (
user_email
and user.get('id') != current_user_id
and matches_base_email(user_email, base_email)
):
logger.info(
f'Found duplicate email: {user_email} matches base {base_email}'
)
return True
else:
for user in users.values():
user_email = user.get('email', '')
if (
user_email
and user.get('id') != current_user_id
and regex_pattern.match(user_email)
):
logger.info(
f'Found duplicate email: {user_email} matches base {base_email}'
)
return True
return False
@retry(
stop=stop_after_attempt(2),
retry=retry_if_exception_type(KeycloakConnectionError),
before_sleep=_before_sleep_callback,
)
async def check_duplicate_base_email(
self, email: str, current_user_id: str
) -> bool:
"""Check if a user with the same base email already exists.
This method checks for duplicate signups using email + modifier.
It checks if any user exists with the same base email, regardless of whether
the provided email has a + modifier or not.
Examples:
- If email is "joe+test@example.com", it checks for existing users with
base email "joe@example.com" (e.g., "joe@example.com", "joe+1@example.com")
- If email is "joe@example.com", it checks for existing users with
base email "joe@example.com" (e.g., "joe+1@example.com", "joe+test@example.com")
Args:
email: The email address to check (may or may not contain + modifier)
current_user_id: The user ID of the current user (to exclude from check)
Returns:
True if a duplicate is found (excluding current user), False otherwise
"""
if not email:
return False
base_email = extract_base_email(email)
if not base_email:
logger.warning(f'Could not extract base email from: {email}')
return False
try:
local_part, domain = base_email.rsplit('@', 1)
users = await self._query_users_by_wildcard_pattern(local_part, domain)
return self._find_duplicate_in_users(users, base_email, current_user_id)
except KeycloakConnectionError:
logger.exception('KeycloakConnectionError when checking duplicate email')
raise
except Exception as e:
logger.exception(f'Unexpected error checking duplicate email: {e}')
# On any error, allow signup to proceed (fail open)
return False
@retry(
stop=stop_after_attempt(2),
retry=retry_if_exception_type(KeycloakConnectionError),
before_sleep=_before_sleep_callback,
)
async def delete_keycloak_user(self, user_id: str) -> bool:
"""Delete a user from Keycloak.
This method is used to clean up user accounts that were created
but should not exist (e.g., duplicate email signups).
Args:
user_id: The Keycloak user ID to delete
Returns:
True if deletion was successful, False otherwise
"""
try:
keycloak_admin = get_keycloak_admin(self.external)
# Use the sync method (python-keycloak doesn't have async delete_user)
# Run it in a thread executor to avoid blocking the event loop
await asyncio.to_thread(keycloak_admin.delete_user, user_id)
logger.info(f'Successfully deleted Keycloak user {user_id}')
return True
except KeycloakConnectionError:
logger.exception(f'KeycloakConnectionError when deleting user {user_id}')
raise
except KeycloakError as e:
# User might not exist or already deleted
logger.warning(
f'KeycloakError when deleting user {user_id}: {e}',
extra={'user_id': user_id, 'error': str(e)},
)
return False
except Exception as e:
logger.exception(f'Unexpected error deleting Keycloak user {user_id}: {e}')
return False
async def get_user_info_from_user_id(self, user_id: str) -> dict | None:
keycloak_admin = get_keycloak_admin(self.external)
user = await keycloak_admin.a_get_user(user_id)
@@ -527,6 +720,49 @@ class TokenManager:
github_id = github_ids[0]
return github_id
async def disable_keycloak_user(
self, user_id: str, email: str | None = None
) -> None:
"""Disable a Keycloak user account.
Args:
user_id: The Keycloak user ID to disable
email: Optional email address for logging purposes
This method attempts to disable the user account but will not raise exceptions.
Errors are logged but do not prevent the operation from completing.
"""
try:
keycloak_admin = get_keycloak_admin(self.external)
# Get current user to preserve other fields
user = await keycloak_admin.a_get_user(user_id)
if user:
# Update user with enabled=False to disable the account
await keycloak_admin.a_update_user(
user_id=user_id,
payload={
'enabled': False,
'username': user.get('username', ''),
'email': user.get('email', ''),
'emailVerified': user.get('emailVerified', False),
},
)
email_str = f', email: {email}' if email else ''
logger.info(
f'Disabled Keycloak account for user_id: {user_id}{email_str}'
)
else:
logger.warning(
f'User not found in Keycloak when attempting to disable: {user_id}'
)
except Exception as e:
# Log error but don't raise - the caller should handle the blocking regardless
email_str = f', email: {email}' if email else ''
logger.error(
f'Failed to disable Keycloak account for user_id: {user_id}{email_str}: {str(e)}',
exc_info=True,
)
def store_org_token(self, installation_id: int, installation_token: str):
"""Store a GitHub App installation token.
+2
View File
@@ -38,6 +38,8 @@ LITE_LLM_API_URL = os.environ.get(
)
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
# Timeout in seconds for BYOR key verification requests to LiteLLM
BYOR_KEY_VERIFICATION_TIMEOUT = 5.0
SUBSCRIPTION_PRICE_DATA = {
'MONTHLY_SUBSCRIPTION': {
'unit_amount': 2000,
+1
View File
@@ -159,6 +159,7 @@ class SetAuthCookieMiddleware:
'/api/billing/cancel',
'/api/billing/customer-setup-success',
'/api/billing/stripe-webhook',
'/api/email/resend',
'/oauth/device/authorize',
'/oauth/device/token',
)
+101 -4
View File
@@ -4,7 +4,11 @@ import httpx
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from server.config import get_config
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
from server.constants import (
BYOR_KEY_VERIFICATION_TIMEOUT,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
)
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
@@ -112,6 +116,70 @@ async def generate_byor_key(user_id: str) -> str | None:
return None
async def verify_byor_key_in_litellm(byor_key: str, user_id: str) -> bool:
"""Verify that a BYOR key is valid in LiteLLM by making a lightweight API call.
Args:
byor_key: The BYOR key to verify
user_id: The user ID for logging purposes
Returns:
True if the key is verified as valid, False if verification fails or key is invalid.
Returns False on network errors/timeouts to ensure we don't return potentially invalid keys.
"""
if not (LITE_LLM_API_URL and byor_key):
return False
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
timeout=BYOR_KEY_VERIFICATION_TIMEOUT,
) as client:
# Make a lightweight request to verify the key
# Using /v1/models endpoint as it's lightweight and requires authentication
response = await client.get(
f'{LITE_LLM_API_URL}/v1/models',
headers={
'Authorization': f'Bearer {byor_key}',
},
)
# Only 200 status code indicates valid key
if response.status_code == 200:
logger.debug(
'BYOR key verification successful',
extra={'user_id': user_id},
)
return True
# All other status codes (401, 403, 500, etc.) are treated as invalid
# This includes authentication errors and server errors
logger.warning(
'BYOR key verification failed - treating as invalid',
extra={
'user_id': user_id,
'status_code': response.status_code,
'key_prefix': byor_key[:10] + '...'
if len(byor_key) > 10
else byor_key,
},
)
return False
except (httpx.TimeoutException, Exception) as e:
# Any exception (timeout, network error, etc.) means we can't verify
# Return False to trigger regeneration rather than returning potentially invalid key
logger.warning(
'BYOR key verification error - treating as invalid to ensure key validity',
extra={
'user_id': user_id,
'error': str(e),
'error_type': type(e).__name__,
},
)
return False
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
"""Delete the BYOR key from LiteLLM using the key directly."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
@@ -278,18 +346,44 @@ async def delete_api_key(key_id: int, user_id: str = Depends(get_user_id)):
@api_router.get('/llm/byor', response_model=LlmApiKeyResponse)
async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user."""
"""Get the LLM API key for BYOR (Bring Your Own Runtime) for the authenticated user.
This endpoint validates that the key exists in LiteLLM before returning it.
If validation fails, it automatically generates a new key to ensure users
always receive a working key.
"""
try:
# Check if the BYOR key exists in the database
byor_key = await get_byor_key_from_db(user_id)
if byor_key:
return {'key': byor_key}
# Validate that the key is actually registered in LiteLLM
is_valid = await verify_byor_key_in_litellm(byor_key, user_id)
if is_valid:
return {'key': byor_key}
else:
# Key exists in DB but is invalid in LiteLLM - regenerate it
logger.warning(
'BYOR key found in database but invalid in LiteLLM - regenerating',
extra={
'user_id': user_id,
'key_prefix': byor_key[:10] + '...'
if len(byor_key) > 10
else byor_key,
},
)
# Delete the invalid key from LiteLLM (best effort, don't fail if it doesn't exist)
await delete_byor_key_from_litellm(user_id, byor_key)
# Fall through to generate a new key
# If not, generate a new key for BYOR
# Generate a new key for BYOR (either no key exists or validation failed)
key = await generate_byor_key(user_id)
if key:
# Store the key in the database
await store_byor_key_in_db(user_id, key)
logger.info(
'Successfully generated and stored new BYOR key',
extra={'user_id': user_id},
)
return {'key': key}
else:
logger.error(
@@ -301,6 +395,9 @@ async def get_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
detail='Failed to generate new BYOR LLM API key',
)
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except Exception as e:
logger.exception('Error retrieving BYOR LLM API key', extra={'error': str(e)})
raise HTTPException(
+70
View File
@@ -14,6 +14,7 @@ from server.auth.constants import (
KEYCLOAK_SERVER_URL_EXT,
ROLE_CHECK_ENABLED,
)
from server.auth.domain_blocker import domain_blocker
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
@@ -145,7 +146,76 @@ async def keycloak_callback(
content={'error': 'Missing user ID or username in response'},
)
email = user_info.get('email')
user_id = user_info['sub']
# Check if email domain is blocked
email = user_info.get('email')
if email and domain_blocker.is_active() and domain_blocker.is_domain_blocked(email):
logger.warning(
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
)
# Disable the Keycloak account
await token_manager.disable_keycloak_user(user_id, email)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': 'Access denied: Your email domain is not allowed to access this service'
},
)
# Check for duplicate email with + modifier
if email:
try:
has_duplicate = await token_manager.check_duplicate_base_email(
email, user_id
)
if has_duplicate:
logger.warning(
f'Blocked signup attempt for email {email} - duplicate base email found',
extra={'user_id': user_id, 'email': email},
)
# Delete the Keycloak user that was automatically created during OAuth
# This prevents orphaned accounts in Keycloak
# The delete_keycloak_user method already handles all errors internally
deletion_success = await token_manager.delete_keycloak_user(user_id)
if deletion_success:
logger.info(
f'Deleted Keycloak user {user_id} after detecting duplicate email {email}'
)
else:
logger.warning(
f'Failed to delete Keycloak user {user_id} after detecting duplicate email {email}. '
f'User may need to be manually cleaned up.'
)
# Redirect to home page with query parameter indicating the issue
home_url = f'{request.base_url}?duplicated_email=true'
return RedirectResponse(home_url, status_code=302)
except Exception as e:
# Log error but allow signup to proceed (fail open)
logger.error(
f'Error checking duplicate email for {email}: {e}',
extra={'user_id': user_id, 'email': email},
)
# Check email verification status
email_verified = user_info.get('email_verified', False)
if not email_verified:
# Send verification email
# Import locally to avoid circular import with email.py
from server.routes.email import verify_email
await verify_email(request=request, user_id=user_id, is_auth_flow=True)
redirect_url = (
f'{request.base_url}?email_verification_required=true&user_id={user_id}'
)
response = RedirectResponse(redirect_url, status_code=302)
return response
# default to github IDP for now.
# TODO: remove default once Keycloak is updated universally with the new attribute.
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
+18 -4
View File
@@ -111,10 +111,24 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
if not stripe_service.STRIPE_API_KEY:
return GetCreditsResponse()
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
user_json = await _get_litellm_user(client, user_id)
credits = calculate_credits(user_json['user_info'])
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
try:
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
user_json = await _get_litellm_user(client, user_id)
credits = calculate_credits(user_json['user_info'])
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
except httpx.HTTPStatusError as e:
logger.error(
f'litellm_get_user_failed: {type(e).__name__}: {e}',
extra={
'user_id': user_id,
'status_code': e.response.status_code,
},
exc_info=True,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve credit balance from billing service',
)
# Endpoint to retrieve user's current subscription access
+47 -6
View File
@@ -7,6 +7,7 @@ from server.auth.constants import KEYCLOAK_CLIENT_ID
from server.auth.keycloak_manager import get_keycloak_admin
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.auth import set_response_cookie
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
@@ -28,6 +29,11 @@ class EmailUpdate(BaseModel):
return v
class ResendEmailVerificationRequest(BaseModel):
user_id: str | None = None
is_auth_flow: bool = False
@api_router.post('')
async def update_email(
email_data: EmailUpdate, request: Request, user_id: str = Depends(get_user_id)
@@ -74,7 +80,7 @@ async def update_email(
accepted_tos=user_auth.accepted_tos,
)
await _verify_email(request=request, user_id=user_id)
await verify_email(request=request, user_id=user_id)
logger.info(f'Updating email address for {user_id} to {email}')
return response
@@ -90,9 +96,41 @@ async def update_email(
)
@api_router.put('/verify')
async def verify_email(request: Request, user_id: str = Depends(get_user_id)):
await _verify_email(request=request, user_id=user_id)
@api_router.put('/resend')
async def resend_email_verification(
request: Request,
body: ResendEmailVerificationRequest | None = None,
):
# Get user_id from body if provided, otherwise from auth
user_id: str | None = None
if body and body.user_id:
user_id = body.user_id
else:
try:
user_id = await get_user_id(request)
except Exception:
pass
if not user_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='user_id is required in request body or user must be authenticated',
)
# Check rate limit (uses user_id if available, otherwise falls back to IP)
# Use 30 seconds for user-based rate limiting to match frontend cooldown
await check_rate_limit_by_user_id(
request=request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60, # 1 minute for IP-based limiting (more lenient)
)
# Get is_auth_flow from body if provided, default to False
is_auth_flow = body.is_auth_flow if body else False
await verify_email(request=request, user_id=user_id, is_auth_flow=is_auth_flow)
logger.info(f'Resending verification email for {user_id}')
return JSONResponse(
@@ -124,10 +162,13 @@ async def verified_email(request: Request):
return response
async def _verify_email(request: Request, user_id: str):
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
keycloak_admin = get_keycloak_admin()
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
if is_auth_flow:
redirect_uri = f'{scheme}://{request.url.netloc}?email_verified=true'
else:
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
logger.info(f'Redirect URI: {redirect_uri}')
await keycloak_admin.a_send_verify_email(
user_id=user_id,
+2 -2
View File
@@ -134,12 +134,12 @@ async def _process_batch_operations_background(
)
except Exception as e:
logger.error(
'error_processing_batch_operation',
f'error_processing_batch_operation: {type(e).__name__}: {e}',
extra={
'path': batch_op.path,
'method': str(batch_op.method),
'error': str(e),
},
exc_info=True,
)
+302 -1
View File
@@ -1,15 +1,28 @@
import asyncio
import hashlib
import json
from fastapi import APIRouter, Header, HTTPException, Request
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from fastapi.responses import JSONResponse
from integrations.gitlab.gitlab_manager import GitlabManager
from integrations.gitlab.gitlab_service import SaaSGitLabService
from integrations.gitlab.webhook_installation import (
BreakLoopException,
install_webhook_on_resource,
verify_webhook_conditions,
)
from integrations.models import Message, SourceType
from integrations.types import GitLabResourceType
from integrations.utils import GITLAB_WEBHOOK_URL
from pydantic import BaseModel
from server.auth.token_manager import TokenManager
from storage.gitlab_webhook import GitlabWebhook
from storage.gitlab_webhook_store import GitlabWebhookStore
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.server.shared import sio
from openhands.server.user_auth import get_user_id
gitlab_integration_router = APIRouter(prefix='/integration')
webhook_store = GitlabWebhookStore()
@@ -18,6 +31,37 @@ token_manager = TokenManager()
gitlab_manager = GitlabManager(token_manager)
# Request/Response models
class ResourceIdentifier(BaseModel):
type: GitLabResourceType
id: str
class ReinstallWebhookRequest(BaseModel):
resource: ResourceIdentifier
class ResourceWithWebhookStatus(BaseModel):
id: str
name: str
full_path: str
type: str
webhook_installed: bool
webhook_uuid: str | None
last_synced: str | None
class GitLabResourcesResponse(BaseModel):
resources: list[ResourceWithWebhookStatus]
class ResourceInstallationResult(BaseModel):
resource_id: str
resource_type: str
success: bool
error: str | None
async def verify_gitlab_signature(
header_webhook_secret: str, webhook_uuid: str, user_id: str
):
@@ -83,3 +127,260 @@ async def gitlab_events(
except Exception as e:
logger.exception(f'Error processing GitLab event: {e}')
return JSONResponse(status_code=400, content={'error': 'Invalid payload.'})
@gitlab_integration_router.get('/gitlab/resources')
async def get_gitlab_resources(
user_id: str = Depends(get_user_id),
) -> GitLabResourcesResponse:
"""Get all GitLab projects and groups where the user has admin access.
Returns a list of resources with their webhook installation status.
"""
try:
# Get GitLab service for the user
gitlab_service = GitLabServiceImpl(external_auth_id=user_id)
if not isinstance(gitlab_service, SaaSGitLabService):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Only SaaS GitLab service is supported',
)
# Fetch projects and groups with admin access
projects, groups = await gitlab_service.get_user_resources_with_admin_access()
# Filter out projects that belong to a group (nested projects)
# We only want top-level personal projects since group webhooks cover nested projects
filtered_projects = [
project
for project in projects
if project.get('namespace', {}).get('kind') != 'group'
]
# Extract IDs for bulk fetching
project_ids = [str(project['id']) for project in filtered_projects]
group_ids = [str(group['id']) for group in groups]
# Bulk fetch webhook records from database (organization-wide)
(
project_webhook_map,
group_webhook_map,
) = await webhook_store.get_webhooks_by_resources(project_ids, group_ids)
# Parallelize GitLab API calls to check webhook status for all resources
async def check_project_webhook(project):
project_id = str(project['id'])
webhook_exists, _ = await gitlab_service.check_webhook_exists_on_resource(
GitLabResourceType.PROJECT, project_id, GITLAB_WEBHOOK_URL
)
return project_id, webhook_exists
async def check_group_webhook(group):
group_id = str(group['id'])
webhook_exists, _ = await gitlab_service.check_webhook_exists_on_resource(
GitLabResourceType.GROUP, group_id, GITLAB_WEBHOOK_URL
)
return group_id, webhook_exists
# Gather all API calls in parallel
project_checks = [
check_project_webhook(project) for project in filtered_projects
]
group_checks = [check_group_webhook(group) for group in groups]
# Execute all checks concurrently
all_results = await asyncio.gather(*(project_checks + group_checks))
# Split results back into projects and groups
num_projects = len(filtered_projects)
project_results = all_results[:num_projects]
group_results = all_results[num_projects:]
# Build response
resources = []
# Add projects with their webhook status
for project, (project_id, webhook_exists) in zip(
filtered_projects, project_results
):
webhook = project_webhook_map.get(project_id)
resources.append(
ResourceWithWebhookStatus(
id=project_id,
name=project.get('name', ''),
full_path=project.get('path_with_namespace', ''),
type='project',
webhook_installed=webhook_exists,
webhook_uuid=webhook.webhook_uuid if webhook else None,
last_synced=(
webhook.last_synced.isoformat()
if webhook and webhook.last_synced
else None
),
)
)
# Add groups with their webhook status
for group, (group_id, webhook_exists) in zip(groups, group_results):
webhook = group_webhook_map.get(group_id)
resources.append(
ResourceWithWebhookStatus(
id=group_id,
name=group.get('name', ''),
full_path=group.get('full_path', ''),
type='group',
webhook_installed=webhook_exists,
webhook_uuid=webhook.webhook_uuid if webhook else None,
last_synced=(
webhook.last_synced.isoformat()
if webhook and webhook.last_synced
else None
),
)
)
logger.info(
'Retrieved GitLab resources',
extra={
'user_id': user_id,
'project_count': len(projects),
'group_count': len(groups),
},
)
return GitLabResourcesResponse(resources=resources)
except HTTPException:
raise
except Exception as e:
logger.exception(f'Error retrieving GitLab resources: {e}')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to retrieve GitLab resources',
)
@gitlab_integration_router.post('/gitlab/reinstall-webhook')
async def reinstall_gitlab_webhook(
body: ReinstallWebhookRequest,
user_id: str = Depends(get_user_id),
) -> ResourceInstallationResult:
"""Reinstall GitLab webhook for a specific resource immediately.
This endpoint validates permissions, resets webhook status in the database,
and immediately installs the webhook on the specified resource.
"""
try:
# Get GitLab service for the user
gitlab_service = GitLabServiceImpl(external_auth_id=user_id)
if not isinstance(gitlab_service, SaaSGitLabService):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Only SaaS GitLab service is supported',
)
resource_id = body.resource.id
resource_type = body.resource.type
# Check if user has admin access to this resource
(
has_admin_access,
check_status,
) = await gitlab_service.check_user_has_admin_access_to_resource(
resource_type, resource_id
)
if not has_admin_access:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail='User does not have admin access to this resource',
)
# Reset webhook in database (organization-wide, not user-specific)
# This allows any admin user to reinstall webhooks
await webhook_store.reset_webhook_for_reinstallation_by_resource(
resource_type, resource_id, user_id
)
# Get or create webhook record (without user_id filter)
webhook = await webhook_store.get_webhook_by_resource_only(
resource_type, resource_id
)
if not webhook:
# Create new webhook record
webhook = GitlabWebhook(
user_id=user_id, # Track who created it
project_id=resource_id
if resource_type == GitLabResourceType.PROJECT
else None,
group_id=resource_id
if resource_type == GitLabResourceType.GROUP
else None,
webhook_exists=False,
)
await webhook_store.store_webhooks([webhook])
# Fetch it again to get the ID (without user_id filter)
webhook = await webhook_store.get_webhook_by_resource_only(
resource_type, resource_id
)
# Verify conditions and install webhook
try:
await verify_webhook_conditions(
gitlab_service=gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=webhook_store,
webhook=webhook,
)
# Install the webhook
webhook_id, install_status = await install_webhook_on_resource(
gitlab_service=gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=webhook_store,
webhook=webhook,
)
if webhook_id:
logger.info(
'GitLab webhook reinstalled successfully',
extra={
'user_id': user_id,
'resource_type': resource_type.value,
'resource_id': resource_id,
},
)
return ResourceInstallationResult(
resource_id=resource_id,
resource_type=resource_type.value,
success=True,
error=None,
)
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to install webhook',
)
except BreakLoopException:
# Conditions not met or webhook already exists
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Webhook installation conditions not met or webhook already exists',
)
except HTTPException:
raise
except Exception as e:
logger.exception(f'Error reinstalling GitLab webhook: {e}')
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to reinstall webhook',
)
@@ -12,6 +12,8 @@ from typing import Any, cast
import httpx
import socketio
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import PERMITTED_CORS_ORIGINS, WEB_HOST
from server.utils.conversation_callback_utils import (
process_event,
@@ -29,7 +31,11 @@ from openhands.core.logger import openhands_logger as logger
from openhands.events.action import MessageAction
from openhands.events.event_store import EventStore
from openhands.events.serialization.event import event_to_dict
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
from openhands.integrations.provider import (
PROVIDER_TOKEN_TYPE,
ProviderHandler,
ProviderToken,
)
from openhands.runtime.impl.remote.remote_runtime import RemoteRuntime
from openhands.runtime.plugins.vscode import VSCodeRequirement
from openhands.runtime.runtime_status import RuntimeStatus
@@ -228,6 +234,102 @@ class SaasNestedConversationManager(ConversationManager):
status=status,
)
async def _refresh_provider_tokens_after_runtime_init(
self, settings: Settings, sid: str, user_id: str | None = None
) -> Settings:
"""Refresh provider tokens after runtime initialization.
During runtime initialization, tokens may be refreshed by Runtime.__init__().
This method retrieves the fresh tokens from the database and creates a new
settings object with updated tokens to avoid sending stale tokens to the
nested runtime.
The method handles two scenarios:
1. ProviderToken has user_id (IDP user ID, e.g., GitLab user ID)
→ Uses get_idp_token_from_idp_user_id()
2. ProviderToken has no user_id but Keycloak user_id is available
→ Uses load_offline_token() + get_idp_token_from_offline_token()
Args:
settings: The conversation settings that may contain provider tokens
sid: The session ID for logging purposes
user_id: The Keycloak user ID (optional, used as fallback when
ProviderToken.user_id is not available)
Returns:
Updated settings with fresh provider tokens, or original settings
if no update is needed
"""
if not isinstance(settings, ConversationInitData):
return settings
if not settings.git_provider_tokens:
return settings
token_manager = TokenManager()
updated_tokens = {}
tokens_refreshed = 0
tokens_failed = 0
for provider_type, provider_token in settings.git_provider_tokens.items():
fresh_token = None
try:
if provider_token.user_id:
# Case 1: We have IDP user ID (e.g., GitLab user ID '32546706')
# Get the token that was just refreshed during runtime initialization
fresh_token = await token_manager.get_idp_token_from_idp_user_id(
provider_token.user_id, provider_type
)
elif user_id:
# Case 2: We have Keycloak user ID but no IDP user ID
# This happens in web UI flow where ProviderToken.user_id is None
offline_token = await token_manager.load_offline_token(user_id)
if offline_token:
fresh_token = (
await token_manager.get_idp_token_from_offline_token(
offline_token, provider_type
)
)
if fresh_token:
updated_tokens[provider_type] = ProviderToken(
token=SecretStr(fresh_token),
user_id=provider_token.user_id,
host=provider_token.host,
)
tokens_refreshed += 1
else:
# Keep original token if we couldn't get a fresh one
updated_tokens[provider_type] = provider_token
except Exception as e:
# If refresh fails, use original token to prevent conversation startup failure
logger.warning(
f'Failed to refresh {provider_type.value} token: {e}',
extra={'session_id': sid, 'provider': provider_type.value},
exc_info=True,
)
updated_tokens[provider_type] = provider_token
tokens_failed += 1
# Create new ConversationInitData with updated tokens
# We cannot modify the frozen field directly, so we create a new object
updated_settings = settings.model_copy(
update={'git_provider_tokens': MappingProxyType(updated_tokens)}
)
logger.info(
'Updated provider tokens after runtime creation',
extra={
'session_id': sid,
'providers': [p.value for p in updated_tokens.keys()],
'refreshed': tokens_refreshed,
'failed': tokens_failed,
},
)
return updated_settings
async def _start_agent_loop(
self, sid, settings, user_id, initial_user_msg=None, replay_json=None
):
@@ -249,6 +351,11 @@ class SaasNestedConversationManager(ConversationManager):
session_api_key = runtime.session.headers['X-Session-API-Key']
# Update provider tokens with fresh ones after runtime creation
settings = await self._refresh_provider_tokens_after_runtime_init(
settings, sid, user_id
)
await self._start_conversation(
sid,
user_id,
@@ -333,7 +440,12 @@ class SaasNestedConversationManager(ConversationManager):
async def _setup_provider_tokens(
self, client: httpx.AsyncClient, api_url: str, settings: Settings
):
"""Setup provider tokens for the nested conversation."""
"""Setup provider tokens for the nested conversation.
Note: Token validation happens in the nested runtime. If tokens are revoked,
the nested runtime will return 401. The caller should handle token refresh
and retry if needed.
"""
provider_handler = self._get_provider_handler(settings)
provider_tokens = provider_handler.provider_tokens
if provider_tokens:
@@ -804,6 +916,8 @@ class SaasNestedConversationManager(ConversationManager):
env_vars['ENABLE_V1'] = '0'
env_vars['SU_TO_USER'] = SU_TO_USER
env_vars['DISABLE_VSCODE_PLUGIN'] = str(DISABLE_VSCODE_PLUGIN).lower()
env_vars['BROWSERGYM_DOWNLOAD_DIR'] = '/workspace/.downloads/'
env_vars['PLAYWRIGHT_BROWSERS_PATH'] = '/opt/playwright-browsers'
# We need this for LLM traces tracking to identify the source of the LLM calls
env_vars['WEB_HOST'] = WEB_HOST
+20
View File
@@ -0,0 +1,20 @@
# Sharing Package
This package contains functionality for sharing conversations.
## Components
- **shared.py**: Data models for shared conversations
- **shared_conversation_info_service.py**: Service interface for accessing shared conversation info
- **sql_shared_conversation_info_service.py**: SQL implementation of the shared conversation info service
- **shared_event_service.py**: Service interface for accessing shared events
- **shared_event_service_impl.py**: Implementation of the shared event service
- **shared_conversation_router.py**: REST API endpoints for shared conversations
- **shared_event_router.py**: REST API endpoints for shared events
## Features
- Read-only access to shared conversations
- Event access for shared conversations
- Search and filtering capabilities
- Pagination support
@@ -0,0 +1,142 @@
"""Implementation of SharedEventService.
This implementation provides read-only access to events from shared conversations:
- Validates that the conversation is shared before returning events
- Uses existing EventService for actual event retrieval
- Uses SharedConversationInfoService for shared conversation validation
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_event_service import (
SharedEventService,
SharedEventServiceInjector,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import InjectorState
from openhands.sdk import Event
logger = logging.getLogger(__name__)
@dataclass
class SharedEventServiceImpl(SharedEventService):
"""Implementation of SharedEventService that validates shared access."""
shared_conversation_info_service: SharedConversationInfoService
event_service: EventService
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return None
# If conversation is shared, get the event
return await self.event_service.get_event(event_id)
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
# Return empty page if conversation is not shared
return EventPage(items=[], next_page_id=None)
# If conversation is shared, search events for this conversation
return await self.event_service.search_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
# First check if the conversation is shared
shared_conversation_info = (
await self.shared_conversation_info_service.get_shared_conversation_info(
conversation_id
)
)
if shared_conversation_info is None:
return 0
# If conversation is shared, count events for this conversation
return await self.event_service.count_events(
conversation_id__eq=conversation_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
class SharedEventServiceImplInjector(SharedEventServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedEventService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import (
get_db_session,
get_event_service,
)
async with (
get_db_session(state, request) as db_session,
get_event_service(state, request) as event_service,
):
shared_conversation_info_service = SQLSharedConversationInfoService(
db_session=db_session
)
service = SharedEventServiceImpl(
shared_conversation_info_service=shared_conversation_info_service,
event_service=event_service,
)
yield service
@@ -0,0 +1,66 @@
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from openhands.app_server.services.injector import Injector
from openhands.sdk.utils.models import DiscriminatedUnionMixin
class SharedConversationInfoService(ABC):
"""Service for accessing shared conversation info without user restrictions."""
@abstractmethod
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
@abstractmethod
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations."""
@abstractmethod
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single shared conversation info, returning None if missing or not shared."""
async def batch_get_shared_conversation_info(
self, conversation_ids: list[UUID]
) -> list[SharedConversation | None]:
"""Get a batch of shared conversation info, return None for any missing or non-shared."""
return await asyncio.gather(
*[
self.get_shared_conversation_info(conversation_id)
for conversation_id in conversation_ids
]
)
class SharedConversationInfoServiceInjector(
DiscriminatedUnionMixin, Injector[SharedConversationInfoService], ABC
):
pass
@@ -0,0 +1,56 @@
from datetime import datetime
from enum import Enum
# Simplified imports to avoid dependency chain issues
# from openhands.integrations.service_types import ProviderType
# from openhands.sdk.llm import MetricsSnapshot
# from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# For now, use Any to avoid import issues
from typing import Any
from uuid import uuid4
from pydantic import BaseModel, Field
from openhands.agent_server.utils import OpenHandsUUID, utc_now
ProviderType = Any
MetricsSnapshot = Any
ConversationTrigger = Any
class SharedConversation(BaseModel):
"""Shared conversation info model with all fields from AppConversationInfo."""
id: OpenHandsUUID = Field(default_factory=uuid4)
created_by_user_id: str | None
sandbox_id: str
selected_repository: str | None = None
selected_branch: str | None = None
git_provider: ProviderType | None = None
title: str | None = None
pr_number: list[int] = Field(default_factory=list)
llm_model: str | None = None
metrics: MetricsSnapshot | None = None
parent_conversation_id: OpenHandsUUID | None = None
sub_conversation_ids: list[OpenHandsUUID] = Field(default_factory=list)
created_at: datetime = Field(default_factory=utc_now)
updated_at: datetime = Field(default_factory=utc_now)
class SharedConversationSortOrder(Enum):
CREATED_AT = 'CREATED_AT'
CREATED_AT_DESC = 'CREATED_AT_DESC'
UPDATED_AT = 'UPDATED_AT'
UPDATED_AT_DESC = 'UPDATED_AT_DESC'
TITLE = 'TITLE'
TITLE_DESC = 'TITLE_DESC'
class SharedConversationPage(BaseModel):
items: list[SharedConversation]
next_page_id: str | None = None
@@ -0,0 +1,135 @@
"""Shared Conversation router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoServiceInjector,
)
router = APIRouter(prefix='/api/shared-conversations', tags=['Sharing'])
shared_conversation_info_service_dependency = Depends(
SQLSharedConversationInfoServiceInjector().depends
)
# Read methods
@router.get('/search')
async def search_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
sort_order: Annotated[
SharedConversationSortOrder,
Query(title='Sort order for results'),
] = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(
title='The max number of results in the page',
gt=0,
lte=100,
),
] = 100,
include_sub_conversations: Annotated[
bool,
Query(
title='If True, include sub-conversations in the results. If False (default), exclude all sub-conversations.'
),
] = False,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> SharedConversationPage:
"""Search / List shared conversations."""
assert limit > 0
assert limit <= 100
return await shared_conversation_service.search_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
include_sub_conversations=include_sub_conversations,
)
@router.get('/count')
async def count_shared_conversations(
title__contains: Annotated[
str | None,
Query(title='Filter by title containing this string'),
] = None,
created_at__gte: Annotated[
datetime | None,
Query(title='Filter by created_at greater than or equal to this datetime'),
] = None,
created_at__lt: Annotated[
datetime | None,
Query(title='Filter by created_at less than this datetime'),
] = None,
updated_at__gte: Annotated[
datetime | None,
Query(title='Filter by updated_at greater than or equal to this datetime'),
] = None,
updated_at__lt: Annotated[
datetime | None,
Query(title='Filter by updated_at less than this datetime'),
] = None,
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> int:
"""Count shared conversations matching the given filters."""
return await shared_conversation_service.count_shared_conversation_info(
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
@router.get('')
async def batch_get_shared_conversations(
ids: Annotated[list[str], Query()],
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
) -> list[SharedConversation | None]:
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
assert len(ids) <= 100
uuids = [UUID(id_) for id_ in ids]
shared_conversation_info = (
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
)
return shared_conversation_info
@@ -0,0 +1,126 @@
"""Shared Event router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImplInjector,
)
from server.sharing.shared_event_service import SharedEventService
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
shared_event_service_dependency = Depends(SharedEventServiceImplInjector().depends)
# Read methods
@router.get('/search')
async def search_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to search events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
page_id: Annotated[
str | None,
Query(title='Optional next_page_id from the previously returned page'),
] = None,
limit: Annotated[
int,
Query(title='The max number of results in the page', gt=0, lte=100),
] = 100,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> EventPage:
"""Search / List events for a shared conversation."""
assert limit > 0
assert limit <= 100
return await shared_event_service.search_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
@router.get('/count')
async def count_shared_events(
conversation_id: Annotated[
str,
Query(title='Conversation ID to count events for'),
],
kind__eq: Annotated[
EventKind | None,
Query(title='Optional filter by event kind'),
] = None,
timestamp__gte: Annotated[
datetime | None,
Query(title='Optional filter by timestamp greater than or equal to'),
] = None,
timestamp__lt: Annotated[
datetime | None,
Query(title='Optional filter by timestamp less than'),
] = None,
sort_order: Annotated[
EventSortOrder,
Query(title='Sort order for results'),
] = EventSortOrder.TIMESTAMP,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> int:
"""Count events for a shared conversation matching the given filters."""
return await shared_event_service.count_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
)
@router.get('')
async def batch_get_shared_events(
conversation_id: Annotated[
UUID,
Query(title='Conversation ID to get events for'),
],
id: Annotated[list[str], Query()],
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> list[Event | None]:
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
assert len(id) <= 100
events = await shared_event_service.batch_get_shared_events(conversation_id, id)
return events
@router.get('/{conversation_id}/{event_id}')
async def get_shared_event(
conversation_id: UUID,
event_id: str,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> Event | None:
"""Get a single event from a shared conversation by conversation_id and event_id."""
return await shared_event_service.get_shared_event(conversation_id, event_id)
@@ -0,0 +1,64 @@
import asyncio
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from uuid import UUID
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.app_server.services.injector import Injector
from openhands.sdk import Event
from openhands.sdk.utils.models import DiscriminatedUnionMixin
_logger = logging.getLogger(__name__)
class SharedEventService(ABC):
"""Event Service for getting events from shared conversations only."""
@abstractmethod
async def get_shared_event(
self, conversation_id: UUID, event_id: str
) -> Event | None:
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
@abstractmethod
async def search_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
page_id: str | None = None,
limit: int = 100,
) -> EventPage:
"""Search events for a specific shared conversation."""
@abstractmethod
async def count_shared_events(
self,
conversation_id: UUID,
kind__eq: EventKind | None = None,
timestamp__gte: datetime | None = None,
timestamp__lt: datetime | None = None,
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
) -> int:
"""Count events for a specific shared conversation."""
async def batch_get_shared_events(
self, conversation_id: UUID, event_ids: list[str]
) -> list[Event | None]:
"""Given a conversation_id and list of event_ids, get events if the conversation is shared."""
return await asyncio.gather(
*[
self.get_shared_event(conversation_id, event_id)
for event_id in event_ids
]
)
class SharedEventServiceInjector(
DiscriminatedUnionMixin, Injector[SharedEventService], ABC
):
pass
@@ -0,0 +1,282 @@
"""SQL implementation of SharedConversationInfoService.
This implementation provides read-only access to shared conversations:
- Direct database access without user permission checks
- Filters only conversations marked as shared (currently public)
- Full async/await support using SQL async db_sessions
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
SharedConversationInfoServiceInjector,
)
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.services.injector import InjectorState
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
logger = logging.getLogger(__name__)
@dataclass
class SQLSharedConversationInfoService(SharedConversationInfoService):
"""SQL implementation of SharedConversationInfoService for shared conversations only."""
db_session: AsyncSession
async def search_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: SharedConversationSortOrder = SharedConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> SharedConversationPage:
"""Search for shared conversations."""
query = self._public_select()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == SharedConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == SharedConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == SharedConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == SharedConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == SharedConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == SharedConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.scalars().all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [self._to_shared_conversation(row) for row in rows]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return SharedConversationPage(items=items, next_page_id=next_page_id)
async def count_shared_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count shared conversations matching the given filters."""
from sqlalchemy import func
query = select(func.count(StoredConversationMetadata.conversation_id))
# Only include shared conversations
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
query = query.where(StoredConversationMetadata.conversation_version == 'V1')
query = self._apply_filters(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
return result.scalar() or 0
async def get_shared_conversation_info(
self, conversation_id: UUID
) -> SharedConversation | None:
"""Get a single public conversation info, returning None if missing or not shared."""
query = self._public_select().where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result = await self.db_session.execute(query)
stored = result.scalar_one_or_none()
if stored is None:
return None
return self._to_shared_conversation(stored)
def _public_select(self):
"""Create a select query that only returns public conversations."""
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
# Only include conversations marked as public
query = query.where(StoredConversationMetadata.public == True) # noqa: E712
return query
def _apply_filters(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply common filters to a query."""
if title__contains is not None:
query = query.where(
StoredConversationMetadata.title.contains(title__contains)
)
if created_at__gte is not None:
query = query.where(
StoredConversationMetadata.created_at >= created_at__gte
)
if created_at__lt is not None:
query = query.where(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
query = query.where(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
query = query.where(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
return query
def _to_shared_conversation(
self,
stored: StoredConversationMetadata,
sub_conversation_ids: list[UUID] | None = None,
) -> SharedConversation:
"""Convert StoredConversationMetadata to SharedConversation."""
# V1 conversations should always have a sandbox_id
sandbox_id = stored.sandbox_id
assert sandbox_id is not None
# Rebuild token usage
token_usage = TokenUsage(
prompt_tokens=stored.prompt_tokens,
completion_tokens=stored.completion_tokens,
cache_read_tokens=stored.cache_read_tokens,
cache_write_tokens=stored.cache_write_tokens,
context_window=stored.context_window,
per_turn_token=stored.per_turn_token,
)
# Rebuild metrics object
metrics = MetricsSnapshot(
accumulated_cost=stored.accumulated_cost,
max_budget_per_task=stored.max_budget_per_task,
accumulated_token_usage=token_usage,
)
# Get timestamps
created_at = self._fix_timezone(stored.created_at)
updated_at = self._fix_timezone(stored.last_updated_at)
return SharedConversation(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
git_provider=(
ProviderType(stored.git_provider) if stored.git_provider else None
),
title=stored.title,
pr_number=stored.pr_number,
llm_model=stored.llm_model,
metrics=metrics,
parent_conversation_id=(
UUID(stored.parent_conversation_id)
if stored.parent_conversation_id
else None
),
sub_conversation_ids=sub_conversation_ids or [],
created_at=created_at,
updated_at=updated_at,
)
def _fix_timezone(self, value: datetime) -> datetime:
"""Sqlite does not store timezones - and since we can't update the existing models
we assume UTC if the timezone is missing."""
if not value.tzinfo:
value = value.replace(tzinfo=UTC)
return value
class SQLSharedConversationInfoServiceInjector(SharedConversationInfoServiceInjector):
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[SharedConversationInfoService, None]:
# Define inline to prevent circular lookup
from openhands.app_server.config import get_db_session
async with get_db_session(state, request) as db_session:
service = SQLSharedConversationInfoService(db_session=db_session)
yield service
@@ -0,0 +1,83 @@
from fastapi import HTTPException, Request, status
from openhands.core.logger import openhands_logger as logger
from openhands.server.shared import sio
# Rate limiting constants
RATE_LIMIT_USER_SECONDS = 120 # 2 minutes per user_id
RATE_LIMIT_IP_SECONDS = 300 # 5 minutes per IP address
async def check_rate_limit_by_user_id(
request: Request,
key_prefix: str,
user_id: str | None,
user_rate_limit_seconds: int = RATE_LIMIT_USER_SECONDS,
ip_rate_limit_seconds: int = RATE_LIMIT_IP_SECONDS,
) -> None:
"""
Check rate limit for requests, using user_id when available, falling back to IP address.
Uses Redis to store rate limit keys with expiration. If a key already exists,
it means the rate limit is active and the request will be rejected.
Args:
request: FastAPI Request object
key_prefix: Prefix for the Redis key (e.g., "email_resend")
user_id: User ID if available, None otherwise
user_rate_limit_seconds: Rate limit window in seconds for user_id-based limiting (default: 120)
ip_rate_limit_seconds: Rate limit window in seconds for IP-based limiting (default: 300)
Raises:
HTTPException: If rate limit is exceeded (429 status code)
"""
try:
redis = sio.manager.redis
if not redis:
# If Redis is unavailable, log warning and allow request (fail open)
logger.warning('Redis unavailable for rate limiting, allowing request')
return
if user_id:
# Rate limit by user_id (primary method)
rate_limit_key = f'{key_prefix}:{user_id}'
rate_limit_seconds = user_rate_limit_seconds
else:
# Fallback to IP address rate limiting
client_ip = request.client.host if request.client else 'unknown'
rate_limit_key = f'{key_prefix}:ip:{client_ip}'
rate_limit_seconds = ip_rate_limit_seconds
# Try to set the key with expiration. If it already exists (nx=True fails),
# it means the rate limit is active
created = await redis.set(rate_limit_key, 1, nx=True, ex=rate_limit_seconds)
if not created:
logger.info(
f'Rate limit exceeded for {rate_limit_key}',
extra={
'user_id': user_id,
'ip': request.client.host if request.client else 'unknown',
},
)
# Format error message based on duration
if rate_limit_seconds < 60:
wait_message = f'{rate_limit_seconds} seconds'
elif rate_limit_seconds % 60 == 0:
wait_message = f'{rate_limit_seconds // 60} minute{"s" if rate_limit_seconds // 60 != 1 else ""}'
else:
minutes = rate_limit_seconds // 60
seconds = rate_limit_seconds % 60
wait_message = f'{minutes} minute{"s" if minutes != 1 else ""} and {seconds} second{"s" if seconds != 1 else ""}'
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f'Too many requests. Please wait {wait_message} before trying again.',
)
except HTTPException:
# Re-raise HTTPException (rate limit exceeded)
raise
except Exception as e:
# Log error but allow request (fail open) to avoid blocking legitimate users
logger.warning(f'Error checking rate limit: {e}', exc_info=True)
return
+10 -2
View File
@@ -19,17 +19,23 @@ GCP_REGION = os.environ.get('GCP_REGION')
POOL_SIZE = int(os.environ.get('DB_POOL_SIZE', '25'))
MAX_OVERFLOW = int(os.environ.get('DB_MAX_OVERFLOW', '10'))
POOL_RECYCLE = int(os.environ.get('DB_POOL_RECYCLE', '1800'))
# Initialize Cloud SQL Connector once at module level for GCP environments.
_connector = None
def _get_db_engine():
if GCP_DB_INSTANCE: # GCP environments
def get_db_connection():
global _connector
from google.cloud.sql.connector import Connector
connector = Connector()
if not _connector:
_connector = Connector()
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
return connector.connect(
return _connector.connect(
instance_string, 'pg8000', user=DB_USER, password=DB_PASS, db=DB_NAME
)
@@ -38,6 +44,7 @@ def _get_db_engine():
creator=get_db_connection,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
else:
@@ -48,6 +55,7 @@ def _get_db_engine():
host_string,
pool_size=POOL_SIZE,
max_overflow=MAX_OVERFLOW,
pool_recycle=POOL_RECYCLE,
pool_pre_ping=True,
)
+121
View File
@@ -220,6 +220,127 @@ class GitlabWebhookStore:
return webhooks[0].webhook_secret
return None
async def get_webhook_by_resource_only(
self, resource_type: GitLabResourceType, resource_id: str
) -> GitlabWebhook | None:
"""Get a webhook by resource without filtering by user_id.
This allows any admin user in the organization to manage webhooks,
not just the original installer.
Args:
resource_type: The type of resource (PROJECT or GROUP)
resource_id: The ID of the resource
Returns:
GitlabWebhook object if found, None otherwise
"""
async with self.a_session_maker() as session:
if resource_type == GitLabResourceType.PROJECT:
query = select(GitlabWebhook).where(
GitlabWebhook.project_id == resource_id
)
else: # GROUP
query = select(GitlabWebhook).where(
GitlabWebhook.group_id == resource_id
)
result = await session.execute(query)
webhook = result.scalars().first()
return webhook
async def get_webhooks_by_resources(
self, project_ids: list[str], group_ids: list[str]
) -> tuple[dict[str, GitlabWebhook], dict[str, GitlabWebhook]]:
"""Bulk fetch webhooks for multiple resources.
This is more efficient than fetching one at a time in a loop.
Args:
project_ids: List of project IDs to fetch
group_ids: List of group IDs to fetch
Returns:
Tuple of (project_webhook_map, group_webhook_map)
"""
async with self.a_session_maker() as session:
project_webhook_map = {}
group_webhook_map = {}
# Fetch all project webhooks in one query
if project_ids:
project_query = select(GitlabWebhook).where(
GitlabWebhook.project_id.in_(project_ids)
)
result = await session.execute(project_query)
project_webhooks = result.scalars().all()
project_webhook_map = {wh.project_id: wh for wh in project_webhooks}
# Fetch all group webhooks in one query
if group_ids:
group_query = select(GitlabWebhook).where(
GitlabWebhook.group_id.in_(group_ids)
)
result = await session.execute(group_query)
group_webhooks = result.scalars().all()
group_webhook_map = {wh.group_id: wh for wh in group_webhooks}
return project_webhook_map, group_webhook_map
async def reset_webhook_for_reinstallation_by_resource(
self, resource_type: GitLabResourceType, resource_id: str, updating_user_id: str
) -> bool:
"""Reset webhook for reinstallation without filtering by user_id.
This allows any admin user to reset webhooks, and updates the user_id
to track who last modified it.
Args:
resource_type: The type of resource (PROJECT or GROUP)
resource_id: The ID of the resource
updating_user_id: The user ID performing the update (for audit purposes)
Returns:
True if webhook was reset, False if not found
"""
async with self.a_session_maker() as session:
async with session.begin():
if resource_type == GitLabResourceType.PROJECT:
update_statement = (
update(GitlabWebhook)
.where(GitlabWebhook.project_id == resource_id)
.values(
webhook_exists=False,
webhook_uuid=None,
user_id=updating_user_id, # Update to track who modified it
)
)
else: # GROUP
update_statement = (
update(GitlabWebhook)
.where(GitlabWebhook.group_id == resource_id)
.values(
webhook_exists=False,
webhook_uuid=None,
user_id=updating_user_id, # Update to track who modified it
)
)
result = await session.execute(update_statement)
rows_updated = result.rowcount
logger.info(
'Reset webhook for reinstallation (organization-wide)',
extra={
'updating_user_id': updating_user_id,
'resource_type': resource_type.value,
'resource_id': resource_id,
'rows_updated': rows_updated,
},
)
return rows_updated > 0
@classmethod
async def get_instance(cls) -> GitlabWebhookStore:
"""Get an instance of the GitlabWebhookStore.
+21 -13
View File
@@ -2,6 +2,7 @@ from dataclasses import dataclass
from datetime import datetime
from sqlalchemy import and_, desc
from sqlalchemy.exc import ProgrammingError
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.openhands_pr import OpenhandsPR
@@ -135,22 +136,29 @@ class OpenhandsPRStore:
Returns:
List of OpenhandsPR objects that need processing
"""
with self.session_maker() as session:
unprocessed_prs = (
session.query(OpenhandsPR)
.filter(
and_(
~OpenhandsPR.processed,
OpenhandsPR.process_attempts < max_retries,
OpenhandsPR.provider == ProviderType.GITHUB.value,
try:
with self.session_maker() as session:
unprocessed_prs = (
session.query(OpenhandsPR)
.filter(
and_(
~OpenhandsPR.processed,
OpenhandsPR.process_attempts < max_retries,
OpenhandsPR.provider == ProviderType.GITHUB.value,
)
)
.order_by(desc(OpenhandsPR.updated_at))
.limit(limit)
.all()
)
.order_by(desc(OpenhandsPR.updated_at))
.limit(limit)
.all()
)
return unprocessed_prs
return unprocessed_prs
except ProgrammingError as e:
logger.warning(
f'Could not query openhands_prs table - it may not exist yet. '
f'Run database migrations first. Error: {e}'
)
return []
@classmethod
def get_instance(cls):
@@ -61,6 +61,7 @@ class SaasConversationStore(ConversationStore):
kwargs.pop('context_window', None)
kwargs.pop('per_turn_token', None)
kwargs.pop('parent_conversation_id', None)
kwargs.pop('public')
return ConversationMetadata(**kwargs)
+90 -13
View File
@@ -19,6 +19,7 @@ from server.constants import (
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
REQUIRE_PAYMENT,
USER_SETTINGS_VERSION_TO_MODEL,
get_default_litellm_model,
)
from server.logger import logger
@@ -202,6 +203,53 @@ class SaasSettingsStore(SettingsStore):
)
return None
def _has_custom_settings(
self, settings: Settings, old_user_version: int | None
) -> bool:
"""
Check if user has custom LLM settings that should be preserved.
Returns True if user customized either model or base_url.
Args:
settings: The user's current settings
old_user_version: The user's old settings version, if any
Returns:
True if user has custom settings, False if using old defaults
"""
# Normalize values
user_model = (
settings.llm_model.strip()
if settings.llm_model and settings.llm_model.strip()
else None
)
user_base_url = (
settings.llm_base_url.strip()
if settings.llm_base_url and settings.llm_base_url.strip()
else None
)
# Custom base_url = definitely custom settings (BYOK)
if user_base_url and user_base_url != LITE_LLM_API_URL:
return True
# No model set = using defaults
if not user_model:
return False
# Check if model matches old version's default
if (
old_user_version
and old_user_version < CURRENT_USER_SETTINGS_VERSION
and old_user_version in USER_SETTINGS_VERSION_TO_MODEL
):
old_default_base = USER_SETTINGS_VERSION_TO_MODEL[old_user_version]
user_model_base = user_model.split('/')[-1]
if user_model_base == old_default_base:
return False # Matches old default
return True # Custom model
async def update_settings_with_litellm_default(
self, settings: Settings
) -> Settings | None:
@@ -213,6 +261,17 @@ class SaasSettingsStore(SettingsStore):
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
key = LITE_LLM_API_KEY
# Check if user has custom settings
has_custom = self._has_custom_settings(settings, settings.user_version)
# Determine model to use (needed before LiteLLM user creation)
llm_model_to_use = (
settings.llm_model
if has_custom and settings.llm_model
else get_default_litellm_model()
)
if not local_deploy:
# Get user info to add to litellm
token_manager = TokenManager()
@@ -226,14 +285,21 @@ class SaasSettingsStore(SettingsStore):
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
# Get the previous max budget to prevent accidental loss
# In Litellm a get always succeeds, regardless of whether the user actually exists
# Get the previous max budget to prevent accidental loss.
#
# LiteLLM v1.80+ returns 404 for non-existent users (previously returned empty user_info)
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
)
response.raise_for_status()
response_json = response.json()
user_info = response_json.get('user_info') or {}
user_info: dict
if response.status_code == 404:
# New user - doesn't exist in LiteLLM yet (v1.80+ behavior)
user_info = {}
else:
# For any other status, use standard error handling
response.raise_for_status()
response_json = response.json()
user_info = response_json.get('user_info') or {}
logger.info(
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
)
@@ -276,7 +342,7 @@ class SaasSettingsStore(SettingsStore):
# Create the new litellm user
response = await self._create_user_in_lite_llm(
client, email, max_budget, spend
client, email, max_budget, spend, llm_model_to_use
)
if not response.is_success:
logger.warning(
@@ -285,7 +351,7 @@ class SaasSettingsStore(SettingsStore):
)
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
response = await self._create_user_in_lite_llm(
client, None, max_budget, spend
client, None, max_budget, spend, llm_model_to_use
)
# User failed to create in litellm - this is an unforseen error state...
@@ -311,11 +377,17 @@ class SaasSettingsStore(SettingsStore):
extra={'user_id': self.user_id},
)
if has_custom:
settings.llm_model = settings.llm_model or get_default_litellm_model()
settings.llm_base_url = settings.llm_base_url or LITE_LLM_API_URL
settings.llm_api_key = settings.llm_api_key or SecretStr(key)
else:
settings.llm_model = get_default_litellm_model()
settings.llm_base_url = LITE_LLM_API_URL
settings.llm_api_key = SecretStr(key)
settings.agent = 'CodeActAgent'
# Use the model corresponding to the current user settings version
settings.llm_model = get_default_litellm_model()
settings.llm_api_key = SecretStr(key)
settings.llm_base_url = LITE_LLM_API_URL
return settings
@classmethod
@@ -398,7 +470,12 @@ class SaasSettingsStore(SettingsStore):
)
async def _create_user_in_lite_llm(
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
self,
client: httpx.AsyncClient,
email: str | None,
max_budget: int,
spend: int,
llm_model: str,
):
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
@@ -413,7 +490,7 @@ class SaasSettingsStore(SettingsStore):
'send_invite_email': False,
'metadata': {
'version': CURRENT_USER_SETTINGS_VERSION,
'model': get_default_litellm_model(),
'model': llm_model,
},
'key_alias': f'OpenHands Cloud - user {self.user_id}',
},
+33 -144
View File
@@ -1,9 +1,15 @@
import asyncio
from typing import cast
from uuid import uuid4
from integrations.gitlab.webhook_installation import (
BreakLoopException,
install_webhook_on_resource,
verify_webhook_conditions,
)
from integrations.types import GitLabResourceType
from integrations.utils import GITLAB_WEBHOOK_URL
from sqlalchemy import text
from storage.database import a_session_maker
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.gitlab_webhook_store import GitlabWebhookStore
@@ -12,20 +18,6 @@ from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.service_types import GitService
CHUNK_SIZE = 100
WEBHOOK_NAME = 'OpenHands Resolver'
SCOPES: list[str] = [
'note_events',
'merge_requests_events',
'confidential_issues_events',
'issues_events',
'confidential_note_events',
'job_events',
'pipeline_events',
]
class BreakLoopException(Exception):
pass
class VerifyWebhookStatus:
@@ -41,77 +33,6 @@ class VerifyWebhookStatus:
if status == WebhookStatus.RATE_LIMITED:
raise BreakLoopException()
async def check_if_resource_exists(
self,
gitlab_service: type[GitService],
resource_type: GitLabResourceType,
resource_id: str,
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
"""
Check if the GitLab resource still exists
"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
does_resource_exist, status = await gitlab_service.check_resource_exists(
resource_type, resource_id
)
logger.info(
'Does resource exists',
extra={
'does_resource_exist': does_resource_exist,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
self.determine_if_rate_limited(status)
if not does_resource_exist and status != WebhookStatus.RATE_LIMITED:
await webhook_store.delete_webhook(webhook)
raise BreakLoopException()
async def check_if_user_has_admin_acccess_to_resource(
self,
gitlab_service: type[GitService],
resource_type: GitLabResourceType,
resource_id: str,
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
"""
Check is user still has permission to resource
"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
(
is_user_admin_of_resource,
status,
) = await gitlab_service.check_user_has_admin_access_to_resource(
resource_type, resource_id
)
logger.info(
'Is user admin',
extra={
'is_user_admin': is_user_admin_of_resource,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
self.determine_if_rate_limited(status)
if not is_user_admin_of_resource:
await webhook_store.delete_webhook(webhook)
raise BreakLoopException()
async def check_if_webhook_already_exists_on_resource(
self,
gitlab_service: type[GitService],
@@ -160,23 +81,8 @@ class VerifyWebhookStatus:
webhook_store: GitlabWebhookStore,
webhook: GitlabWebhook,
):
await self.check_if_resource_exists(
gitlab_service=gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=webhook_store,
webhook=webhook,
)
await self.check_if_user_has_admin_acccess_to_resource(
gitlab_service=gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=webhook_store,
webhook=webhook,
)
await self.check_if_webhook_already_exists_on_resource(
# Use the standalone function
await verify_webhook_conditions(
gitlab_service=gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
@@ -195,51 +101,15 @@ class VerifyWebhookStatus:
"""
Install webhook on resource
"""
from integrations.gitlab.gitlab_service import SaaSGitLabService
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
webhook_secret = f'{webhook.user_id}-{str(uuid4())}'
webhook_uuid = f'{str(uuid4())}'
webhook_id, status = await gitlab_service.install_webhook(
# Use the standalone function
await install_webhook_on_resource(
gitlab_service=gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_name=WEBHOOK_NAME,
webhook_url=GITLAB_WEBHOOK_URL,
webhook_secret=webhook_secret,
webhook_uuid=webhook_uuid,
scopes=SCOPES,
webhook_store=webhook_store,
webhook=webhook,
)
logger.info(
'Creating new webhook',
extra={
'webhook_id': webhook_id,
'status': status,
'resource_id': resource_id,
'resource_type': resource_type,
},
)
self.determine_if_rate_limited(status)
if webhook_id:
await webhook_store.update_webhook(
webhook=webhook,
update_fields={
'webhook_secret': webhook_secret,
'webhook_exists': True, # webhook was created
'webhook_url': GITLAB_WEBHOOK_URL,
'scopes': SCOPES,
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
},
)
logger.info(
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
)
async def install_webhooks(self):
"""
Periodically check the conditions for installing a webhook on resource as valid
@@ -258,6 +128,25 @@ class VerifyWebhookStatus:
from integrations.gitlab.gitlab_service import SaaSGitLabService
# Check if the table exists before proceeding
# This handles cases where the CronJob runs before database migrations complete
async with a_session_maker() as session:
query = text("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = 'gitlab_webhook'
)
""")
result = await session.execute(query)
table_exists = result.scalar() or False
if not table_exists:
logger.info(
'gitlab_webhook table does not exist yet, '
'waiting for database migrations to complete'
)
return
# Get an instance of the webhook store
webhook_store = await GitlabWebhookStore.get_instance()
@@ -0,0 +1,204 @@
"""Unit tests for SaaSGitLabService."""
from unittest.mock import patch
import pytest
from integrations.gitlab.gitlab_service import SaaSGitLabService
@pytest.fixture
def gitlab_service():
"""Create a SaaSGitLabService instance for testing."""
return SaaSGitLabService(external_auth_id='test_user_id')
class TestGetUserResourcesWithAdminAccess:
"""Test cases for get_user_resources_with_admin_access method."""
@pytest.mark.asyncio
async def test_get_resources_single_page_projects_and_groups(self, gitlab_service):
"""Test fetching resources when all data fits in a single page."""
# Arrange
mock_projects = [
{'id': 1, 'name': 'Project 1'},
{'id': 2, 'name': 'Project 2'},
]
mock_groups = [
{'id': 10, 'name': 'Group 1'},
]
with patch.object(gitlab_service, '_make_request') as mock_request:
# First call for projects, second for groups
mock_request.side_effect = [
(mock_projects, {'Link': ''}), # No next page
(mock_groups, {'Link': ''}), # No next page
]
# Act
(
projects,
groups,
) = await gitlab_service.get_user_resources_with_admin_access()
# Assert
assert len(projects) == 2
assert len(groups) == 1
assert projects[0]['id'] == 1
assert projects[1]['id'] == 2
assert groups[0]['id'] == 10
assert mock_request.call_count == 2
@pytest.mark.asyncio
async def test_get_resources_multiple_pages_projects(self, gitlab_service):
"""Test fetching projects across multiple pages."""
# Arrange
page1_projects = [{'id': i, 'name': f'Project {i}'} for i in range(1, 101)]
page2_projects = [{'id': i, 'name': f'Project {i}'} for i in range(101, 151)]
with patch.object(gitlab_service, '_make_request') as mock_request:
mock_request.side_effect = [
(page1_projects, {'Link': '<url>; rel="next"'}), # Has next page
(page2_projects, {'Link': ''}), # Last page
([], {'Link': ''}), # Groups (empty)
]
# Act
(
projects,
groups,
) = await gitlab_service.get_user_resources_with_admin_access()
# Assert
assert len(projects) == 150
assert len(groups) == 0
assert mock_request.call_count == 3
@pytest.mark.asyncio
async def test_get_resources_multiple_pages_groups(self, gitlab_service):
"""Test fetching groups across multiple pages."""
# Arrange
page1_groups = [{'id': i, 'name': f'Group {i}'} for i in range(1, 101)]
page2_groups = [{'id': i, 'name': f'Group {i}'} for i in range(101, 151)]
with patch.object(gitlab_service, '_make_request') as mock_request:
mock_request.side_effect = [
([], {'Link': ''}), # Projects (empty)
(page1_groups, {'Link': '<url>; rel="next"'}), # Has next page
(page2_groups, {'Link': ''}), # Last page
]
# Act
(
projects,
groups,
) = await gitlab_service.get_user_resources_with_admin_access()
# Assert
assert len(projects) == 0
assert len(groups) == 150
assert mock_request.call_count == 3
@pytest.mark.asyncio
async def test_get_resources_empty_response(self, gitlab_service):
"""Test when user has no projects or groups with admin access."""
# Arrange
with patch.object(gitlab_service, '_make_request') as mock_request:
mock_request.side_effect = [
([], {'Link': ''}), # No projects
([], {'Link': ''}), # No groups
]
# Act
(
projects,
groups,
) = await gitlab_service.get_user_resources_with_admin_access()
# Assert
assert len(projects) == 0
assert len(groups) == 0
assert mock_request.call_count == 2
@pytest.mark.asyncio
async def test_get_resources_uses_correct_params_for_projects(self, gitlab_service):
"""Test that projects API is called with correct parameters."""
# Arrange
with patch.object(gitlab_service, '_make_request') as mock_request:
mock_request.side_effect = [
([], {'Link': ''}), # Projects
([], {'Link': ''}), # Groups
]
# Act
await gitlab_service.get_user_resources_with_admin_access()
# Assert
# Check first call (projects)
first_call = mock_request.call_args_list[0]
assert 'projects' in first_call[0][0]
assert first_call[0][1]['membership'] == 1
assert first_call[0][1]['min_access_level'] == 40
assert first_call[0][1]['per_page'] == '100'
@pytest.mark.asyncio
async def test_get_resources_uses_correct_params_for_groups(self, gitlab_service):
"""Test that groups API is called with correct parameters."""
# Arrange
with patch.object(gitlab_service, '_make_request') as mock_request:
mock_request.side_effect = [
([], {'Link': ''}), # Projects
([], {'Link': ''}), # Groups
]
# Act
await gitlab_service.get_user_resources_with_admin_access()
# Assert
# Check second call (groups)
second_call = mock_request.call_args_list[1]
assert 'groups' in second_call[0][0]
assert second_call[0][1]['min_access_level'] == 40
assert second_call[0][1]['top_level_only'] == 'true'
assert second_call[0][1]['per_page'] == '100'
@pytest.mark.asyncio
async def test_get_resources_handles_api_error_gracefully(self, gitlab_service):
"""Test that API errors are handled gracefully and don't crash."""
# Arrange
with patch.object(gitlab_service, '_make_request') as mock_request:
# First call succeeds, second call fails
mock_request.side_effect = [
([{'id': 1, 'name': 'Project 1'}], {'Link': ''}),
Exception('API Error'),
]
# Act
(
projects,
groups,
) = await gitlab_service.get_user_resources_with_admin_access()
# Assert
# Should return what was fetched before the error
assert len(projects) == 1
assert len(groups) == 0
@pytest.mark.asyncio
async def test_get_resources_stops_on_empty_response(self, gitlab_service):
"""Test that pagination stops when API returns empty response."""
# Arrange
with patch.object(gitlab_service, '_make_request') as mock_request:
mock_request.side_effect = [
(None, {'Link': ''}), # Empty response stops pagination
([], {'Link': ''}), # Groups
]
# Act
(
projects,
groups,
) = await gitlab_service.get_user_resources_with_admin_access()
# Assert
assert len(projects) == 0
assert mock_request.call_count == 2 # Should not continue pagination
@@ -18,7 +18,11 @@ from integrations.jira.jira_view import (
from integrations.models import Message, SourceType
from openhands.integrations.service_types import ProviderType, Repository
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
class TestJiraManagerInit:
@@ -732,6 +736,32 @@ class TestStartJob:
call_args = jira_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
self, jira_manager, sample_jira_workspace
):
"""Test job start with session expired error."""
mock_view = MagicMock(spec=JiraNewConversationView)
mock_view.jira_user = MagicMock()
mock_view.jira_user.keycloak_user_id = 'test_user'
mock_view.job_context = MagicMock()
mock_view.job_context.issue_key = 'PROJ-123'
mock_view.jira_workspace = sample_jira_workspace
mock_view.create_or_update_conversation = AsyncMock(
side_effect=SessionExpiredError('Session expired')
)
jira_manager.send_message = AsyncMock()
jira_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
await jira_manager.start_job(mock_view)
# Should send error message about session expired
jira_manager.send_message.assert_called_once()
call_args = jira_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
self, jira_manager, sample_jira_workspace
@@ -18,7 +18,11 @@ from integrations.jira_dc.jira_dc_view import (
from integrations.models import Message, SourceType
from openhands.integrations.service_types import ProviderType, Repository
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
class TestJiraDcManagerInit:
@@ -761,6 +765,32 @@ class TestStartJob:
call_args = jira_dc_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
self, jira_dc_manager, sample_jira_dc_workspace
):
"""Test job start with session expired error."""
mock_view = MagicMock(spec=JiraDcNewConversationView)
mock_view.jira_dc_user = MagicMock()
mock_view.jira_dc_user.keycloak_user_id = 'test_user'
mock_view.job_context = MagicMock()
mock_view.job_context.issue_key = 'PROJ-123'
mock_view.jira_dc_workspace = sample_jira_dc_workspace
mock_view.create_or_update_conversation = AsyncMock(
side_effect=SessionExpiredError('Session expired')
)
jira_dc_manager.send_message = AsyncMock()
jira_dc_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
await jira_dc_manager.start_job(mock_view)
# Should send error message about session expired
jira_dc_manager.send_message.assert_called_once()
call_args = jira_dc_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
self, jira_dc_manager, sample_jira_dc_workspace
@@ -18,7 +18,11 @@ from integrations.linear.linear_view import (
from integrations.models import Message, SourceType
from openhands.integrations.service_types import ProviderType, Repository
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
from openhands.server.types import (
LLMAuthenticationError,
MissingSettingsError,
SessionExpiredError,
)
class TestLinearManagerInit:
@@ -826,6 +830,33 @@ class TestStartJob:
call_args = linear_manager.send_message.call_args[0]
assert 'valid LLM API key' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_session_expired_error(
self, linear_manager, sample_linear_workspace
):
"""Test job start with session expired error."""
mock_view = MagicMock(spec=LinearNewConversationView)
mock_view.linear_user = MagicMock()
mock_view.linear_user.keycloak_user_id = 'test_user'
mock_view.job_context = MagicMock()
mock_view.job_context.issue_key = 'TEST-123'
mock_view.job_context.issue_id = 'issue_id'
mock_view.linear_workspace = sample_linear_workspace
mock_view.create_or_update_conversation = AsyncMock(
side_effect=SessionExpiredError('Session expired')
)
linear_manager.send_message = AsyncMock()
linear_manager.token_manager.decrypt_text.return_value = 'decrypted_key'
await linear_manager.start_job(mock_view)
# Should send error message about session expired
linear_manager.send_message.assert_called_once()
call_args = linear_manager.send_message.call_args[0]
assert 'session has expired' in call_args[0].message
assert 'login again' in call_args[0].message
@pytest.mark.asyncio
async def test_start_job_unexpected_error(
self, linear_manager, sample_linear_workspace
@@ -1,7 +1,14 @@
"""Tests for enterprise integrations utils module."""
from unittest.mock import patch
import pytest
from integrations.utils import get_summary_for_agent_state
from integrations.utils import (
HOST_URL,
append_conversation_footer,
get_session_expired_message,
get_summary_for_agent_state,
)
from openhands.core.schema.agent import AgentState
from openhands.events.observation.agent import AgentStateChangedObservation
@@ -157,3 +164,200 @@ class TestGetSummaryForAgentState:
assert 'try again later' in result.lower()
# RATE_LIMITED doesn't include conversation link in response
assert self.conversation_link not in result
class TestGetSessionExpiredMessage:
"""Test cases for get_session_expired_message function."""
def test_message_with_username_contains_at_prefix(self):
"""Test that the message contains the username with @ prefix."""
result = get_session_expired_message('testuser')
assert '@testuser' in result
def test_message_with_username_contains_session_expired_text(self):
"""Test that the message contains session expired text."""
result = get_session_expired_message('testuser')
assert 'session has expired' in result
def test_message_with_username_contains_login_instruction(self):
"""Test that the message contains login instruction."""
result = get_session_expired_message('testuser')
assert 'login again' in result
def test_message_with_username_contains_host_url(self):
"""Test that the message contains the OpenHands Cloud URL."""
result = get_session_expired_message('testuser')
assert HOST_URL in result
assert 'OpenHands Cloud' in result
def test_different_usernames(self):
"""Test that different usernames produce different messages."""
result1 = get_session_expired_message('user1')
result2 = get_session_expired_message('user2')
assert '@user1' in result1
assert '@user2' in result2
assert '@user1' not in result2
assert '@user2' not in result1
def test_message_without_username_contains_session_expired_text(self):
"""Test that the message without username contains session expired text."""
result = get_session_expired_message()
assert 'session has expired' in result
def test_message_without_username_contains_login_instruction(self):
"""Test that the message without username contains login instruction."""
result = get_session_expired_message()
assert 'login again' in result
def test_message_without_username_contains_host_url(self):
"""Test that the message without username contains the OpenHands Cloud URL."""
result = get_session_expired_message()
assert HOST_URL in result
assert 'OpenHands Cloud' in result
def test_message_without_username_does_not_contain_at_prefix(self):
"""Test that the message without username does not contain @ prefix."""
result = get_session_expired_message()
assert not result.startswith('@')
assert 'Your session' in result
def test_message_with_none_username(self):
"""Test that passing None explicitly works the same as no argument."""
result = get_session_expired_message(None)
assert not result.startswith('@')
assert 'Your session' in result
class TestAppendConversationFooter:
"""Test cases for append_conversation_footer function."""
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_appends_footer_with_markdown_link(self):
"""Test that footer is appended with correct markdown link format."""
# Arrange
message = 'This is a test message'
conversation_id = 'test-conv-123'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert result.startswith(message)
assert (
'[View full conversation](https://example.com/conversations/test-conv-123)'
in result
)
assert result.endswith(
'[View full conversation](https://example.com/conversations/test-conv-123)'
)
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_footer_does_not_contain_html_tags(self):
"""Test that footer does not contain HTML tags like <sub>."""
# Arrange
message = 'Test message'
conversation_id = 'test-conv-456'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert '<sub>' not in result
assert '</sub>' not in result
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_footer_format_with_newlines(self):
"""Test that footer is properly separated with newlines."""
# Arrange
message = 'Original message content'
conversation_id = 'test-conv-789'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert (
result
== 'Original message content\n\n[View full conversation](https://example.com/conversations/test-conv-789)'
)
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_empty_message_still_appends_footer(self):
"""Test that footer is appended even when message is empty."""
# Arrange
message = ''
conversation_id = 'empty-msg-conv'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert result.startswith('\n\n')
assert (
'[View full conversation](https://example.com/conversations/empty-msg-conv)'
in result
)
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_conversation_id_with_special_characters(self):
"""Test that footer handles conversation IDs with special characters."""
# Arrange
message = 'Test message'
conversation_id = 'conv-123_abc-456'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
expected_url = 'https://example.com/conversations/conv-123_abc-456'
assert expected_url in result
assert '[View full conversation]' in result
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_multiline_message_preserves_content(self):
"""Test that multiline messages are preserved correctly."""
# Arrange
message = 'Line 1\nLine 2\nLine 3'
conversation_id = 'multiline-conv'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
assert result.startswith('Line 1\nLine 2\nLine 3')
assert '\n\n[View full conversation]' in result
assert message in result
@patch(
'integrations.utils.CONVERSATION_URL', 'https://example.com/conversations/{}'
)
def test_footer_contains_only_markdown_syntax(self):
"""Test that footer uses only markdown syntax, not HTML."""
# Arrange
message = 'Test message'
conversation_id = 'markdown-test'
# Act
result = append_conversation_footer(message, conversation_id)
# Assert
footer_part = result[len(message) :]
# Should only contain markdown link syntax: [text](url)
assert footer_part.startswith('\n\n[')
assert '](' in footer_part
assert footer_part.endswith(')')
# Should not contain any HTML tags (specifically <sub> tags that were removed)
assert '<sub>' not in footer_part
assert '</sub>' not in footer_part
@@ -0,0 +1,330 @@
"""Unit tests for API keys routes, focusing on BYOR key validation and retrieval."""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from fastapi import HTTPException
from server.routes.api_keys import (
get_llm_api_key_for_byor,
verify_byor_key_in_litellm,
)
class TestVerifyByorKeyInLitellm:
"""Test the verify_byor_key_in_litellm function."""
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_valid_key_returns_true(self, mock_client_class):
"""Test that a valid key (200 response) returns True."""
# Arrange
byor_key = 'sk-valid-key-123'
user_id = 'user-123'
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.is_success = True
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is True
mock_client.get.assert_called_once_with(
'https://litellm.example.com/v1/models',
headers={'Authorization': f'Bearer {byor_key}'},
)
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_invalid_key_401_returns_false(self, mock_client_class):
"""Test that an invalid key (401 response) returns False."""
# Arrange
byor_key = 'sk-invalid-key-123'
user_id = 'user-123'
mock_response = MagicMock()
mock_response.status_code = 401
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_invalid_key_403_returns_false(self, mock_client_class):
"""Test that an invalid key (403 response) returns False."""
# Arrange
byor_key = 'sk-forbidden-key-123'
user_id = 'user-123'
mock_response = MagicMock()
mock_response.status_code = 403
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_server_error_returns_false(self, mock_client_class):
"""Test that a server error (500) returns False to ensure key validity."""
# Arrange
byor_key = 'sk-key-123'
user_id = 'user-123'
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.is_success = False
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.return_value = mock_response
mock_client_class.return_value = mock_client
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_timeout_returns_false(self, mock_client_class):
"""Test that a timeout returns False to ensure key validity."""
# Arrange
byor_key = 'sk-key-123'
user_id = 'user-123'
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.side_effect = httpx.TimeoutException('Request timed out')
mock_client_class.return_value = mock_client
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
@patch('server.routes.api_keys.httpx.AsyncClient')
async def test_verify_network_error_returns_false(self, mock_client_class):
"""Test that a network error returns False to ensure key validity."""
# Arrange
byor_key = 'sk-key-123'
user_id = 'user-123'
mock_client = AsyncMock()
mock_client.__aenter__.return_value = mock_client
mock_client.__aexit__.return_value = None
mock_client.get.side_effect = httpx.NetworkError('Network error')
mock_client_class.return_value = mock_client
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', None)
async def test_verify_missing_api_url_returns_false(self):
"""Test that missing LITE_LLM_API_URL returns False."""
# Arrange
byor_key = 'sk-key-123'
user_id = 'user-123'
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
@pytest.mark.asyncio
@patch('server.routes.api_keys.LITE_LLM_API_URL', 'https://litellm.example.com')
async def test_verify_empty_key_returns_false(self):
"""Test that empty key returns False."""
# Arrange
byor_key = ''
user_id = 'user-123'
# Act
result = await verify_byor_key_in_litellm(byor_key, user_id)
# Assert
assert result is False
class TestGetLlmApiKeyForByor:
"""Test the get_llm_api_key_for_byor endpoint."""
@pytest.mark.asyncio
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_no_key_in_database_generates_new(
self, mock_get_key, mock_generate_key, mock_store_key
):
"""Test that when no key exists in database, a new one is generated."""
# Arrange
user_id = 'user-123'
new_key = 'sk-new-generated-key'
mock_get_key.return_value = None
mock_generate_key.return_value = new_key
mock_store_key.return_value = None
# Act
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
mock_get_key.assert_called_once_with(user_id)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_valid_key_in_database_returns_key(
self, mock_get_key, mock_verify_key
):
"""Test that when a valid key exists in database, it is returned."""
# Arrange
user_id = 'user-123'
existing_key = 'sk-existing-valid-key'
mock_get_key.return_value = existing_key
mock_verify_key.return_value = True
# Act
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': existing_key}
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(existing_key, user_id)
@pytest.mark.asyncio
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_invalid_key_in_database_regenerates(
self,
mock_get_key,
mock_verify_key,
mock_delete_key,
mock_generate_key,
mock_store_key,
):
"""Test that when an invalid key exists in database, it is regenerated."""
# Arrange
user_id = 'user-123'
invalid_key = 'sk-invalid-key'
new_key = 'sk-new-generated-key'
mock_get_key.return_value = invalid_key
mock_verify_key.return_value = False
mock_delete_key.return_value = True
mock_generate_key.return_value = new_key
mock_store_key.return_value = None
# Act
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
mock_get_key.assert_called_once_with(user_id)
mock_verify_key.assert_called_once_with(invalid_key, user_id)
mock_delete_key.assert_called_once_with(user_id, invalid_key)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.store_byor_key_in_db')
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.delete_byor_key_from_litellm')
@patch('server.routes.api_keys.verify_byor_key_in_litellm')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_invalid_key_deletion_failure_still_regenerates(
self,
mock_get_key,
mock_verify_key,
mock_delete_key,
mock_generate_key,
mock_store_key,
):
"""Test that even if deletion fails, regeneration still proceeds."""
# Arrange
user_id = 'user-123'
invalid_key = 'sk-invalid-key'
new_key = 'sk-new-generated-key'
mock_get_key.return_value = invalid_key
mock_verify_key.return_value = False
mock_delete_key.return_value = False # Deletion fails
mock_generate_key.return_value = new_key
mock_store_key.return_value = None
# Act
result = await get_llm_api_key_for_byor(user_id=user_id)
# Assert
assert result == {'key': new_key}
mock_delete_key.assert_called_once_with(user_id, invalid_key)
mock_generate_key.assert_called_once_with(user_id)
mock_store_key.assert_called_once_with(user_id, new_key)
@pytest.mark.asyncio
@patch('server.routes.api_keys.generate_byor_key')
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_key_generation_failure_raises_exception(
self, mock_get_key, mock_generate_key
):
"""Test that when key generation fails, an HTTPException is raised."""
# Arrange
user_id = 'user-123'
mock_get_key.return_value = None
mock_generate_key.return_value = None
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_llm_api_key_for_byor(user_id=user_id)
assert exc_info.value.status_code == 500
assert 'Failed to generate new BYOR LLM API key' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.api_keys.get_byor_key_from_db')
async def test_database_error_raises_exception(self, mock_get_key):
"""Test that database errors are properly handled."""
# Arrange
user_id = 'user-123'
mock_get_key.side_effect = Exception('Database connection error')
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_llm_api_key_for_byor(user_id=user_id)
assert exc_info.value.status_code == 500
assert 'Failed to retrieve BYOR LLM API key' in exc_info.value.detail
@@ -0,0 +1,361 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import SecretStr
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.email import (
ResendEmailVerificationRequest,
resend_email_verification,
verified_email,
verify_email,
)
@pytest.fixture
def mock_request():
"""Create a mock request object."""
request = MagicMock(spec=Request)
request.url = MagicMock()
request.url.hostname = 'localhost'
request.url.netloc = 'localhost:8000'
request.url.path = '/api/email/verified'
request.base_url = 'http://localhost:8000/'
request.headers = {}
request.cookies = {}
request.query_params = MagicMock()
return request
@pytest.fixture
def mock_user_auth():
"""Create a mock SaasUserAuth object."""
auth = MagicMock(spec=SaasUserAuth)
auth.access_token = SecretStr('test_access_token')
auth.refresh_token = SecretStr('test_refresh_token')
auth.email = 'test@example.com'
auth.email_verified = False
auth.accepted_tos = True
auth.refresh = AsyncMock()
return auth
@pytest.mark.asyncio
async def test_verify_email_default_behavior(mock_request):
"""Test verify_email with default is_auth_flow=False."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
# Act
with patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
):
await verify_email(request=mock_request, user_id=user_id)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
assert call_args.kwargs['user_id'] == user_id
assert (
call_args.kwargs['redirect_uri'] == 'http://localhost:8000/api/email/verified'
)
assert 'client_id' in call_args.kwargs
@pytest.mark.asyncio
async def test_verify_email_with_auth_flow(mock_request):
"""Test verify_email with is_auth_flow=True."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
# Act
with patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
):
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
assert call_args.kwargs['user_id'] == user_id
assert (
call_args.kwargs['redirect_uri'] == 'http://localhost:8000?email_verified=true'
)
assert 'client_id' in call_args.kwargs
@pytest.mark.asyncio
async def test_verify_email_https_scheme(mock_request):
"""Test verify_email uses https scheme for non-localhost hosts."""
# Arrange
user_id = 'test_user_id'
mock_request.url.hostname = 'example.com'
mock_request.url.netloc = 'example.com'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
# Act
with patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
):
await verify_email(request=mock_request, user_id=user_id, is_auth_flow=True)
# Assert
call_args = mock_keycloak_admin.a_send_verify_email.call_args
assert call_args.kwargs['redirect_uri'].startswith('https://')
@pytest.mark.asyncio
async def test_verified_email_default_redirect(mock_request, mock_user_auth):
"""Test verified_email redirects to /settings/user by default."""
# Arrange
mock_request.query_params.get.return_value = None
# Act
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
):
result = await verified_email(mock_request)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert result.headers['location'] == 'http://localhost:8000/settings/user'
mock_user_auth.refresh.assert_called_once()
mock_set_cookie.assert_called_once()
assert mock_user_auth.email_verified is True
@pytest.mark.asyncio
async def test_verified_email_https_scheme(mock_request, mock_user_auth):
"""Test verified_email uses https scheme for non-localhost hosts."""
# Arrange
mock_request.url.hostname = 'example.com'
mock_request.url.netloc = 'example.com'
mock_request.query_params.get.return_value = None
# Act
with (
patch('server.routes.email.get_user_auth', return_value=mock_user_auth),
patch('server.routes.email.set_response_cookie') as mock_set_cookie,
):
result = await verified_email(mock_request)
# Assert
assert isinstance(result, RedirectResponse)
assert result.headers['location'].startswith('https://')
mock_set_cookie.assert_called_once()
# Verify secure flag is True for https
call_kwargs = mock_set_cookie.call_args.kwargs
assert call_kwargs['secure'] is True
@pytest.mark.asyncio
async def test_resend_email_verification_with_user_id_from_body_succeeds(mock_request):
"""Test resend_email_verification succeeds when user_id is provided in body."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
patch('server.routes.email.logger') as mock_logger,
):
mock_rate_limit.return_value = None # Rate limit check passes
# Act
result = await resend_email_verification(request=mock_request, body=body)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
assert 'message' in result.body.decode()
mock_rate_limit.assert_called_once_with(
request=mock_request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60,
)
mock_keycloak_admin.a_send_verify_email.assert_called_once()
# Logger is called multiple times (verify_email and resend_email_verification)
# Check that the resend message was logged
assert any(
'Resending verification email for' in str(call)
for call in mock_logger.info.call_args_list
)
@pytest.mark.asyncio
async def test_resend_email_verification_with_user_id_from_auth_succeeds(mock_request):
"""Test resend_email_verification succeeds when user_id comes from authentication."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch(
'server.routes.email.get_user_id', return_value=user_id
) as mock_get_user_id,
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None # Rate limit check passes
# Act
result = await resend_email_verification(request=mock_request, body=None)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
mock_get_user_id.assert_called_once_with(mock_request)
mock_rate_limit.assert_called_once_with(
request=mock_request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60,
)
@pytest.mark.asyncio
async def test_resend_email_verification_without_user_id_returns_400(mock_request):
"""Test resend_email_verification returns 400 when user_id is not available."""
# Arrange
with patch(
'server.routes.email.get_user_id', side_effect=Exception('Not authenticated')
):
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await resend_email_verification(request=mock_request, body=None)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'user_id is required' in exc_info.value.detail
@pytest.mark.asyncio
async def test_resend_email_verification_rate_limit_exceeded_returns_429(mock_request):
"""Test resend_email_verification returns 429 when rate limit is exceeded."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id)
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
):
mock_rate_limit.side_effect = HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail='Too many requests. Please wait 2 minutes before trying again.',
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await resend_email_verification(request=mock_request, body=body)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert 'Too many requests' in exc_info.value.detail
mock_rate_limit.assert_called_once()
@pytest.mark.asyncio
async def test_resend_email_verification_with_is_auth_flow_true(mock_request):
"""Test resend_email_verification passes is_auth_flow to verify_email."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=True)
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None
# Act
await resend_email_verification(request=mock_request, body=body)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
# Verify that verify_email was called with is_auth_flow=True
# We check this indirectly by verifying the redirect_uri
assert 'email_verified=true' in call_args.kwargs['redirect_uri']
@pytest.mark.asyncio
async def test_resend_email_verification_with_is_auth_flow_false(mock_request):
"""Test resend_email_verification uses default is_auth_flow=False when not specified."""
# Arrange
user_id = 'test_user_id'
body = ResendEmailVerificationRequest(user_id=user_id, is_auth_flow=False)
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None
# Act
await resend_email_verification(request=mock_request, body=body)
# Assert
mock_keycloak_admin.a_send_verify_email.assert_called_once()
call_args = mock_keycloak_admin.a_send_verify_email.call_args
# Verify that verify_email was called with is_auth_flow=False
assert '/api/email/verified' in call_args.kwargs['redirect_uri']
@pytest.mark.asyncio
async def test_resend_email_verification_body_none_uses_auth(mock_request):
"""Test resend_email_verification uses auth when body is None."""
# Arrange
user_id = 'test_user_id'
mock_keycloak_admin = AsyncMock()
mock_keycloak_admin.a_send_verify_email = AsyncMock()
with (
patch(
'server.routes.email.get_user_id', return_value=user_id
) as mock_get_user_id,
patch('server.routes.email.check_rate_limit_by_user_id') as mock_rate_limit,
patch(
'server.routes.email.get_keycloak_admin', return_value=mock_keycloak_admin
),
):
mock_rate_limit.return_value = None
# Act
result = await resend_email_verification(request=mock_request, body=None)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
mock_get_user_id.assert_called_once()
mock_rate_limit.assert_called_once_with(
request=mock_request,
key_prefix='email_resend',
user_id=user_id,
user_rate_limit_seconds=30,
ip_rate_limit_seconds=60,
)
@@ -0,0 +1,502 @@
"""Unit tests for GitLab integration routes."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, status
from integrations.gitlab.gitlab_service import SaaSGitLabService
from integrations.gitlab.webhook_installation import BreakLoopException
from integrations.types import GitLabResourceType
from server.routes.integration.gitlab import (
ReinstallWebhookRequest,
ResourceIdentifier,
get_gitlab_resources,
reinstall_gitlab_webhook,
)
from storage.gitlab_webhook import GitlabWebhook
@pytest.fixture
def mock_gitlab_service():
"""Create a mock SaaSGitLabService instance."""
service = MagicMock(spec=SaaSGitLabService)
service.get_user_resources_with_admin_access = AsyncMock(
return_value=(
[
{
'id': 1,
'name': 'Test Project',
'path_with_namespace': 'user/test-project',
'namespace': {'kind': 'user'},
},
{
'id': 2,
'name': 'Group Project',
'path_with_namespace': 'group/group-project',
'namespace': {'kind': 'group'},
},
],
[
{
'id': 10,
'name': 'Test Group',
'full_path': 'test-group',
},
],
)
)
service.check_webhook_exists_on_resource = AsyncMock(return_value=(True, None))
service.check_user_has_admin_access_to_resource = AsyncMock(
return_value=(True, None)
)
return service
@pytest.fixture
def mock_webhook():
"""Create a mock webhook object."""
webhook = MagicMock(spec=GitlabWebhook)
webhook.webhook_uuid = 'test-uuid'
webhook.last_synced = None
return webhook
class TestGetGitLabResources:
"""Test cases for get_gitlab_resources endpoint."""
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_get_resources_success(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_gitlab_service,
):
"""Test successfully retrieving GitLab resources with webhook status."""
# Arrange
user_id = 'test_user_id'
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.get_webhooks_by_resources = AsyncMock(
return_value=({}, {}) # Empty maps for simplicity
)
# Act
response = await get_gitlab_resources(user_id=user_id)
# Assert
assert len(response.resources) == 2 # 1 project (filtered) + 1 group
assert response.resources[0].type == 'project'
assert response.resources[0].id == '1'
assert response.resources[0].name == 'Test Project'
assert response.resources[1].type == 'group'
assert response.resources[1].id == '10'
mock_gitlab_service.get_user_resources_with_admin_access.assert_called_once()
mock_webhook_store.get_webhooks_by_resources.assert_called_once()
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_get_resources_filters_nested_projects(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_gitlab_service,
):
"""Test that projects nested under groups are filtered out."""
# Arrange
user_id = 'test_user_id'
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.get_webhooks_by_resources = AsyncMock(return_value=({}, {}))
# Act
response = await get_gitlab_resources(user_id=user_id)
# Assert
# Should only include the user project, not the group project
project_resources = [r for r in response.resources if r.type == 'project']
assert len(project_resources) == 1
assert project_resources[0].id == '1'
assert project_resources[0].name == 'Test Project'
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_get_resources_includes_webhook_metadata(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_gitlab_service,
mock_webhook,
):
"""Test that webhook metadata is included in the response."""
# Arrange
user_id = 'test_user_id'
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.get_webhooks_by_resources = AsyncMock(
return_value=({'1': mock_webhook}, {'10': mock_webhook})
)
# Act
response = await get_gitlab_resources(user_id=user_id)
# Assert
assert response.resources[0].webhook_uuid == 'test-uuid'
assert response.resources[1].webhook_uuid == 'test-uuid'
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
async def test_get_resources_non_saas_service(
self, mock_gitlab_service_impl, mock_gitlab_service
):
"""Test that non-SaaS GitLab service raises an error."""
# Arrange
user_id = 'test_user_id'
non_saas_service = AsyncMock()
mock_gitlab_service_impl.return_value = non_saas_service
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await get_gitlab_resources(user_id=user_id)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'Only SaaS GitLab service is supported' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_get_resources_parallel_api_calls(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_gitlab_service,
):
"""Test that webhook status checks are made in parallel."""
# Arrange
user_id = 'test_user_id'
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.get_webhooks_by_resources = AsyncMock(return_value=({}, {}))
call_count = 0
async def track_calls(*args, **kwargs):
nonlocal call_count
call_count += 1
return (True, None)
mock_gitlab_service.check_webhook_exists_on_resource = AsyncMock(
side_effect=track_calls
)
# Act
await get_gitlab_resources(user_id=user_id)
# Assert
# Should be called for each resource (1 project + 1 group)
assert call_count == 2
class TestReinstallGitLabWebhook:
"""Test cases for reinstall_gitlab_webhook endpoint."""
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.install_webhook_on_resource')
@patch('server.routes.integration.gitlab.verify_webhook_conditions')
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_reinstall_webhook_success_existing_webhook(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_verify_conditions,
mock_install_webhook,
mock_gitlab_service,
mock_webhook,
):
"""Test successful webhook reinstallation when webhook record exists."""
# Arrange
user_id = 'test_user_id'
resource_id = 'project-123'
resource_type = GitLabResourceType.PROJECT
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.reset_webhook_for_reinstallation_by_resource = AsyncMock(
return_value=True
)
mock_webhook_store.get_webhook_by_resource_only = AsyncMock(
return_value=mock_webhook
)
mock_verify_conditions.return_value = None
mock_install_webhook.return_value = ('webhook-id-123', None)
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act
result = await reinstall_gitlab_webhook(body=body, user_id=user_id)
# Assert
assert result.success is True
assert result.resource_id == resource_id
assert result.resource_type == resource_type.value
assert result.error is None
mock_gitlab_service.check_user_has_admin_access_to_resource.assert_called_once_with(
resource_type, resource_id
)
mock_webhook_store.reset_webhook_for_reinstallation_by_resource.assert_called_once_with(
resource_type, resource_id, user_id
)
mock_verify_conditions.assert_called_once()
mock_install_webhook.assert_called_once()
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.install_webhook_on_resource')
@patch('server.routes.integration.gitlab.verify_webhook_conditions')
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_reinstall_webhook_success_new_webhook_record(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_verify_conditions,
mock_install_webhook,
mock_gitlab_service,
):
"""Test successful webhook reinstallation when webhook record doesn't exist."""
# Arrange
user_id = 'test_user_id'
resource_id = 'project-456'
resource_type = GitLabResourceType.PROJECT
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.reset_webhook_for_reinstallation_by_resource = (
AsyncMock(return_value=False) # No existing webhook to reset
)
mock_webhook_store.get_webhook_by_resource_only = AsyncMock(
side_effect=[
None,
MagicMock(),
] # First call returns None, second returns new webhook
)
mock_webhook_store.store_webhooks = AsyncMock()
mock_verify_conditions.return_value = None
mock_install_webhook.return_value = ('webhook-id-456', None)
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act
result = await reinstall_gitlab_webhook(body=body, user_id=user_id)
# Assert
assert result.success is True
mock_webhook_store.store_webhooks.assert_called_once()
# Should fetch webhook twice: once to check, once after creating
assert mock_webhook_store.get_webhook_by_resource_only.call_count == 2
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.isinstance')
async def test_reinstall_webhook_no_admin_access(
self, mock_isinstance, mock_gitlab_service_impl, mock_gitlab_service
):
"""Test reinstallation when user doesn't have admin access."""
# Arrange
user_id = 'test_user_id'
resource_id = 'project-789'
resource_type = GitLabResourceType.PROJECT
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_gitlab_service.check_user_has_admin_access_to_resource = AsyncMock(
return_value=(False, None)
)
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await reinstall_gitlab_webhook(body=body, user_id=user_id)
assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN
assert 'does not have admin access' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
async def test_reinstall_webhook_non_saas_service(self, mock_gitlab_service_impl):
"""Test reinstallation with non-SaaS GitLab service."""
# Arrange
user_id = 'test_user_id'
resource_id = 'project-999'
resource_type = GitLabResourceType.PROJECT
non_saas_service = AsyncMock()
mock_gitlab_service_impl.return_value = non_saas_service
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await reinstall_gitlab_webhook(body=body, user_id=user_id)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'Only SaaS GitLab service is supported' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.install_webhook_on_resource')
@patch('server.routes.integration.gitlab.verify_webhook_conditions')
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_reinstall_webhook_conditions_not_met(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_verify_conditions,
mock_install_webhook,
mock_gitlab_service,
mock_webhook,
):
"""Test reinstallation when webhook conditions are not met."""
# Arrange
user_id = 'test_user_id'
resource_id = 'project-111'
resource_type = GitLabResourceType.PROJECT
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.reset_webhook_for_reinstallation_by_resource = AsyncMock(
return_value=True
)
mock_webhook_store.get_webhook_by_resource_only = AsyncMock(
return_value=mock_webhook
)
mock_verify_conditions.side_effect = BreakLoopException()
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await reinstall_gitlab_webhook(body=body, user_id=user_id)
assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST
assert 'conditions not met' in exc_info.value.detail.lower()
mock_install_webhook.assert_not_called()
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.install_webhook_on_resource')
@patch('server.routes.integration.gitlab.verify_webhook_conditions')
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_reinstall_webhook_installation_fails(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_verify_conditions,
mock_install_webhook,
mock_gitlab_service,
mock_webhook,
):
"""Test reinstallation when webhook installation fails."""
# Arrange
user_id = 'test_user_id'
resource_id = 'project-222'
resource_type = GitLabResourceType.PROJECT
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.reset_webhook_for_reinstallation_by_resource = AsyncMock(
return_value=True
)
mock_webhook_store.get_webhook_by_resource_only = AsyncMock(
return_value=mock_webhook
)
mock_verify_conditions.return_value = None
mock_install_webhook.return_value = (None, None) # Installation failed
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await reinstall_gitlab_webhook(body=body, user_id=user_id)
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert 'Failed to install webhook' in exc_info.value.detail
@pytest.mark.asyncio
@patch('server.routes.integration.gitlab.install_webhook_on_resource')
@patch('server.routes.integration.gitlab.verify_webhook_conditions')
@patch('server.routes.integration.gitlab.GitLabServiceImpl')
@patch('server.routes.integration.gitlab.webhook_store')
@patch('server.routes.integration.gitlab.isinstance')
async def test_reinstall_webhook_group_resource(
self,
mock_isinstance,
mock_webhook_store,
mock_gitlab_service_impl,
mock_verify_conditions,
mock_install_webhook,
mock_gitlab_service,
mock_webhook,
):
"""Test reinstallation for a group resource."""
# Arrange
user_id = 'test_user_id'
resource_id = 'group-333'
resource_type = GitLabResourceType.GROUP
mock_gitlab_service_impl.return_value = mock_gitlab_service
mock_isinstance.return_value = True
mock_webhook_store.reset_webhook_for_reinstallation_by_resource = AsyncMock(
return_value=True
)
mock_webhook_store.get_webhook_by_resource_only = AsyncMock(
return_value=mock_webhook
)
mock_verify_conditions.return_value = None
mock_install_webhook.return_value = ('webhook-id-group', None)
body = ReinstallWebhookRequest(
resource=ResourceIdentifier(type=resource_type, id=resource_id)
)
# Act
result = await reinstall_gitlab_webhook(body=body, user_id=user_id)
# Assert
assert result.success is True
assert result.resource_id == resource_id
assert result.resource_type == resource_type.value
mock_webhook_store.reset_webhook_for_reinstallation_by_resource.assert_called_once_with(
resource_type, resource_id, user_id
)
@@ -699,12 +699,11 @@ class TestProcessBatchOperationsBackground:
# Should not raise exceptions
await _process_batch_operations_background(batch_ops, 'test-api-key')
# Should log the error
mock_logger.error.assert_called_once_with(
'error_processing_batch_operation',
extra={
'path': 'invalid-path',
'method': 'BatchMethod.POST',
'error': mock_logger.error.call_args[1]['extra']['error'],
},
)
# Should log the error with exception type and message in the log message
mock_logger.error.assert_called_once()
call_args = mock_logger.error.call_args
log_message = call_args[0][0]
assert log_message.startswith('error_processing_batch_operation:')
assert call_args[1]['extra']['path'] == 'invalid-path'
assert call_args[1]['extra']['method'] == 'BatchMethod.POST'
assert call_args[1]['exc_info'] is True
@@ -0,0 +1,290 @@
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException, Request, status
from server.utils.rate_limit_utils import (
RATE_LIMIT_IP_SECONDS,
RATE_LIMIT_USER_SECONDS,
check_rate_limit_by_user_id,
)
@pytest.fixture
def mock_request():
"""Create a mock request object."""
request = MagicMock(spec=Request)
request.client = MagicMock()
request.client.host = '192.168.1.1'
return request
@pytest.fixture
def mock_redis():
"""Create a mock Redis client."""
redis = AsyncMock()
redis.set = AsyncMock(return_value=True) # First call succeeds (key doesn't exist)
return redis
@pytest.mark.asyncio
async def test_rate_limit_by_user_id_first_request_succeeds(mock_request, mock_redis):
"""Test that first request with user_id succeeds and sets rate limit key."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'email_resend'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
)
mock_logger.warning.assert_not_called()
mock_logger.info.assert_not_called()
@pytest.mark.asyncio
async def test_rate_limit_by_user_id_second_request_within_window_fails(
mock_request, mock_redis
):
"""Test that second request with same user_id within rate limit window fails."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'email_resend'
mock_redis.set = AsyncMock(return_value=False) # Key already exists
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert 'Too many requests' in exc_info.value.detail
assert f'{RATE_LIMIT_USER_SECONDS // 60} minutes' in exc_info.value.detail
mock_logger.info.assert_called_once()
@pytest.mark.asyncio
async def test_rate_limit_by_ip_when_user_id_is_none(mock_request, mock_redis):
"""Test that rate limiting falls back to IP address when user_id is None."""
# Arrange
key_prefix = 'email_resend'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:ip:{mock_request.client.host}',
1,
nx=True,
ex=RATE_LIMIT_IP_SECONDS,
)
mock_logger.warning.assert_not_called()
@pytest.mark.asyncio
async def test_rate_limit_by_ip_second_request_within_window_fails(
mock_request, mock_redis
):
"""Test that second request from same IP within rate limit window fails."""
# Arrange
key_prefix = 'email_resend'
mock_redis.set = AsyncMock(return_value=False) # Key already exists
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
):
mock_sio.manager.redis = mock_redis
# Act & Assert
with pytest.raises(HTTPException) as exc_info:
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
)
assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS
assert f'{RATE_LIMIT_IP_SECONDS // 60} minutes' in exc_info.value.detail
@pytest.mark.asyncio
async def test_rate_limit_redis_unavailable_fails_open(mock_request):
"""Test that rate limiting fails open when Redis is unavailable."""
# Arrange
key_prefix = 'email_resend'
user_id = 'test_user_id'
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = None # Redis unavailable
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_logger.warning.assert_called_once_with(
'Redis unavailable for rate limiting, allowing request'
)
@pytest.mark.asyncio
async def test_rate_limit_redis_exception_fails_open(mock_request, mock_redis):
"""Test that rate limiting fails open when Redis raises an exception."""
# Arrange
key_prefix = 'email_resend'
user_id = 'test_user_id'
mock_redis.set = AsyncMock(side_effect=Exception('Redis connection error'))
with (
patch('server.utils.rate_limit_utils.sio') as mock_sio,
patch('server.utils.rate_limit_utils.logger') as mock_logger,
):
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_logger.warning.assert_called_once()
assert 'Error checking rate limit' in str(mock_logger.warning.call_args[0][0])
@pytest.mark.asyncio
async def test_rate_limit_custom_key_prefix(mock_request, mock_redis):
"""Test that different key prefixes create different rate limit keys."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'password_reset'
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:{user_id}', 1, nx=True, ex=RATE_LIMIT_USER_SECONDS
)
@pytest.mark.asyncio
async def test_rate_limit_custom_rate_limit_seconds(mock_request, mock_redis):
"""Test that custom rate limit seconds are used correctly."""
# Arrange
user_id = 'test_user_id'
key_prefix = 'email_resend'
custom_user_seconds = 60
custom_ip_seconds = 180
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request,
key_prefix=key_prefix,
user_id=user_id,
user_rate_limit_seconds=custom_user_seconds,
ip_rate_limit_seconds=custom_ip_seconds,
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:{user_id}', 1, nx=True, ex=custom_user_seconds
)
@pytest.mark.asyncio
async def test_rate_limit_ip_with_unknown_client(mock_request, mock_redis):
"""Test that rate limiting handles missing client host gracefully."""
# Arrange
key_prefix = 'email_resend'
mock_request.client = None # No client information
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=None
)
# Assert
mock_redis.set.assert_called_once_with(
f'{key_prefix}:ip:unknown', 1, nx=True, ex=RATE_LIMIT_IP_SECONDS
)
@pytest.mark.asyncio
async def test_rate_limit_different_users_have_separate_limits(
mock_request, mock_redis
):
"""Test that different user_ids have separate rate limit keys."""
# Arrange
key_prefix = 'email_resend'
user_id_1 = 'user_1'
user_id_2 = 'user_2'
with patch('server.utils.rate_limit_utils.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
# Act
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id_1
)
await check_rate_limit_by_user_id(
request=mock_request, key_prefix=key_prefix, user_id=user_id_2
)
# Assert
assert mock_redis.set.call_count == 2
# Extract call arguments properly
call_args_list = [
(call[0][0], call[0][1], call[1]['nx'], call[1]['ex'])
for call in mock_redis.set.call_args_list
]
assert (
f'{key_prefix}:{user_id_1}',
1,
True,
RATE_LIMIT_USER_SECONDS,
) in call_args_list
assert (
f'{key_prefix}:{user_id_2}',
1,
True,
RATE_LIMIT_USER_SECONDS,
) in call_args_list
@@ -0,0 +1,388 @@
"""Unit tests for GitlabWebhookStore."""
import pytest
from integrations.types import GitLabResourceType
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from storage.base import Base
from storage.gitlab_webhook import GitlabWebhook
from storage.gitlab_webhook_store import GitlabWebhookStore
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session_maker(async_engine):
"""Create an async session maker for testing."""
return async_sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False)
@pytest.fixture
async def webhook_store(async_session_maker):
"""Create a GitlabWebhookStore instance for testing."""
return GitlabWebhookStore(a_session_maker=async_session_maker)
@pytest.fixture
async def sample_webhooks(async_session_maker):
"""Create sample webhook records for testing."""
async with async_session_maker() as session:
# Create webhooks for user_1
webhook1 = GitlabWebhook(
project_id='project-1',
group_id=None,
user_id='user_1',
webhook_exists=True,
webhook_url='https://example.com/webhook',
webhook_secret='secret-1',
webhook_uuid='uuid-1',
)
webhook2 = GitlabWebhook(
project_id='project-2',
group_id=None,
user_id='user_1',
webhook_exists=True,
webhook_url='https://example.com/webhook',
webhook_secret='secret-2',
webhook_uuid='uuid-2',
)
webhook3 = GitlabWebhook(
project_id=None,
group_id='group-1',
user_id='user_1',
webhook_exists=False, # Already marked for reinstallation
webhook_url='https://example.com/webhook',
webhook_secret='secret-3',
webhook_uuid='uuid-3',
)
# Create webhook for user_2
webhook4 = GitlabWebhook(
project_id='project-3',
group_id=None,
user_id='user_2',
webhook_exists=True,
webhook_url='https://example.com/webhook',
webhook_secret='secret-4',
webhook_uuid='uuid-4',
)
session.add_all([webhook1, webhook2, webhook3, webhook4])
await session.commit()
# Refresh to get IDs (outside of begin() context)
await session.refresh(webhook1)
await session.refresh(webhook2)
await session.refresh(webhook3)
await session.refresh(webhook4)
return [webhook1, webhook2, webhook3, webhook4]
class TestGetWebhookByResourceOnly:
"""Test cases for get_webhook_by_resource_only method."""
@pytest.mark.asyncio
async def test_get_project_webhook_by_resource_only(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test getting a project webhook by resource ID without user_id filter."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-1'
# Act
webhook = await webhook_store.get_webhook_by_resource_only(
resource_type, resource_id
)
# Assert
assert webhook is not None
assert webhook.project_id == resource_id
assert webhook.user_id == 'user_1'
@pytest.mark.asyncio
async def test_get_group_webhook_by_resource_only(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test getting a group webhook by resource ID without user_id filter."""
# Arrange
resource_type = GitLabResourceType.GROUP
resource_id = 'group-1'
# Act
webhook = await webhook_store.get_webhook_by_resource_only(
resource_type, resource_id
)
# Assert
assert webhook is not None
assert webhook.group_id == resource_id
assert webhook.user_id == 'user_1'
@pytest.mark.asyncio
async def test_get_webhook_by_resource_only_not_found(
self, webhook_store, async_session_maker
):
"""Test getting a webhook that doesn't exist."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'non-existent-project'
# Act
webhook = await webhook_store.get_webhook_by_resource_only(
resource_type, resource_id
)
# Assert
assert webhook is None
@pytest.mark.asyncio
async def test_get_webhook_by_resource_only_organization_wide(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test that webhook lookup works regardless of which user originally created it."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-3' # Created by user_2
# Act
webhook = await webhook_store.get_webhook_by_resource_only(
resource_type, resource_id
)
# Assert
assert webhook is not None
assert webhook.project_id == resource_id
# Should find webhook even though it was created by a different user
assert webhook.user_id == 'user_2'
class TestResetWebhookForReinstallationByResource:
"""Test cases for reset_webhook_for_reinstallation_by_resource method."""
@pytest.mark.asyncio
async def test_reset_project_webhook_by_resource(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test resetting a project webhook by resource without user_id filter."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-1'
updating_user_id = 'user_2' # Different user can reset it
# Act
result = await webhook_store.reset_webhook_for_reinstallation_by_resource(
resource_type, resource_id, updating_user_id
)
# Assert
assert result is True
# Verify webhook was reset
async with async_session_maker() as session:
result_query = await session.execute(
select(GitlabWebhook).where(GitlabWebhook.project_id == resource_id)
)
webhook = result_query.scalars().first()
assert webhook.webhook_exists is False
assert webhook.webhook_uuid is None
assert (
webhook.user_id == updating_user_id
) # Updated to track who modified it
@pytest.mark.asyncio
async def test_reset_group_webhook_by_resource(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test resetting a group webhook by resource without user_id filter."""
# Arrange
resource_type = GitLabResourceType.GROUP
resource_id = 'group-1'
updating_user_id = 'user_2'
# Act
result = await webhook_store.reset_webhook_for_reinstallation_by_resource(
resource_type, resource_id, updating_user_id
)
# Assert
assert result is True
# Verify webhook was reset
async with async_session_maker() as session:
result_query = await session.execute(
select(GitlabWebhook).where(GitlabWebhook.group_id == resource_id)
)
webhook = result_query.scalars().first()
assert webhook.webhook_exists is False
assert webhook.webhook_uuid is None
assert webhook.user_id == updating_user_id
@pytest.mark.asyncio
async def test_reset_webhook_by_resource_not_found(
self, webhook_store, async_session_maker
):
"""Test resetting a webhook that doesn't exist."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'non-existent-project'
updating_user_id = 'user_1'
# Act
result = await webhook_store.reset_webhook_for_reinstallation_by_resource(
resource_type, resource_id, updating_user_id
)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_reset_webhook_by_resource_organization_wide(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test that any user can reset a webhook regardless of original creator."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-3' # Created by user_2
updating_user_id = 'user_1' # Different user resetting it
# Act
result = await webhook_store.reset_webhook_for_reinstallation_by_resource(
resource_type, resource_id, updating_user_id
)
# Assert
assert result is True
# Verify webhook was reset and user_id updated
async with async_session_maker() as session:
result_query = await session.execute(
select(GitlabWebhook).where(GitlabWebhook.project_id == resource_id)
)
webhook = result_query.scalars().first()
assert webhook.webhook_exists is False
assert webhook.user_id == updating_user_id
class TestGetWebhooksByResources:
"""Test cases for get_webhooks_by_resources method."""
@pytest.mark.asyncio
async def test_get_webhooks_by_resources_projects_only(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test bulk fetching webhooks for multiple projects."""
# Arrange
project_ids = ['project-1', 'project-2', 'project-3']
group_ids: list[str] = []
# Act
project_map, group_map = await webhook_store.get_webhooks_by_resources(
project_ids, group_ids
)
# Assert
assert len(project_map) == 3
assert 'project-1' in project_map
assert 'project-2' in project_map
assert 'project-3' in project_map
assert len(group_map) == 0
@pytest.mark.asyncio
async def test_get_webhooks_by_resources_groups_only(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test bulk fetching webhooks for multiple groups."""
# Arrange
project_ids: list[str] = []
group_ids = ['group-1']
# Act
project_map, group_map = await webhook_store.get_webhooks_by_resources(
project_ids, group_ids
)
# Assert
assert len(project_map) == 0
assert len(group_map) == 1
assert 'group-1' in group_map
@pytest.mark.asyncio
async def test_get_webhooks_by_resources_mixed(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test bulk fetching webhooks for both projects and groups."""
# Arrange
project_ids = ['project-1', 'project-2']
group_ids = ['group-1']
# Act
project_map, group_map = await webhook_store.get_webhooks_by_resources(
project_ids, group_ids
)
# Assert
assert len(project_map) == 2
assert len(group_map) == 1
assert 'project-1' in project_map
assert 'project-2' in project_map
assert 'group-1' in group_map
@pytest.mark.asyncio
async def test_get_webhooks_by_resources_empty_lists(
self, webhook_store, async_session_maker
):
"""Test bulk fetching with empty ID lists."""
# Arrange
project_ids: list[str] = []
group_ids: list[str] = []
# Act
project_map, group_map = await webhook_store.get_webhooks_by_resources(
project_ids, group_ids
)
# Assert
assert len(project_map) == 0
assert len(group_map) == 0
@pytest.mark.asyncio
async def test_get_webhooks_by_resources_partial_matches(
self, webhook_store, async_session_maker, sample_webhooks
):
"""Test bulk fetching when some IDs don't exist."""
# Arrange
project_ids = ['project-1', 'non-existent-project']
group_ids = ['group-1', 'non-existent-group']
# Act
project_map, group_map = await webhook_store.get_webhooks_by_resources(
project_ids, group_ids
)
# Assert
assert len(project_map) == 1
assert 'project-1' in project_map
assert 'non-existent-project' not in project_map
assert len(group_map) == 1
assert 'group-1' in group_map
assert 'non-existent-group' not in group_map
@@ -0,0 +1,438 @@
"""Unit tests for install_gitlab_webhooks module."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from integrations.gitlab.webhook_installation import (
BreakLoopException,
install_webhook_on_resource,
verify_webhook_conditions,
)
from integrations.types import GitLabResourceType
from integrations.utils import GITLAB_WEBHOOK_URL
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
@pytest.fixture
def mock_gitlab_service():
"""Create a mock GitLab service."""
service = MagicMock()
service.check_resource_exists = AsyncMock(return_value=(True, None))
service.check_user_has_admin_access_to_resource = AsyncMock(
return_value=(True, None)
)
service.check_webhook_exists_on_resource = AsyncMock(return_value=(False, None))
service.install_webhook = AsyncMock(return_value=('webhook-id-123', None))
return service
@pytest.fixture
def mock_webhook_store():
"""Create a mock webhook store."""
store = MagicMock()
store.delete_webhook = AsyncMock()
store.update_webhook = AsyncMock()
return store
@pytest.fixture
def sample_webhook():
"""Create a sample webhook object."""
webhook = MagicMock(spec=GitlabWebhook)
webhook.user_id = 'test_user_id'
webhook.webhook_exists = False
webhook.webhook_uuid = None
return webhook
class TestVerifyWebhookConditions:
"""Test cases for verify_webhook_conditions function."""
@pytest.mark.asyncio
async def test_verify_conditions_all_pass(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when all conditions are met for webhook installation."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
# Act
# Should not raise any exception
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert
mock_gitlab_service.check_resource_exists.assert_called_once_with(
resource_type, resource_id
)
mock_gitlab_service.check_user_has_admin_access_to_resource.assert_called_once_with(
resource_type, resource_id
)
mock_gitlab_service.check_webhook_exists_on_resource.assert_called_once_with(
resource_type, resource_id, GITLAB_WEBHOOK_URL
)
mock_webhook_store.delete_webhook.assert_not_called()
@pytest.mark.asyncio
async def test_verify_conditions_resource_does_not_exist(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when resource does not exist."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-999'
mock_gitlab_service.check_resource_exists = AsyncMock(
return_value=(False, None)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert webhook is deleted
mock_webhook_store.delete_webhook.assert_called_once_with(sample_webhook)
@pytest.mark.asyncio
async def test_verify_conditions_rate_limited_on_resource_check(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when rate limited during resource existence check."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
mock_gitlab_service.check_resource_exists = AsyncMock(
return_value=(False, WebhookStatus.RATE_LIMITED)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Should not delete webhook on rate limit
mock_webhook_store.delete_webhook.assert_not_called()
@pytest.mark.asyncio
async def test_verify_conditions_user_no_admin_access(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when user does not have admin access."""
# Arrange
resource_type = GitLabResourceType.GROUP
resource_id = 'group-456'
mock_gitlab_service.check_user_has_admin_access_to_resource = AsyncMock(
return_value=(False, None)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert webhook is deleted
mock_webhook_store.delete_webhook.assert_called_once_with(sample_webhook)
@pytest.mark.asyncio
async def test_verify_conditions_rate_limited_on_admin_check(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when rate limited during admin access check."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
mock_gitlab_service.check_user_has_admin_access_to_resource = AsyncMock(
return_value=(False, WebhookStatus.RATE_LIMITED)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Should not delete webhook on rate limit
mock_webhook_store.delete_webhook.assert_not_called()
@pytest.mark.asyncio
async def test_verify_conditions_webhook_already_exists(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when webhook already exists on resource."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
mock_gitlab_service.check_webhook_exists_on_resource = AsyncMock(
return_value=(True, None)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
@pytest.mark.asyncio
async def test_verify_conditions_rate_limited_on_webhook_check(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when rate limited during webhook existence check."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
mock_gitlab_service.check_webhook_exists_on_resource = AsyncMock(
return_value=(False, WebhookStatus.RATE_LIMITED)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
@pytest.mark.asyncio
async def test_verify_conditions_updates_webhook_status_mismatch(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test that webhook status is updated when database and API don't match."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
sample_webhook.webhook_exists = True # DB says exists
mock_gitlab_service.check_webhook_exists_on_resource = AsyncMock(
return_value=(False, None) # API says doesn't exist
)
# Act
# Should not raise BreakLoopException when webhook doesn't exist (allows installation)
await verify_webhook_conditions(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert webhook status was updated to match API
mock_webhook_store.update_webhook.assert_called_once_with(
sample_webhook, {'webhook_exists': False}
)
class TestInstallWebhookOnResource:
"""Test cases for install_webhook_on_resource function."""
@pytest.mark.asyncio
async def test_install_webhook_success(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test successful webhook installation."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
# Act
webhook_id, status = await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert
assert webhook_id == 'webhook-id-123'
assert status is None
mock_gitlab_service.install_webhook.assert_called_once()
mock_webhook_store.update_webhook.assert_called_once()
# Verify update_webhook was called with correct fields (using keyword arguments)
call_args = mock_webhook_store.update_webhook.call_args
assert call_args[1]['webhook'] == sample_webhook
update_fields = call_args[1]['update_fields']
assert update_fields['webhook_exists'] is True
assert update_fields['webhook_url'] == GITLAB_WEBHOOK_URL
assert 'webhook_secret' in update_fields
assert 'webhook_uuid' in update_fields
assert 'scopes' in update_fields
@pytest.mark.asyncio
async def test_install_webhook_group_resource(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test webhook installation for a group resource."""
# Arrange
resource_type = GitLabResourceType.GROUP
resource_id = 'group-456'
# Act
webhook_id, status = await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert
assert webhook_id == 'webhook-id-123'
# Verify install_webhook was called with GROUP type
call_args = mock_gitlab_service.install_webhook.call_args
assert call_args[1]['resource_type'] == resource_type
assert call_args[1]['resource_id'] == resource_id
@pytest.mark.asyncio
async def test_install_webhook_rate_limited(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when installation is rate limited."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
mock_gitlab_service.install_webhook = AsyncMock(
return_value=(None, WebhookStatus.RATE_LIMITED)
)
# Act & Assert
with pytest.raises(BreakLoopException):
await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Should not update webhook on rate limit
mock_webhook_store.update_webhook.assert_not_called()
@pytest.mark.asyncio
async def test_install_webhook_installation_fails(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test when webhook installation fails."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
mock_gitlab_service.install_webhook = AsyncMock(return_value=(None, None))
# Act
webhook_id, status = await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert
assert webhook_id is None
assert status is None
# Should not update webhook when installation fails
mock_webhook_store.update_webhook.assert_not_called()
@pytest.mark.asyncio
async def test_install_webhook_generates_unique_secrets(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test that unique webhook secrets and UUIDs are generated."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
# Act - First call
webhook_id1, _ = await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Capture first call's values before resetting
call1_secret = mock_webhook_store.update_webhook.call_args_list[0][1][
'update_fields'
]['webhook_secret']
call1_uuid = mock_webhook_store.update_webhook.call_args_list[0][1][
'update_fields'
]['webhook_uuid']
# Reset mocks and call again
mock_gitlab_service.install_webhook.reset_mock()
mock_webhook_store.update_webhook.reset_mock()
# Act - Second call
webhook_id2, _ = await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Capture second call's values
call2_secret = mock_webhook_store.update_webhook.call_args_list[0][1][
'update_fields'
]['webhook_secret']
call2_uuid = mock_webhook_store.update_webhook.call_args_list[0][1][
'update_fields'
]['webhook_uuid']
# Assert - Secrets and UUIDs should be different
assert call1_secret != call2_secret
assert call1_uuid != call2_uuid
@pytest.mark.asyncio
async def test_install_webhook_uses_correct_webhook_name_and_url(
self, mock_gitlab_service, mock_webhook_store, sample_webhook
):
"""Test that correct webhook name and URL are used."""
# Arrange
resource_type = GitLabResourceType.PROJECT
resource_id = 'project-123'
# Act
await install_webhook_on_resource(
gitlab_service=mock_gitlab_service,
resource_type=resource_type,
resource_id=resource_id,
webhook_store=mock_webhook_store,
webhook=sample_webhook,
)
# Assert
call_args = mock_gitlab_service.install_webhook.call_args
assert call_args[1]['webhook_name'] == 'OpenHands Resolver'
assert call_args[1]['webhook_url'] == GITLAB_WEBHOOK_URL
@@ -234,3 +234,53 @@ async def test_middleware_with_other_auth_error(middleware, mock_request):
assert 'set-cookie' in result.headers
# Logger should be called for non-NoCredentialsError
mock_logger.warning.assert_called_once()
@pytest.mark.asyncio
async def test_middleware_ignores_email_resend_path(
middleware, mock_request, mock_response
):
"""Test middleware ignores /api/email/resend path and doesn't require authentication."""
# Arrange
mock_request.cookies = {}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = '/api/email/resend'
mock_call_next = AsyncMock(return_value=mock_response)
# Act
result = await middleware(mock_request, mock_call_next)
# Assert
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
# Should not raise NoCredentialsError even without auth cookie
@pytest.mark.asyncio
async def test_middleware_ignores_email_resend_path_no_tos_check(
middleware, mock_request, mock_response
):
"""Test middleware doesn't check TOS for /api/email/resend path."""
# Arrange
mock_request.cookies = {'keycloak_auth': 'test_cookie'}
mock_request.url = MagicMock()
mock_request.url.hostname = 'localhost'
mock_request.url.path = '/api/email/resend'
mock_call_next = AsyncMock(return_value=mock_response)
with (
patch('server.middleware.jwt.decode') as mock_decode,
patch('server.middleware.config') as mock_config,
):
# Even with accepted_tos=False, should not raise TosNotAcceptedError
mock_decode.return_value = {'accepted_tos': False}
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
# Act
result = await middleware(mock_request, mock_call_next)
# Assert
assert result == mock_response
mock_call_next.assert_called_once_with(mock_request)
# Should not raise TosNotAcceptedError for this path
+496
View File
@@ -136,6 +136,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
@@ -184,6 +185,7 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
@@ -214,6 +216,84 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
mock_posthog.set.assert_called_once()
@pytest.mark.asyncio
async def test_keycloak_callback_email_not_verified(mock_request):
"""Test keycloak_callback when email is not verified."""
# Arrange
mock_verify_email = AsyncMock()
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.email.verify_email', mock_verify_email),
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': False,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_verifier.is_active.return_value = False
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'email_verification_required=true' in result.headers['location']
assert 'user_id=test_user_id' in result.headers['location']
mock_verify_email.assert_called_once_with(
request=mock_request, user_id='test_user_id', is_auth_flow=True
)
@pytest.mark.asyncio
async def test_keycloak_callback_email_not_verified_missing_field(mock_request):
"""Test keycloak_callback when email_verified field is missing (defaults to False)."""
# Arrange
mock_verify_email = AsyncMock()
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.email.verify_email', mock_verify_email),
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
# email_verified field is missing
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_verifier.is_active.return_value = False
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'email_verification_required=true' in result.headers['location']
assert 'user_id=test_user_id' in result.headers['location']
mock_verify_email.assert_called_once_with(
request=mock_request, user_id='test_user_id', is_auth_flow=True
)
@pytest.mark.asyncio
async def test_keycloak_callback_success_without_offline_token(mock_request):
"""Test successful keycloak_callback without valid offline token."""
@@ -248,6 +328,7 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
@@ -442,3 +523,418 @@ async def test_logout_without_refresh_token():
mock_token_manager.logout.assert_not_called()
assert 'set-cookie' in result.headers
@pytest.mark.asyncio
async def test_keycloak_callback_blocked_email_domain(mock_request):
"""Test keycloak_callback when email domain is blocked."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'user@colsch.us',
'identity_provider': 'github',
}
)
mock_token_manager.disable_keycloak_user = AsyncMock()
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
assert 'error' in result.body.decode()
assert 'email domain is not allowed' in result.body.decode()
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
mock_token_manager.disable_keycloak_user.assert_called_once_with(
'test_user_id', 'user@colsch.us'
)
@pytest.mark.asyncio
async def test_keycloak_callback_allowed_email_domain(mock_request):
"""Test keycloak_callback when email domain is not blocked."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'user@example.com',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
'user@example.com'
)
mock_token_manager.disable_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_domain_blocking_inactive(mock_request):
"""Test keycloak_callback when domain blocking is not active."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'user@colsch.us',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = False
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_not_called()
mock_token_manager.disable_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_missing_email(mock_request):
"""Test keycloak_callback when user info does not contain email."""
# Arrange
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.domain_blocker') as mock_domain_blocker,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'identity_provider': 'github',
'email_verified': True,
# No email field
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_domain_blocker.is_active.return_value = True
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
mock_domain_blocker.is_domain_blocked.assert_not_called()
mock_token_manager.disable_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_duplicate_email_detected(mock_request):
"""Test keycloak_callback when duplicate email is detected."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
):
# Arrange
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=True)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'duplicated_email=true' in result.headers['location']
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
'joe+test@example.com', 'test_user_id'
)
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
@pytest.mark.asyncio
async def test_keycloak_callback_duplicate_email_deletion_fails(mock_request):
"""Test keycloak_callback when duplicate is detected but deletion fails."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
):
# Arrange
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=True)
mock_token_manager.delete_keycloak_user = AsyncMock(return_value=False)
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert 'duplicated_email=true' in result.headers['location']
mock_token_manager.delete_keycloak_user.assert_called_once_with('test_user_id')
@pytest.mark.asyncio
async def test_keycloak_callback_duplicate_check_exception(mock_request):
"""Test keycloak_callback when duplicate check raises exception."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
# Arrange
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(
side_effect=Exception('Check failed')
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
# Should proceed with normal flow despite exception (fail open)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
@pytest.mark.asyncio
async def test_keycloak_callback_no_duplicate_email(mock_request):
"""Test keycloak_callback when no duplicate email is found."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
# Arrange
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
'email': 'joe+test@example.com',
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.check_duplicate_base_email = AsyncMock(return_value=False)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
mock_token_manager.check_duplicate_base_email.assert_called_once_with(
'joe+test@example.com', 'test_user_id'
)
# Should not delete user when no duplicate found
mock_token_manager.delete_keycloak_user.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_no_email_in_user_info(mock_request):
"""Test keycloak_callback when email is not in user_info."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.session_maker') as mock_session_maker,
):
# Arrange
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value={
'sub': 'test_user_id',
'preferred_username': 'test_user',
# No email field
'identity_provider': 'github',
'email_verified': True,
}
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=True)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = True
# Act
result = await keycloak_callback(
code='test_code', state='test_state', request=mock_request
)
# Assert
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
# Should not check for duplicate when email is missing
mock_token_manager.check_duplicate_base_email.assert_not_called()
+6 -7
View File
@@ -4,7 +4,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import stripe
from fastapi import HTTPException, Request, status
from httpx import HTTPStatusError, Response
from httpx import Response
from integrations.stripe_service import has_payment_method
from server.routes.billing import (
CreateBillingSessionResponse,
@@ -78,8 +78,6 @@ def mock_subscription_request():
@pytest.mark.asyncio
async def test_get_credits_lite_llm_error():
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
mock_response = Response(
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
)
@@ -88,11 +86,12 @@ async def test_get_credits_lite_llm_error():
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
with patch('httpx.AsyncClient', return_value=mock_client):
with pytest.raises(HTTPStatusError) as exc_info:
await get_credits(mock_request)
with pytest.raises(HTTPException) as exc_info:
await get_credits('mock_user')
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert (
exc_info.value.response.status_code
== status.HTTP_500_INTERNAL_SERVER_ERROR
exc_info.value.detail
== 'Failed to retrieve credit balance from billing service'
)
@@ -0,0 +1,493 @@
"""Unit tests for DomainBlocker class."""
import pytest
from server.auth.domain_blocker import DomainBlocker
@pytest.fixture
def domain_blocker():
"""Create a DomainBlocker instance for testing."""
return DomainBlocker()
@pytest.mark.parametrize(
'blocked_domains,expected',
[
(['colsch.us', 'other-domain.com'], True),
(['example.com'], True),
([], False),
],
)
def test_is_active(domain_blocker, blocked_domains, expected):
"""Test that is_active returns correct value based on blocked domains configuration."""
# Arrange
domain_blocker.blocked_domains = blocked_domains
# Act
result = domain_blocker.is_active()
# Assert
assert result == expected
@pytest.mark.parametrize(
'email,expected_domain',
[
('user@example.com', 'example.com'),
('test@colsch.us', 'colsch.us'),
('user.name@other-domain.com', 'other-domain.com'),
('USER@EXAMPLE.COM', 'example.com'), # Case insensitive
('user@EXAMPLE.COM', 'example.com'),
(' user@example.com ', 'example.com'), # Whitespace handling
],
)
def test_extract_domain_valid_emails(domain_blocker, email, expected_domain):
"""Test that _extract_domain correctly extracts and normalizes domains from valid emails."""
# Act
result = domain_blocker._extract_domain(email)
# Assert
assert result == expected_domain
@pytest.mark.parametrize(
'email,expected',
[
(None, None),
('', None),
('invalid-email', None),
('user@', None), # Empty domain after @
('no-at-sign', None),
],
)
def test_extract_domain_invalid_emails(domain_blocker, email, expected):
"""Test that _extract_domain returns None for invalid email formats."""
# Act
result = domain_blocker._extract_domain(email)
# Assert
assert result == expected
def test_is_domain_blocked_when_inactive(domain_blocker):
"""Test that is_domain_blocked returns False when blocking is not active."""
# Arrange
domain_blocker.blocked_domains = []
# Act
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is False
def test_is_domain_blocked_with_none_email(domain_blocker):
"""Test that is_domain_blocked returns False when email is None."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked(None)
# Assert
assert result is False
def test_is_domain_blocked_with_empty_email(domain_blocker):
"""Test that is_domain_blocked returns False when email is empty."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked('')
# Assert
assert result is False
def test_is_domain_blocked_with_invalid_email(domain_blocker):
"""Test that is_domain_blocked returns False when email format is invalid."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked('invalid-email')
# Assert
assert result is False
def test_is_domain_blocked_domain_not_blocked(domain_blocker):
"""Test that is_domain_blocked returns False when domain is not in blocked list."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
# Act
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is False
def test_is_domain_blocked_domain_blocked(domain_blocker):
"""Test that is_domain_blocked returns True when domain is in blocked list."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com']
# Act
result = domain_blocker.is_domain_blocked('user@colsch.us')
# Assert
assert result is True
def test_is_domain_blocked_case_insensitive(domain_blocker):
"""Test that is_domain_blocked performs case-insensitive domain matching."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked('user@COLSCH.US')
# Assert
assert result is True
def test_is_domain_blocked_multiple_blocked_domains(domain_blocker):
"""Test that is_domain_blocked correctly checks against multiple blocked domains."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us', 'other-domain.com', 'blocked.org']
# Act
result1 = domain_blocker.is_domain_blocked('user@other-domain.com')
result2 = domain_blocker.is_domain_blocked('user@blocked.org')
result3 = domain_blocker.is_domain_blocked('user@allowed.com')
# Assert
assert result1 is True
assert result2 is True
assert result3 is False
def test_is_domain_blocked_with_whitespace(domain_blocker):
"""Test that is_domain_blocked handles emails with whitespace correctly."""
# Arrange
domain_blocker.blocked_domains = ['colsch.us']
# Act
result = domain_blocker.is_domain_blocked(' user@colsch.us ')
# Assert
assert result is True
# ============================================================================
# TLD Blocking Tests (patterns starting with '.')
# ============================================================================
def test_is_domain_blocked_tld_pattern_blocks_matching_domain(domain_blocker):
"""Test that TLD pattern blocks domains ending with that TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
# Act
result = domain_blocker.is_domain_blocked('user@company.us')
# Assert
assert result is True
def test_is_domain_blocked_tld_pattern_blocks_subdomain_with_tld(domain_blocker):
"""Test that TLD pattern blocks subdomains with that TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
# Act
result = domain_blocker.is_domain_blocked('user@subdomain.company.us')
# Assert
assert result is True
def test_is_domain_blocked_tld_pattern_does_not_block_different_tld(domain_blocker):
"""Test that TLD pattern does not block domains with different TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
# Act
result = domain_blocker.is_domain_blocked('user@company.com')
# Assert
assert result is False
def test_is_domain_blocked_tld_pattern_does_not_block_substring_match(
domain_blocker,
):
"""Test that TLD pattern does not block domains that contain but don't end with the TLD."""
# Arrange
domain_blocker.blocked_domains = ['.us']
# Act
result = domain_blocker.is_domain_blocked('user@focus.com')
# Assert
assert result is False
def test_is_domain_blocked_tld_pattern_case_insensitive(domain_blocker):
"""Test that TLD pattern matching is case-insensitive."""
# Arrange
domain_blocker.blocked_domains = ['.us']
# Act
result = domain_blocker.is_domain_blocked('user@COMPANY.US')
# Assert
assert result is True
def test_is_domain_blocked_multiple_tld_patterns(domain_blocker):
"""Test blocking with multiple TLD patterns."""
# Arrange
domain_blocker.blocked_domains = ['.us', '.vn', '.com']
# Act
result_us = domain_blocker.is_domain_blocked('user@test.us')
result_vn = domain_blocker.is_domain_blocked('user@test.vn')
result_com = domain_blocker.is_domain_blocked('user@test.com')
result_org = domain_blocker.is_domain_blocked('user@test.org')
# Assert
assert result_us is True
assert result_vn is True
assert result_com is True
assert result_org is False
def test_is_domain_blocked_tld_pattern_with_multi_level_tld(domain_blocker):
"""Test that TLD pattern works with multi-level TLDs like .co.uk."""
# Arrange
domain_blocker.blocked_domains = ['.co.uk']
# Act
result_match = domain_blocker.is_domain_blocked('user@example.co.uk')
result_subdomain = domain_blocker.is_domain_blocked('user@api.example.co.uk')
result_no_match = domain_blocker.is_domain_blocked('user@example.uk')
# Assert
assert result_match is True
assert result_subdomain is True
assert result_no_match is False
# ============================================================================
# Subdomain Blocking Tests (domain patterns now block subdomains)
# ============================================================================
def test_is_domain_blocked_domain_pattern_blocks_exact_match(domain_blocker):
"""Test that domain pattern blocks exact domain match."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
# Act
result = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result is True
def test_is_domain_blocked_domain_pattern_blocks_subdomain(domain_blocker):
"""Test that domain pattern blocks subdomains of that domain."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
# Act
result = domain_blocker.is_domain_blocked('user@subdomain.example.com')
# Assert
assert result is True
def test_is_domain_blocked_domain_pattern_blocks_multi_level_subdomain(
domain_blocker,
):
"""Test that domain pattern blocks multi-level subdomains."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
# Act
result = domain_blocker.is_domain_blocked('user@api.v2.example.com')
# Assert
assert result is True
def test_is_domain_blocked_domain_pattern_does_not_block_similar_domain(
domain_blocker,
):
"""Test that domain pattern does not block domains that contain but don't match the pattern."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
# Act
result = domain_blocker.is_domain_blocked('user@notexample.com')
# Assert
assert result is False
def test_is_domain_blocked_domain_pattern_does_not_block_different_tld(
domain_blocker,
):
"""Test that domain pattern does not block same domain with different TLD."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
# Act
result = domain_blocker.is_domain_blocked('user@example.org')
# Assert
assert result is False
def test_is_domain_blocked_subdomain_pattern_blocks_exact_and_nested(domain_blocker):
"""Test that blocking a subdomain also blocks its nested subdomains."""
# Arrange
domain_blocker.blocked_domains = ['api.example.com']
# Act
result_exact = domain_blocker.is_domain_blocked('user@api.example.com')
result_nested = domain_blocker.is_domain_blocked('user@v1.api.example.com')
result_parent = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result_exact is True
assert result_nested is True
assert result_parent is False
# ============================================================================
# Mixed Pattern Tests (TLD + domain patterns together)
# ============================================================================
def test_is_domain_blocked_mixed_patterns_tld_and_domain(domain_blocker):
"""Test blocking with both TLD and domain patterns."""
# Arrange
domain_blocker.blocked_domains = ['.us', 'openhands.dev']
# Act
result_tld = domain_blocker.is_domain_blocked('user@company.us')
result_domain = domain_blocker.is_domain_blocked('user@openhands.dev')
result_subdomain = domain_blocker.is_domain_blocked('user@api.openhands.dev')
result_allowed = domain_blocker.is_domain_blocked('user@example.com')
# Assert
assert result_tld is True
assert result_domain is True
assert result_subdomain is True
assert result_allowed is False
def test_is_domain_blocked_overlapping_patterns(domain_blocker):
"""Test that overlapping patterns (TLD and specific domain) both work."""
# Arrange
domain_blocker.blocked_domains = ['.us', 'test.us']
# Act
result_specific = domain_blocker.is_domain_blocked('user@test.us')
result_other_us = domain_blocker.is_domain_blocked('user@other.us')
# Assert
assert result_specific is True
assert result_other_us is True
def test_is_domain_blocked_complex_multi_pattern_scenario(domain_blocker):
"""Test complex scenario with multiple TLD and domain patterns."""
# Arrange
domain_blocker.blocked_domains = [
'.us',
'.vn',
'test.com',
'openhands.dev',
]
# Act & Assert
# TLD patterns
assert domain_blocker.is_domain_blocked('user@anything.us') is True
assert domain_blocker.is_domain_blocked('user@company.vn') is True
# Domain patterns (exact)
assert domain_blocker.is_domain_blocked('user@test.com') is True
assert domain_blocker.is_domain_blocked('user@openhands.dev') is True
# Domain patterns (subdomains)
assert domain_blocker.is_domain_blocked('user@api.test.com') is True
assert domain_blocker.is_domain_blocked('user@staging.openhands.dev') is True
# Not blocked
assert domain_blocker.is_domain_blocked('user@allowed.com') is False
assert domain_blocker.is_domain_blocked('user@example.org') is False
# ============================================================================
# Edge Case Tests
# ============================================================================
def test_is_domain_blocked_domain_with_hyphens(domain_blocker):
"""Test that domain patterns work with hyphenated domains."""
# Arrange
domain_blocker.blocked_domains = ['my-company.com']
# Act
result_exact = domain_blocker.is_domain_blocked('user@my-company.com')
result_subdomain = domain_blocker.is_domain_blocked('user@api.my-company.com')
# Assert
assert result_exact is True
assert result_subdomain is True
def test_is_domain_blocked_domain_with_numbers(domain_blocker):
"""Test that domain patterns work with numeric domains."""
# Arrange
domain_blocker.blocked_domains = ['test123.com']
# Act
result_exact = domain_blocker.is_domain_blocked('user@test123.com')
result_subdomain = domain_blocker.is_domain_blocked('user@api.test123.com')
# Assert
assert result_exact is True
assert result_subdomain is True
def test_is_domain_blocked_short_tld(domain_blocker):
"""Test that short TLD patterns work correctly."""
# Arrange
domain_blocker.blocked_domains = ['.io']
# Act
result = domain_blocker.is_domain_blocked('user@company.io')
# Assert
assert result is True
def test_is_domain_blocked_very_long_subdomain_chain(domain_blocker):
"""Test that blocking works with very long subdomain chains."""
# Arrange
domain_blocker.blocked_domains = ['example.com']
# Act
result = domain_blocker.is_domain_blocked(
'user@level4.level3.level2.level1.example.com'
)
# Assert
assert result is True
@@ -0,0 +1,294 @@
"""Tests for email validation utilities."""
import re
from server.auth.email_validation import (
extract_base_email,
get_base_email_regex_pattern,
has_plus_modifier,
matches_base_email,
)
class TestExtractBaseEmail:
"""Test cases for extract_base_email function."""
def test_extract_base_email_with_plus_modifier(self):
"""Test extracting base email from email with + modifier."""
# Arrange
email = 'joe+test@example.com'
# Act
result = extract_base_email(email)
# Assert
assert result == 'joe@example.com'
def test_extract_base_email_without_plus_modifier(self):
"""Test that email without + modifier is returned as-is."""
# Arrange
email = 'joe@example.com'
# Act
result = extract_base_email(email)
# Assert
assert result == 'joe@example.com'
def test_extract_base_email_multiple_plus_signs(self):
"""Test extracting base email when multiple + signs exist."""
# Arrange
email = 'joe+openhands+test@example.com'
# Act
result = extract_base_email(email)
# Assert
assert result == 'joe@example.com'
def test_extract_base_email_invalid_no_at_symbol(self):
"""Test that invalid email without @ returns None."""
# Arrange
email = 'invalid-email'
# Act
result = extract_base_email(email)
# Assert
assert result is None
def test_extract_base_email_empty_string(self):
"""Test that empty string returns None."""
# Arrange
email = ''
# Act
result = extract_base_email(email)
# Assert
assert result is None
def test_extract_base_email_none(self):
"""Test that None input returns None."""
# Arrange
email = None
# Act
result = extract_base_email(email)
# Assert
assert result is None
class TestHasPlusModifier:
"""Test cases for has_plus_modifier function."""
def test_has_plus_modifier_true(self):
"""Test detecting + modifier in email."""
# Arrange
email = 'joe+test@example.com'
# Act
result = has_plus_modifier(email)
# Assert
assert result is True
def test_has_plus_modifier_false(self):
"""Test that email without + modifier returns False."""
# Arrange
email = 'joe@example.com'
# Act
result = has_plus_modifier(email)
# Assert
assert result is False
def test_has_plus_modifier_invalid_no_at_symbol(self):
"""Test that invalid email without @ returns False."""
# Arrange
email = 'invalid-email'
# Act
result = has_plus_modifier(email)
# Assert
assert result is False
def test_has_plus_modifier_empty_string(self):
"""Test that empty string returns False."""
# Arrange
email = ''
# Act
result = has_plus_modifier(email)
# Assert
assert result is False
class TestMatchesBaseEmail:
"""Test cases for matches_base_email function."""
def test_matches_base_email_exact_match(self):
"""Test that exact base email matches."""
# Arrange
email = 'joe@example.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is True
def test_matches_base_email_with_plus_variant(self):
"""Test that email with + variant matches base email."""
# Arrange
email = 'joe+test@example.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is True
def test_matches_base_email_different_base(self):
"""Test that different base emails do not match."""
# Arrange
email = 'jane@example.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is False
def test_matches_base_email_different_domain(self):
"""Test that same local part but different domain does not match."""
# Arrange
email = 'joe@other.com'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is False
def test_matches_base_email_case_insensitive(self):
"""Test that matching is case-insensitive."""
# Arrange
email = 'JOE+TEST@EXAMPLE.COM'
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is True
def test_matches_base_email_empty_strings(self):
"""Test that empty strings return False."""
# Arrange
email = ''
base_email = 'joe@example.com'
# Act
result = matches_base_email(email, base_email)
# Assert
assert result is False
class TestGetBaseEmailRegexPattern:
"""Test cases for get_base_email_regex_pattern function."""
def test_get_base_email_regex_pattern_valid(self):
"""Test generating valid regex pattern for base email."""
# Arrange
base_email = 'joe@example.com'
# Act
pattern = get_base_email_regex_pattern(base_email)
# Assert
assert pattern is not None
assert isinstance(pattern, re.Pattern)
assert pattern.match('joe@example.com') is not None
assert pattern.match('joe+test@example.com') is not None
assert pattern.match('joe+openhands@example.com') is not None
def test_get_base_email_regex_pattern_matches_plus_variant(self):
"""Test that regex pattern matches + variant."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('joe+test@example.com')
# Assert
assert match is not None
def test_get_base_email_regex_pattern_rejects_different_base(self):
"""Test that regex pattern rejects different base email."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('jane@example.com')
# Assert
assert match is None
def test_get_base_email_regex_pattern_rejects_different_domain(self):
"""Test that regex pattern rejects different domain."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('joe@other.com')
# Assert
assert match is None
def test_get_base_email_regex_pattern_case_insensitive(self):
"""Test that regex pattern is case-insensitive."""
# Arrange
base_email = 'joe@example.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('JOE+TEST@EXAMPLE.COM')
# Assert
assert match is not None
def test_get_base_email_regex_pattern_special_characters(self):
"""Test that regex pattern handles special characters in email."""
# Arrange
base_email = 'user.name+tag@example-site.com'
pattern = get_base_email_regex_pattern(base_email)
# Act
match = pattern.match('user.name+test@example-site.com')
# Assert
assert match is not None
def test_get_base_email_regex_pattern_invalid_base_email(self):
"""Test that invalid base email returns None."""
# Arrange
base_email = 'invalid-email'
# Act
pattern = get_base_email_regex_pattern(base_email)
# Assert
assert pattern is None
@@ -0,0 +1,437 @@
"""
TDD Tests for SaasNestedConversationManager token refresh functionality.
This module tests the token refresh logic that prevents stale tokens from being
sent to nested runtimes after Runtime.__init__() refreshes them.
Test Coverage:
- Token refresh with IDP user ID (GitLab webhook flow)
- Token refresh with Keycloak user ID (Web UI flow)
- Error handling and fallback behavior
- Settings immutability handling
"""
from types import MappingProxyType
from unittest.mock import AsyncMock, Mock, patch
import pytest
from pydantic import SecretStr
from enterprise.server.saas_nested_conversation_manager import (
SaasNestedConversationManager,
)
from openhands.integrations.provider import ProviderToken, ProviderType
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.storage.data_models.settings import Settings
class TestRefreshProviderTokensAfterRuntimeInit:
"""Test suite for _refresh_provider_tokens_after_runtime_init method."""
@pytest.fixture
def conversation_manager(self):
"""Create a minimal SaasNestedConversationManager instance for testing."""
# Arrange: Create mock dependencies
mock_sio = Mock()
mock_config = Mock()
mock_config.max_concurrent_conversations = 5
mock_server_config = Mock()
mock_file_store = Mock()
# Create manager instance
manager = SaasNestedConversationManager(
sio=mock_sio,
config=mock_config,
server_config=mock_server_config,
file_store=mock_file_store,
event_retrieval=Mock(),
)
return manager
@pytest.fixture
def gitlab_provider_token_with_user_id(self):
"""Create a GitLab ProviderToken with IDP user ID (webhook flow)."""
return ProviderToken(
token=SecretStr('old_token_abc123'),
user_id='32546706', # GitLab user ID
host=None,
)
@pytest.fixture
def gitlab_provider_token_without_user_id(self):
"""Create a GitLab ProviderToken without IDP user ID (web UI flow)."""
return ProviderToken(
token=SecretStr('old_token_xyz789'),
user_id=None,
host=None,
)
@pytest.fixture
def conversation_init_data_with_user_id(self, gitlab_provider_token_with_user_id):
"""Create ConversationInitData with provider token containing user_id."""
return ConversationInitData(
git_provider_tokens=MappingProxyType(
{ProviderType.GITLAB: gitlab_provider_token_with_user_id}
)
)
@pytest.fixture
def conversation_init_data_without_user_id(
self, gitlab_provider_token_without_user_id
):
"""Create ConversationInitData with provider token without user_id."""
return ConversationInitData(
git_provider_tokens=MappingProxyType(
{ProviderType.GITLAB: gitlab_provider_token_without_user_id}
)
)
@pytest.mark.asyncio
async def test_returns_original_settings_when_not_conversation_init_data(
self, conversation_manager
):
"""
Test: Returns original settings when not ConversationInitData.
Arrange: Create a Settings object (not ConversationInitData)
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Returns the same settings object unchanged
"""
# Arrange
settings = Settings()
sid = 'test_session_123'
# Act
result = await conversation_manager._refresh_provider_tokens_after_runtime_init(
settings, sid
)
# Assert
assert result is settings
@pytest.mark.asyncio
async def test_returns_original_settings_when_no_provider_tokens(
self, conversation_manager
):
"""
Test: Returns original settings when no provider tokens present.
Arrange: Create ConversationInitData without git_provider_tokens
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Returns the same settings object unchanged
"""
# Arrange
settings = ConversationInitData(git_provider_tokens=None)
sid = 'test_session_456'
# Act
result = await conversation_manager._refresh_provider_tokens_after_runtime_init(
settings, sid
)
# Assert
assert result is settings
@pytest.mark.asyncio
async def test_refreshes_token_with_idp_user_id(
self, conversation_manager, conversation_init_data_with_user_id
):
"""
Test: Refreshes token using IDP user ID (GitLab webhook flow).
Arrange: ConversationInitData with GitLab token containing user_id
Act: Call _refresh_provider_tokens_after_runtime_init with mocked TokenManager
Assert: Token is refreshed using get_idp_token_from_idp_user_id
"""
# Arrange
sid = 'test_session_789'
fresh_token = 'fresh_token_def456'
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock(
return_value=fresh_token
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
conversation_init_data_with_user_id, sid
)
)
# Assert
mock_token_manager.get_idp_token_from_idp_user_id.assert_called_once_with(
'32546706', ProviderType.GITLAB
)
assert (
result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value()
== fresh_token
)
assert result.git_provider_tokens[ProviderType.GITLAB].user_id == '32546706'
@pytest.mark.asyncio
async def test_refreshes_token_with_keycloak_user_id(
self, conversation_manager, conversation_init_data_without_user_id
):
"""
Test: Refreshes token using Keycloak user ID (Web UI flow).
Arrange: ConversationInitData without IDP user_id, but with Keycloak user_id
Act: Call _refresh_provider_tokens_after_runtime_init with mocked TokenManager
Assert: Token is refreshed using load_offline_token + get_idp_token_from_offline_token
"""
# Arrange
sid = 'test_session_101'
keycloak_user_id = 'keycloak_user_abc'
offline_token = 'offline_token_xyz'
fresh_token = 'fresh_token_ghi789'
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
mock_token_manager.load_offline_token = AsyncMock(
return_value=offline_token
)
mock_token_manager.get_idp_token_from_offline_token = AsyncMock(
return_value=fresh_token
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
conversation_init_data_without_user_id, sid, keycloak_user_id
)
)
# Assert
mock_token_manager.load_offline_token.assert_called_once_with(
keycloak_user_id
)
mock_token_manager.get_idp_token_from_offline_token.assert_called_once_with(
offline_token, ProviderType.GITLAB
)
assert (
result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value()
== fresh_token
)
assert result.git_provider_tokens[ProviderType.GITLAB].user_id is None
@pytest.mark.asyncio
async def test_keeps_original_token_when_refresh_fails(
self, conversation_manager, conversation_init_data_with_user_id
):
"""
Test: Keeps original token when refresh fails (error handling).
Arrange: ConversationInitData with token, TokenManager raises exception
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Original token is preserved, no exception raised
"""
# Arrange
sid = 'test_session_error'
original_token = conversation_init_data_with_user_id.git_provider_tokens[
ProviderType.GITLAB
].token.get_secret_value()
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock(
side_effect=Exception('Token refresh failed')
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
conversation_init_data_with_user_id, sid
)
)
# Assert
assert (
result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value()
== original_token
)
@pytest.mark.asyncio
async def test_keeps_original_token_when_no_fresh_token_available(
self, conversation_manager, conversation_init_data_with_user_id
):
"""
Test: Keeps original token when no fresh token is available.
Arrange: ConversationInitData with token, TokenManager returns None
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Original token is preserved
"""
# Arrange
sid = 'test_session_no_fresh'
original_token = conversation_init_data_with_user_id.git_provider_tokens[
ProviderType.GITLAB
].token.get_secret_value()
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock(
return_value=None
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
conversation_init_data_with_user_id, sid
)
)
# Assert
assert (
result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value()
== original_token
)
@pytest.mark.asyncio
async def test_creates_new_settings_object_preserving_immutability(
self, conversation_manager, conversation_init_data_with_user_id
):
"""
Test: Creates new settings object (respects Pydantic frozen fields).
Arrange: ConversationInitData with frozen git_provider_tokens field
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Returns a new ConversationInitData object, not the same instance
"""
# Arrange
sid = 'test_session_immutable'
fresh_token = 'fresh_token_new'
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock(
return_value=fresh_token
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
conversation_init_data_with_user_id, sid
)
)
# Assert
assert result is not conversation_init_data_with_user_id
assert isinstance(result, ConversationInitData)
@pytest.mark.asyncio
async def test_handles_multiple_providers(self, conversation_manager):
"""
Test: Handles multiple provider tokens correctly.
Arrange: ConversationInitData with both GitLab and GitHub tokens
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Both tokens are refreshed independently
"""
# Arrange
sid = 'test_session_multi'
gitlab_token = ProviderToken(
token=SecretStr('old_gitlab_token'), user_id='gitlab_user_123', host=None
)
github_token = ProviderToken(
token=SecretStr('old_github_token'), user_id='github_user_456', host=None
)
settings = ConversationInitData(
git_provider_tokens=MappingProxyType(
{ProviderType.GITLAB: gitlab_token, ProviderType.GITHUB: github_token}
)
)
fresh_gitlab_token = 'fresh_gitlab_token'
fresh_github_token = 'fresh_github_token'
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
async def mock_get_token(user_id, provider_type):
if provider_type == ProviderType.GITLAB:
return fresh_gitlab_token
elif provider_type == ProviderType.GITHUB:
return fresh_github_token
return None
mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock(
side_effect=mock_get_token
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
settings, sid
)
)
# Assert
assert (
result.git_provider_tokens[ProviderType.GITLAB].token.get_secret_value()
== fresh_gitlab_token
)
assert (
result.git_provider_tokens[ProviderType.GITHUB].token.get_secret_value()
== fresh_github_token
)
assert mock_token_manager.get_idp_token_from_idp_user_id.call_count == 2
@pytest.mark.asyncio
async def test_preserves_token_host_field(self, conversation_manager):
"""
Test: Preserves the host field from original token.
Arrange: ProviderToken with custom host value
Act: Call _refresh_provider_tokens_after_runtime_init
Assert: Host field is preserved in the refreshed token
"""
# Arrange
sid = 'test_session_host'
custom_host = 'gitlab.example.com'
token_with_host = ProviderToken(
token=SecretStr('old_token'), user_id='user_789', host=custom_host
)
settings = ConversationInitData(
git_provider_tokens=MappingProxyType({ProviderType.GITLAB: token_with_host})
)
fresh_token = 'fresh_token_with_host'
with patch(
'enterprise.server.saas_nested_conversation_manager.TokenManager'
) as mock_token_manager_class:
mock_token_manager = AsyncMock()
mock_token_manager.get_idp_token_from_idp_user_id = AsyncMock(
return_value=fresh_token
)
mock_token_manager_class.return_value = mock_token_manager
# Act
result = (
await conversation_manager._refresh_provider_tokens_after_runtime_init(
settings, sid
)
)
# Assert
assert result.git_provider_tokens[ProviderType.GITLAB].host == custom_host
@@ -1,11 +1,13 @@
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from pydantic import SecretStr
from server.constants import (
CURRENT_USER_SETTINGS_VERSION,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
get_default_litellm_model,
)
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
@@ -334,6 +336,80 @@ async def test_update_settings_with_litellm_default_error(settings_store):
assert settings is None
@pytest.mark.asyncio
@pytest.mark.parametrize(
'status_code,user_info_response,should_succeed',
[
# 200 OK with user info - existing user (v1.79.x and v1.80+ behavior)
(200, {'user_info': {'max_budget': 10, 'spend': 5}}, True),
# 200 OK with empty user info - new user (v1.79.x behavior)
(200, {'user_info': None}, True),
# 404 Not Found - new user (v1.80+ behavior)
(404, None, True),
# 500 Internal Server Error - should fail
(500, None, False),
],
)
async def test_update_settings_with_litellm_default_handles_user_info_responses(
settings_store, session_maker, status_code, user_info_response, should_succeed
):
"""Test that various LiteLLM user/info responses are handled correctly.
LiteLLM API behavior changed between versions:
- v1.79.x and earlier: GET /user/info always succeeds with empty user_info
- v1.80.x and later: GET /user/info returns 404 for non-existent users
"""
mock_get_response = MagicMock()
mock_get_response.status_code = status_code
if user_info_response is not None:
mock_get_response.json = MagicMock(return_value=user_info_response)
mock_get_response.raise_for_status = MagicMock()
else:
mock_get_response.raise_for_status = MagicMock(
side_effect=httpx.HTTPStatusError(
'Error', request=MagicMock(), response=mock_get_response
)
if status_code >= 500
else None
)
# Mock successful responses for POST operations (delete and create)
mock_post_response = MagicMock()
mock_post_response.is_success = True
mock_post_response.json = MagicMock(return_value={'key': 'new_user_api_key'})
with (
patch('storage.saas_settings_store.LITE_LLM_API_KEY', 'test_key'),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
patch('storage.saas_settings_store.LITE_LLM_TEAM_ID', 'test_team'),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testuser@example.com'}),
),
patch('httpx.AsyncClient') as mock_client,
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Set up the mock client
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_get_response
)
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_post_response
)
settings = Settings()
if should_succeed:
settings = await settings_store.update_settings_with_litellm_default(
settings
)
assert settings is not None
assert settings.llm_api_key is not None
assert settings.llm_api_key.get_secret_value() == 'new_user_api_key'
else:
with pytest.raises(httpx.HTTPStatusError):
await settings_store.update_settings_with_litellm_default(settings)
@pytest.mark.asyncio
async def test_update_settings_with_litellm_retry_on_duplicate_email(
settings_store, mock_litellm_api, session_maker
@@ -393,10 +469,11 @@ async def test_create_user_in_lite_llm(settings_store):
mock_response = AsyncMock()
mock_response.is_success = True
mock_client.post.return_value = mock_response
test_model = 'custom-model/test-model'
# Test with email
await settings_store._create_user_in_lite_llm(
mock_client, 'test@example.com', 50, 10
mock_client, 'test@example.com', 50, 10, test_model
)
# Get the actual call arguments
@@ -412,11 +489,11 @@ async def test_create_user_in_lite_llm(settings_store):
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
assert call_args['json']['metadata']['model'] == test_model
# Test with None email
mock_client.post.reset_mock()
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15, test_model)
# Get the actual call arguments
call_args = mock_client.post.call_args[1]
@@ -431,12 +508,12 @@ async def test_create_user_in_lite_llm(settings_store):
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
assert call_args['json']['metadata']['model'] == test_model
# Verify response is returned correctly
assert (
await settings_store._create_user_in_lite_llm(
mock_client, 'email@test.com', 30, 7
mock_client, 'email@test.com', 30, 7, test_model
)
== mock_response
)
@@ -464,3 +541,808 @@ async def test_encryption(settings_store):
# But we should be able to decrypt it when loading
loaded_settings = await settings_store.load()
assert loaded_settings.llm_api_key.get_secret_value() == 'secret_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has a custom LLM model set
custom_model = 'anthropic/claude-3-5-sonnet-20241022'
settings = Settings(llm_model=custom_model)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom model is preserved
assert updated_settings is not None
assert updated_settings.llm_model == custom_model
assert updated_settings.agent == 'CodeActAgent'
assert updated_settings.llm_api_key is not None
# Assert: LiteLLM metadata contains user's custom model
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['metadata']['model'] == custom_model
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_uses_default_when_no_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has no model set (new user scenario)
settings = Settings()
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'newuser@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default model is assigned
assert updated_settings is not None
expected_default = get_default_litellm_model()
assert updated_settings.llm_model == expected_default
assert updated_settings.agent == 'CodeActAgent'
# Assert: LiteLLM metadata contains default model
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['metadata']['model'] == expected_default
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_empty_string_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has empty string as model (edge case)
settings = Settings(llm_model='')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default model is used (empty string treated as no model)
assert updated_settings is not None
expected_default = get_default_litellm_model()
assert updated_settings.llm_model == expected_default
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_whitespace_model(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has whitespace-only model (edge case)
settings = Settings(llm_model=' ')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default model is used (whitespace treated as no model)
assert updated_settings is not None
expected_default = get_default_litellm_model()
assert updated_settings.llm_model == expected_default
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_api_key(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has a custom API key and custom model (so has_custom=True)
custom_api_key = 'sk-custom-user-api-key-12345'
custom_model = 'gpt-4'
settings = Settings(llm_model=custom_model, llm_api_key=SecretStr(custom_api_key))
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom API key is preserved when user has custom settings
assert updated_settings is not None
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
assert updated_settings.llm_api_key.get_secret_value() != 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has a custom base URL
custom_base_url = 'https://api.custom-llm-provider.com/v1'
settings = Settings(llm_base_url=custom_base_url)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom base URL is preserved
assert updated_settings is not None
assert updated_settings.llm_base_url == custom_base_url
assert updated_settings.llm_base_url != LITE_LLM_API_URL
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_preserves_custom_api_key_and_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has both custom API key and base URL
custom_api_key = 'sk-custom-user-api-key-67890'
custom_base_url = 'https://api.another-llm-provider.com/v1'
custom_model = 'openai/gpt-4'
settings = Settings(
llm_model=custom_model,
llm_api_key=SecretStr(custom_api_key),
llm_base_url=custom_base_url,
)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: All custom settings are preserved
assert updated_settings is not None
assert updated_settings.llm_model == custom_model
assert updated_settings.llm_api_key.get_secret_value() == custom_api_key
assert updated_settings.llm_base_url == custom_base_url
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_uses_default_api_key_when_none(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has no API key set
settings = Settings(llm_api_key=None)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default LiteLLM API key is assigned
assert updated_settings is not None
assert updated_settings.llm_api_key is not None
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_uses_default_base_url_when_none(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has no base URL set
settings = Settings(llm_base_url=None)
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default LiteLLM base URL is assigned (using mocked value)
assert updated_settings is not None
assert updated_settings.llm_base_url == 'http://test.url'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_empty_api_key(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has empty string as API key (edge case)
settings = Settings(llm_api_key=SecretStr(''))
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default API key is used (empty string treated as no key)
assert updated_settings is not None
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_empty_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has empty string as base URL (edge case)
settings = Settings(llm_base_url='')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default base URL is used (empty string treated as no URL)
assert updated_settings is not None
assert updated_settings.llm_base_url == 'http://test.url'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_whitespace_api_key(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has whitespace-only API key (edge case)
settings = Settings(llm_api_key=SecretStr(' '))
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default API key is used (whitespace treated as no key)
assert updated_settings is not None
assert updated_settings.llm_api_key.get_secret_value() == 'test_api_key'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_handles_whitespace_base_url(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User has whitespace-only base URL (edge case)
settings = Settings(llm_base_url=' ')
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'),
):
# Act: Update settings with LiteLLM defaults
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Default base URL is used (whitespace treated as no URL)
assert updated_settings is not None
assert updated_settings.llm_base_url == 'http://test.url'
# Tests for version migration and helper methods
@pytest.mark.asyncio
async def test_has_custom_settings_with_custom_base_url(settings_store):
# Arrange: User with custom base URL (BYOR)
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_base_url='http://custom.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: Custom base URL detected
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_default_base_url(settings_store):
# Arrange: User with default base URL
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (no model set)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_no_model(settings_store):
# Arrange: User with no model set
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model=None, llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (using defaults)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_empty_model(settings_store):
# Arrange: User with empty model
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model='', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (empty treated as no model)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_whitespace_model(settings_store):
# Arrange: User with whitespace-only model
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model=' ', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No custom settings (whitespace treated as no model)
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_custom_model(settings_store):
# Arrange: User with custom model
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: Custom model detected
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_matches_old_default_model(settings_store):
# Arrange: User with old version and model matching old default
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(
llm_model='litellm_proxy/prod/claude-3-5-sonnet-20241022',
llm_base_url='http://default.url',
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 1)
# Assert: Matches old default, so not custom
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_matches_old_default_by_base_name(settings_store):
# Arrange: User with old version and model matching old default by base name
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(
llm_model='anthropic/claude-3-5-sonnet-20241022',
llm_base_url='http://default.url',
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 1)
# Assert: Matches old default by base name, so not custom
assert has_custom is False
@pytest.mark.asyncio
async def test_has_custom_settings_with_old_version_but_custom_model(settings_store):
# Arrange: User with old version but custom model
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(llm_model='gpt-4', llm_base_url='http://default.url')
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 1)
# Assert: Custom model detected
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_current_version(settings_store):
# Arrange: User with current version
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
),
):
settings = Settings(
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 5)
# Assert: Current version, so model is custom (not old default)
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_none_version(settings_store):
# Arrange: User with no version
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: No version, so model is custom
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_with_invalid_version(settings_store):
# Arrange: User with invalid version
with (
patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
):
settings = Settings(
llm_model='claude-3-5-sonnet-20241022', llm_base_url='http://default.url'
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, 99)
# Assert: Invalid version, so model is custom
assert has_custom is True
@pytest.mark.asyncio
async def test_has_custom_settings_normalizes_whitespace(settings_store):
# Arrange: Settings with whitespace in values
with patch('storage.saas_settings_store.LITE_LLM_API_URL', 'http://default.url'):
settings = Settings(
llm_model=' claude-3-5-sonnet-20241022 ',
llm_base_url=' http://default.url ',
)
# Act: Check if has custom settings
has_custom = settings_store._has_custom_settings(settings, None)
# Assert: Whitespace is normalized, custom model detected
assert has_custom is True
@pytest.mark.asyncio
async def test_update_settings_upgrades_user_from_old_defaults(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with old version using old defaults
old_version = 1
old_model = 'litellm_proxy/prod/claude-3-5-sonnet-20241022'
settings = Settings(llm_model=old_model, llm_base_url=LITE_LLM_API_URL)
# Use a consistent test URL
test_base_url = 'http://test.url'
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
),
patch(
'storage.saas_settings_store.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022', 5: 'claude-opus-4-5-20251101'},
),
patch('server.constants.CURRENT_USER_SETTINGS_VERSION', 5),
patch('storage.saas_settings_store.CURRENT_USER_SETTINGS_VERSION', 5),
patch('storage.saas_settings_store.LITE_LLM_API_URL', test_base_url),
patch(
'storage.saas_settings_store.get_default_litellm_model',
return_value='litellm_proxy/prod/claude-opus-4-5-20251101',
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
):
# Create existing user settings with old version
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=old_version,
llm_model=old_model,
llm_base_url=test_base_url,
)
session.add(existing_settings)
session.commit()
# Update settings to use test_base_url
# Set user_version to match the database so _has_custom_settings can detect old defaults
settings = Settings(
llm_model=old_model, llm_base_url=test_base_url, user_version=old_version
)
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Settings upgraded to new defaults
assert updated_settings is not None
assert (
updated_settings.llm_model == 'litellm_proxy/prod/claude-opus-4-5-20251101'
)
assert updated_settings.llm_base_url == test_base_url
@pytest.mark.asyncio
async def test_update_settings_preserves_custom_settings_during_upgrade(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with old version but custom settings
old_version = 1
custom_model = 'gpt-4'
custom_base_url = 'http://custom.url'
settings = Settings(llm_model=custom_model, llm_base_url=custom_base_url)
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.constants.USER_SETTINGS_VERSION_TO_MODEL',
{1: 'claude-3-5-sonnet-20241022'},
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
):
# Create existing user settings with old version
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=old_version,
llm_model=custom_model,
llm_base_url=custom_base_url,
)
session.add(existing_settings)
session.commit()
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Custom settings preserved
assert updated_settings is not None
assert updated_settings.llm_model == custom_model
assert updated_settings.llm_base_url == custom_base_url
@pytest.mark.asyncio
async def test_update_settings_migrates_billing_margin_v3_to_v4(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with version 3 and billing margin
old_version = 3
billing_margin = 2.0
max_budget = 10.0
spend = 5.0
settings = Settings()
mock_get_response = AsyncMock()
mock_get_response.is_success = True
mock_get_response.json = MagicMock(
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
)
mock_post_response = AsyncMock()
mock_post_response.is_success = True
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('httpx.AsyncClient') as mock_client,
):
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_get_response
)
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_post_response
)
# Create existing user settings with version 3 and billing margin
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=old_version,
billing_margin=billing_margin,
)
session.add(existing_settings)
session.commit()
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Settings updated
assert updated_settings is not None
# Assert: Billing margin applied to budget
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['max_budget'] == max_budget * billing_margin
assert call_args['json']['spend'] == spend * billing_margin
# Assert: Billing margin reset to 1.0
with session_maker() as session:
updated_user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == settings_store.user_id)
.first()
)
assert updated_user_settings.billing_margin == 1.0
@pytest.mark.asyncio
async def test_update_settings_skips_billing_margin_migration_when_already_v4(
settings_store, mock_litellm_api, session_maker
):
# Arrange: User with version 4
version = 4
billing_margin = 2.0
max_budget = 10.0
spend = 5.0
settings = Settings()
mock_get_response = AsyncMock()
mock_get_response.is_success = True
mock_get_response.json = MagicMock(
return_value={'user_info': {'max_budget': max_budget, 'spend': spend}}
)
mock_post_response = AsyncMock()
mock_post_response.is_success = True
mock_post_response.json = MagicMock(return_value={'key': 'test_api_key'})
with (
patch('storage.saas_settings_store.session_maker', session_maker),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'user@example.com'}),
),
patch('httpx.AsyncClient') as mock_client,
):
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_get_response
)
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_post_response
)
# Create existing user settings with version 4
with session_maker() as session:
existing_settings = UserSettings(
keycloak_user_id=settings_store.user_id,
user_version=version,
billing_margin=billing_margin,
)
session.add(existing_settings)
session.commit()
# Act: Update settings
updated_settings = await settings_store.update_settings_with_litellm_default(
settings
)
# Assert: Settings updated
assert updated_settings is not None
# Assert: Billing margin NOT applied (version >= 4)
call_args = mock_client.return_value.__aenter__.return_value.post.call_args[1]
assert call_args['json']['max_budget'] == max_budget
assert call_args['json']['spend'] == spend
+100 -1
View File
@@ -5,7 +5,12 @@ import jwt
import pytest
from fastapi import Request
from pydantic import SecretStr
from server.auth.auth_error import BearerTokenError, CookieError, NoCredentialsError
from server.auth.auth_error import (
AuthError,
BearerTokenError,
CookieError,
NoCredentialsError,
)
from server.auth.saas_user_auth import (
SaasUserAuth,
get_api_key_from_header,
@@ -647,3 +652,97 @@ def test_get_api_key_from_header_bearer_with_empty_token():
# Assert that empty string from Bearer is returned (current behavior)
# This tests the current implementation behavior
assert api_key == ''
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_blocked_domain(mock_config):
"""Test that saas_user_auth_from_signed_token raises AuthError when email domain is blocked."""
# Arrange
access_payload = {
'sub': 'test_user_id',
'exp': int(time.time()) + 3600,
'email': 'user@colsch.us',
'email_verified': True,
}
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
token_payload = {
'access_token': access_token,
'refresh_token': 'test_refresh_token',
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = True
# Act & Assert
with pytest.raises(AuthError) as exc_info:
await saas_user_auth_from_signed_token(signed_token)
assert 'email domain is not allowed' in str(exc_info.value)
mock_domain_blocker.is_domain_blocked.assert_called_once_with('user@colsch.us')
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_allowed_domain(mock_config):
"""Test that saas_user_auth_from_signed_token succeeds when email domain is not blocked."""
# Arrange
access_payload = {
'sub': 'test_user_id',
'exp': int(time.time()) + 3600,
'email': 'user@example.com',
'email_verified': True,
}
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
token_payload = {
'access_token': access_token,
'refresh_token': 'test_refresh_token',
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = True
mock_domain_blocker.is_domain_blocked.return_value = False
# Act
result = await saas_user_auth_from_signed_token(signed_token)
# Assert
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
assert result.email == 'user@example.com'
mock_domain_blocker.is_domain_blocked.assert_called_once_with(
'user@example.com'
)
@pytest.mark.asyncio
async def test_saas_user_auth_from_signed_token_domain_blocking_inactive(mock_config):
"""Test that saas_user_auth_from_signed_token succeeds when domain blocking is not active."""
# Arrange
access_payload = {
'sub': 'test_user_id',
'exp': int(time.time()) + 3600,
'email': 'user@colsch.us',
'email_verified': True,
}
access_token = jwt.encode(access_payload, 'access_secret', algorithm='HS256')
token_payload = {
'access_token': access_token,
'refresh_token': 'test_refresh_token',
}
signed_token = jwt.encode(token_payload, 'test_secret', algorithm='HS256')
with patch('server.auth.saas_user_auth.domain_blocker') as mock_domain_blocker:
mock_domain_blocker.is_active.return_value = False
# Act
result = await saas_user_auth_from_signed_token(signed_token)
# Assert
assert isinstance(result, SaasUserAuth)
assert result.user_id == 'test_user_id'
mock_domain_blocker.is_domain_blocked.assert_not_called()
@@ -0,0 +1 @@
"""Tests for sharing package."""
@@ -0,0 +1,91 @@
"""Tests for public conversation models."""
from datetime import datetime
from uuid import uuid4
from server.sharing.shared_conversation_models import (
SharedConversation,
SharedConversationPage,
SharedConversationSortOrder,
)
def test_public_conversation_creation():
"""Test that SharedConversation can be created with all required fields."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
assert conversation.id == conversation_id
assert conversation.title == 'Test Conversation'
assert conversation.created_by_user_id == 'test_user'
assert conversation.sandbox_id == 'test_sandbox'
def test_public_conversation_page_creation():
"""Test that SharedConversationPage can be created."""
conversation_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository=None,
parent_conversation_id=None,
)
page = SharedConversationPage(
items=[conversation],
next_page_id='next_page',
)
assert len(page.items) == 1
assert page.items[0].id == conversation_id
assert page.next_page_id == 'next_page'
def test_public_conversation_sort_order_enum():
"""Test that SharedConversationSortOrder enum has expected values."""
assert hasattr(SharedConversationSortOrder, 'CREATED_AT')
assert hasattr(SharedConversationSortOrder, 'CREATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT')
assert hasattr(SharedConversationSortOrder, 'UPDATED_AT_DESC')
assert hasattr(SharedConversationSortOrder, 'TITLE')
assert hasattr(SharedConversationSortOrder, 'TITLE_DESC')
def test_public_conversation_optional_fields():
"""Test that SharedConversation works with optional fields."""
conversation_id = uuid4()
parent_id = uuid4()
now = datetime.utcnow()
conversation = SharedConversation(
id=conversation_id,
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Conversation',
created_at=now,
updated_at=now,
selected_repository='owner/repo',
parent_conversation_id=parent_id,
llm_model='gpt-4',
)
assert conversation.selected_repository == 'owner/repo'
assert conversation.parent_conversation_id == parent_id
assert conversation.llm_model == 'gpt-4'
@@ -0,0 +1,430 @@
"""Tests for SharedConversationInfoService."""
from datetime import UTC, datetime
from typing import AsyncGenerator
from uuid import uuid4
import pytest
from server.sharing.shared_conversation_models import (
SharedConversationSortOrder,
)
from server.sharing.sql_shared_conversation_info_service import (
SQLSharedConversationInfoService,
)
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.provider import ProviderType
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
async def shared_conversation_info_service(async_session):
"""Create a SharedConversationInfoService for testing."""
return SQLSharedConversationInfoService(db_session=async_session)
@pytest.fixture
async def app_conversation_service(async_session):
"""Create an AppConversationInfoService for creating test data."""
return SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
def sample_conversation_info():
"""Create a sample conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
selected_repository='test/repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Test Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[123],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=1.5,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=100,
completion_tokens=50,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=150,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=True, # Make it public for testing
)
@pytest.fixture
def sample_private_conversation_info():
"""Create a sample private conversation info for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_private',
selected_repository='test/private_repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Private Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[124],
llm_model='gpt-4',
metrics=MetricsSnapshot(
accumulated_cost=2.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(
prompt_tokens=200,
completion_tokens=100,
cache_read_tokens=0,
cache_write_tokens=0,
context_window=4096,
per_turn_token=300,
),
),
parent_conversation_id=None,
sub_conversation_ids=[],
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
public=False, # Make it private
)
class TestSharedConversationInfoService:
"""Test cases for SharedConversationInfoService."""
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_public_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test that get_shared_conversation_info returns a public conversation."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_conversation_info.id
)
assert result is not None
assert result.id == sample_conversation_info.id
assert result.title == sample_conversation_info.title
assert result.created_by_user_id == sample_conversation_info.created_by_user_id
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_private_conversation(
self,
shared_conversation_info_service,
app_conversation_service,
sample_private_conversation_info,
):
"""Test that get_shared_conversation_info returns None for private conversations."""
# Create a private conversation
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Try to retrieve it via public service
result = await shared_conversation_info_service.get_shared_conversation_info(
sample_private_conversation_info.id
)
assert result is None
@pytest.mark.asyncio
async def test_get_shared_conversation_info_returns_none_for_nonexistent_conversation(
self, shared_conversation_info_service
):
"""Test that get_shared_conversation_info returns None for nonexistent conversations."""
nonexistent_id = uuid4()
result = await shared_conversation_info_service.get_shared_conversation_info(
nonexistent_id
)
assert result is None
@pytest.mark.asyncio
async def test_search_shared_conversation_info_returns_only_public_conversations(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test that search only returns public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Search for all conversations
result = (
await shared_conversation_info_service.search_shared_conversation_info()
)
# Should only return the public conversation
assert len(result.items) == 1
assert result.items[0].id == sample_conversation_info.id
assert result.items[0].title == sample_conversation_info.title
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_title_filter(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
):
"""Test searching with title filter."""
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
# Search with matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='Test'
)
assert len(result.items) == 1
# Search with non-matching title
result = await shared_conversation_info_service.search_shared_conversation_info(
title__contains='NonExistent'
)
assert len(result.items) == 0
@pytest.mark.asyncio
async def test_search_shared_conversation_info_with_sort_order(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test searching with different sort orders."""
# Create multiple public conversations with different titles and timestamps
conv1 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_1',
title='A First Conversation',
created_at=datetime(2023, 1, 1, tzinfo=UTC),
updated_at=datetime(2023, 1, 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conv2 = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox_2',
title='B Second Conversation',
created_at=datetime(2023, 1, 2, tzinfo=UTC),
updated_at=datetime(2023, 1, 2, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
await app_conversation_service.save_app_conversation_info(conv1)
await app_conversation_service.save_app_conversation_info(conv2)
# Test sort by title ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE
)
assert len(result.items) == 2
assert result.items[0].title == 'A First Conversation'
assert result.items[1].title == 'B Second Conversation'
# Test sort by title descending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.TITLE_DESC
)
assert len(result.items) == 2
assert result.items[0].title == 'B Second Conversation'
assert result.items[1].title == 'A First Conversation'
# Test sort by created_at ascending
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.items[0].id == conv1.id
assert result.items[1].id == conv2.id
# Test sort by created_at descending (default)
result = await shared_conversation_info_service.search_shared_conversation_info(
sort_order=SharedConversationSortOrder.CREATED_AT_DESC
)
assert len(result.items) == 2
assert result.items[0].id == conv2.id
assert result.items[1].id == conv1.id
@pytest.mark.asyncio
async def test_count_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test counting public conversations."""
# Initially should be 0
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 0
# Create a public conversation
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
# Create a private conversation - count should remain 1
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
count = await shared_conversation_info_service.count_shared_conversation_info()
assert count == 1
@pytest.mark.asyncio
async def test_batch_get_shared_conversation_info(
self,
shared_conversation_info_service,
app_conversation_service,
sample_conversation_info,
sample_private_conversation_info,
):
"""Test batch getting public conversations."""
# Create both public and private conversations
await app_conversation_service.save_app_conversation_info(
sample_conversation_info
)
await app_conversation_service.save_app_conversation_info(
sample_private_conversation_info
)
# Batch get both conversations
result = (
await shared_conversation_info_service.batch_get_shared_conversation_info(
[sample_conversation_info.id, sample_private_conversation_info.id]
)
)
# Should return the public one and None for the private one
assert len(result) == 2
assert result[0] is not None
assert result[0].id == sample_conversation_info.id
assert result[1] is None
@pytest.mark.asyncio
async def test_search_with_pagination(
self,
shared_conversation_info_service,
app_conversation_service,
):
"""Test search with pagination."""
# Create multiple public conversations
conversations = []
for i in range(5):
conv = AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id=f'test_sandbox_{i}',
title=f'Conversation {i}',
created_at=datetime(2023, 1, i + 1, tzinfo=UTC),
updated_at=datetime(2023, 1, i + 1, tzinfo=UTC),
public=True,
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
conversations.append(conv)
await app_conversation_service.save_app_conversation_info(conv)
# Get first page with limit 2
result = await shared_conversation_info_service.search_shared_conversation_info(
limit=2, sort_order=SharedConversationSortOrder.CREATED_AT
)
assert len(result.items) == 2
assert result.next_page_id is not None
# Get next page
result2 = (
await shared_conversation_info_service.search_shared_conversation_info(
limit=2,
page_id=result.next_page_id,
sort_order=SharedConversationSortOrder.CREATED_AT,
)
)
assert len(result2.items) == 2
assert result2.next_page_id is not None
# Verify no overlap between pages
page1_ids = {item.id for item in result.items}
page2_ids = {item.id for item in result2.items}
assert page1_ids.isdisjoint(page2_ids)
@@ -0,0 +1,365 @@
"""Tests for SharedEventService."""
from datetime import UTC, datetime
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from server.sharing.filesystem_shared_event_service import (
SharedEventServiceImpl,
)
from server.sharing.shared_conversation_info_service import (
SharedConversationInfoService,
)
from server.sharing.shared_conversation_models import SharedConversation
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event.event_service import EventService
from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
@pytest.fixture
def mock_shared_conversation_info_service():
"""Create a mock SharedConversationInfoService."""
return AsyncMock(spec=SharedConversationInfoService)
@pytest.fixture
def mock_event_service():
"""Create a mock EventService."""
return AsyncMock(spec=EventService)
@pytest.fixture
def shared_event_service(mock_shared_conversation_info_service, mock_event_service):
"""Create a SharedEventService for testing."""
return SharedEventServiceImpl(
shared_conversation_info_service=mock_shared_conversation_info_service,
event_service=mock_event_service,
)
@pytest.fixture
def sample_public_conversation():
"""Create a sample public conversation."""
return SharedConversation(
id=uuid4(),
created_by_user_id='test_user',
sandbox_id='test_sandbox',
title='Test Public Conversation',
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
metrics=MetricsSnapshot(
accumulated_cost=0.0,
max_budget_per_task=10.0,
accumulated_token_usage=TokenUsage(),
),
)
@pytest.fixture
def sample_event():
"""Create a sample event."""
# For testing purposes, we'll just use a mock that the EventPage can accept
# The actual event creation is complex and not the focus of these tests
return None
class TestSharedEventService:
"""Test cases for SharedEventService."""
async def test_get_shared_event_returns_event_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that get_shared_event returns an event for a public conversation."""
conversation_id = sample_public_conversation.id
event_id = 'test_event_id'
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return an event
mock_event_service.get_event.return_value = sample_event
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result == sample_event
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.get_event.assert_called_once_with(event_id)
async def test_get_shared_event_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that get_shared_event returns None for a private conversation."""
conversation_id = uuid4()
event_id = 'test_event_id'
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.get_shared_event(conversation_id, event_id)
# Verify the result
assert result is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that search_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id=None)
mock_event_service.search_events.return_value = mock_event_page
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
limit=10,
)
# Verify the result
assert result == mock_event_page
assert len(result.items) == 0 # Empty list as we mocked
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
page_id=None,
limit=10,
)
async def test_search_shared_events_returns_empty_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that search_shared_events returns empty page for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
limit=10,
)
# Verify the result
assert isinstance(result, EventPage)
assert len(result.items) == 0
assert result.next_page_id is None
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.search_events.assert_not_called()
async def test_count_shared_events_returns_count_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test that count_shared_events returns count for a public conversation."""
conversation_id = sample_public_conversation.id
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return a count
mock_event_service.count_events.return_value = 5
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
kind__eq='ActionEvent',
)
# Verify the result
assert result == 5
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
mock_event_service.count_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ActionEvent',
timestamp__gte=None,
timestamp__lt=None,
sort_order=EventSortOrder.TIMESTAMP,
)
async def test_count_shared_events_returns_zero_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that count_shared_events returns 0 for a private conversation."""
conversation_id = uuid4()
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.count_shared_events(
conversation_id=conversation_id,
)
# Verify the result
assert result == 0
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
conversation_id
)
# Event service should not be called
mock_event_service.count_events.assert_not_called()
async def test_batch_get_shared_events_returns_events_for_public_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
sample_event,
):
"""Test that batch_get_shared_events returns events for a public conversation."""
conversation_id = sample_public_conversation.id
event_ids = ['event1', 'event2']
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_service.get_event.side_effect = [sample_event, None]
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] == sample_event
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Verify that get_event was called for each event
assert mock_event_service.get_event.call_count == 2
async def test_batch_get_shared_events_returns_none_for_private_conversation(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
):
"""Test that batch_get_shared_events returns None for a private conversation."""
conversation_id = uuid4()
event_ids = ['event1', 'event2']
# Mock the public conversation service to return None (private conversation)
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
# Call the method
result = await shared_event_service.batch_get_shared_events(
conversation_id, event_ids
)
# Verify the result
assert len(result) == 2
assert result[0] is None
assert result[1] is None
# Verify that get_shared_conversation_info was called for each event
assert (
mock_shared_conversation_info_service.get_shared_conversation_info.call_count
== 2
)
# Event service should not be called
mock_event_service.get_event.assert_not_called()
async def test_search_shared_events_with_all_parameters(
self,
shared_event_service,
mock_shared_conversation_info_service,
mock_event_service,
sample_public_conversation,
):
"""Test search_shared_events with all parameters."""
conversation_id = sample_public_conversation.id
timestamp_gte = datetime(2023, 1, 1, tzinfo=UTC)
timestamp_lt = datetime(2023, 12, 31, tzinfo=UTC)
# Mock the public conversation service to return a public conversation
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
# Mock the event service to return events
mock_event_page = EventPage(items=[], next_page_id='next_page')
mock_event_service.search_events.return_value = mock_event_page
# Call the method with all parameters
result = await shared_event_service.search_shared_events(
conversation_id=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)
# Verify the result
assert result == mock_event_page
mock_event_service.search_events.assert_called_once_with(
conversation_id__eq=conversation_id,
kind__eq='ObservationEvent',
timestamp__gte=timestamp_gte,
timestamp__lt=timestamp_lt,
sort_order=EventSortOrder.TIMESTAMP_DESC,
page_id='current_page',
limit=50,
)
+427 -1
View File
@@ -1,6 +1,8 @@
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from keycloak.exceptions import KeycloakConnectionError, KeycloakError
from server.auth.token_manager import TokenManager
from sqlalchemy.orm import Session
from storage.offline_token_store import OfflineTokenStore
from storage.stored_offline_token import StoredOfflineToken
@@ -32,6 +34,14 @@ def token_store(mock_session_maker, mock_config):
return OfflineTokenStore('test_user_id', mock_session_maker, mock_config)
@pytest.fixture
def token_manager():
with patch('server.config.get_config') as mock_get_config:
mock_config = mock_get_config.return_value
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
return TokenManager(external=False)
@pytest.mark.asyncio
async def test_store_token_new_record(token_store, mock_session):
# Setup
@@ -109,3 +119,419 @@ async def test_get_instance(mock_config):
assert isinstance(result, OfflineTokenStore)
assert result.user_id == test_user_id
assert result.config == mock_config
class TestCheckDuplicateBaseEmail:
"""Test cases for check_duplicate_base_email method."""
@pytest.mark.asyncio
async def test_check_duplicate_base_email_no_plus_modifier(self, token_manager):
"""Test that emails without + modifier are still checked for duplicates."""
# Arrange
email = 'joe@example.com'
current_user_id = 'user123'
with (
patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query,
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
):
mock_find.return_value = False
mock_query.return_value = {}
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is False
mock_query.assert_called_once()
mock_find.assert_called_once()
@pytest.mark.asyncio
async def test_check_duplicate_base_email_empty_email(self, token_manager):
"""Test that empty email returns False."""
# Arrange
email = ''
current_user_id = 'user123'
# Act
result = await token_manager.check_duplicate_base_email(email, current_user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_duplicate_base_email_invalid_email(self, token_manager):
"""Test that invalid email returns False."""
# Arrange
email = 'invalid-email'
current_user_id = 'user123'
# Act
result = await token_manager.check_duplicate_base_email(email, current_user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_duplicate_base_email_duplicate_found(self, token_manager):
"""Test that duplicate email is detected when found."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
existing_user = {
'id': 'existing_user_id',
'email': 'joe@example.com',
}
with (
patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query,
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
):
mock_find.return_value = True
mock_query.return_value = {'existing_user_id': existing_user}
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is True
mock_query.assert_called_once()
mock_find.assert_called_once()
@pytest.mark.asyncio
async def test_check_duplicate_base_email_no_duplicate(self, token_manager):
"""Test that no duplicate is found when none exists."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
with (
patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query,
patch.object(token_manager, '_find_duplicate_in_users') as mock_find,
):
mock_find.return_value = False
mock_query.return_value = {}
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_check_duplicate_base_email_keycloak_connection_error(
self, token_manager
):
"""Test that KeycloakConnectionError triggers retry and raises RetryError."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
with patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query:
mock_query.side_effect = KeycloakConnectionError('Connection failed')
# Act & Assert
# KeycloakConnectionError is re-raised, which triggers retry decorator
# After retries exhaust (2 attempts), it raises RetryError
from tenacity import RetryError
with pytest.raises(RetryError):
await token_manager.check_duplicate_base_email(email, current_user_id)
@pytest.mark.asyncio
async def test_check_duplicate_base_email_general_exception(self, token_manager):
"""Test that general exceptions are handled gracefully."""
# Arrange
email = 'joe+test@example.com'
current_user_id = 'user123'
with patch.object(
token_manager, '_query_users_by_wildcard_pattern'
) as mock_query:
mock_query.side_effect = Exception('Unexpected error')
# Act
result = await token_manager.check_duplicate_base_email(
email, current_user_id
)
# Assert
assert result is False
class TestQueryUsersByWildcardPattern:
"""Test cases for _query_users_by_wildcard_pattern method."""
@pytest.mark.asyncio
async def test_query_users_by_wildcard_pattern_success_with_search(
self, token_manager
):
"""Test successful query using search parameter."""
# Arrange
local_part = 'joe'
domain = 'example.com'
mock_users = [
{'id': 'user1', 'email': 'joe@example.com'},
{'id': 'user2', 'email': 'joe+test@example.com'},
]
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_users = AsyncMock(return_value=mock_users)
mock_get_admin.return_value = mock_admin
# Act
result = await token_manager._query_users_by_wildcard_pattern(
local_part, domain
)
# Assert
assert len(result) == 2
assert 'user1' in result
assert 'user2' in result
mock_admin.a_get_users.assert_called_once_with(
{'search': 'joe*@example.com'}
)
@pytest.mark.asyncio
async def test_query_users_by_wildcard_pattern_fallback_to_q(self, token_manager):
"""Test fallback to q parameter when search fails."""
# Arrange
local_part = 'joe'
domain = 'example.com'
mock_users = [{'id': 'user1', 'email': 'joe@example.com'}]
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
# First call fails, second succeeds
mock_admin.a_get_users = AsyncMock(
side_effect=[Exception('Search failed'), mock_users]
)
mock_get_admin.return_value = mock_admin
# Act
result = await token_manager._query_users_by_wildcard_pattern(
local_part, domain
)
# Assert
assert len(result) == 1
assert 'user1' in result
assert mock_admin.a_get_users.call_count == 2
@pytest.mark.asyncio
async def test_query_users_by_wildcard_pattern_empty_result(self, token_manager):
"""Test query returns empty dict when no users found."""
# Arrange
local_part = 'joe'
domain = 'example.com'
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_users = AsyncMock(return_value=[])
mock_get_admin.return_value = mock_admin
# Act
result = await token_manager._query_users_by_wildcard_pattern(
local_part, domain
)
# Assert
assert result == {}
class TestFindDuplicateInUsers:
"""Test cases for _find_duplicate_in_users method."""
def test_find_duplicate_in_users_with_regex_match(self, token_manager):
"""Test finding duplicate using regex pattern."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'joe@example.com'},
'user2': {'id': 'user2', 'email': 'joe+test@example.com'},
}
base_email = 'joe@example.com'
current_user_id = 'user3'
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is True
def test_find_duplicate_in_users_fallback_to_simple_matching(self, token_manager):
"""Test fallback to simple matching when regex pattern is None."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'joe@example.com'},
}
base_email = 'invalid-email' # Will cause regex pattern to be None
current_user_id = 'user2'
with patch(
'server.auth.token_manager.get_base_email_regex_pattern', return_value=None
):
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
# Should use fallback matching, but invalid base_email won't match
assert result is False
def test_find_duplicate_in_users_excludes_current_user(self, token_manager):
"""Test that current user is excluded from duplicate check."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'joe@example.com'},
}
base_email = 'joe@example.com'
current_user_id = 'user1' # Same as user in users dict
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is False
def test_find_duplicate_in_users_no_match(self, token_manager):
"""Test that no duplicate is found when emails don't match."""
# Arrange
users = {
'user1': {'id': 'user1', 'email': 'jane@example.com'},
}
base_email = 'joe@example.com'
current_user_id = 'user2'
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is False
def test_find_duplicate_in_users_empty_dict(self, token_manager):
"""Test that empty users dict returns False."""
# Arrange
users: dict[str, dict] = {}
base_email = 'joe@example.com'
current_user_id = 'user1'
# Act
result = token_manager._find_duplicate_in_users(
users, base_email, current_user_id
)
# Assert
assert result is False
class TestDeleteKeycloakUser:
"""Test cases for delete_keycloak_user method."""
@pytest.mark.asyncio
async def test_delete_keycloak_user_success(self, token_manager):
"""Test successful deletion of Keycloak user."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.return_value = None
# Act
result = await token_manager.delete_keycloak_user(user_id)
# Assert
assert result is True
mock_to_thread.assert_called_once_with(mock_admin.delete_user, user_id)
@pytest.mark.asyncio
async def test_delete_keycloak_user_connection_error(self, token_manager):
"""Test handling of KeycloakConnectionError triggers retry and raises RetryError."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.side_effect = KeycloakConnectionError('Connection failed')
# Act & Assert
# KeycloakConnectionError triggers retry decorator
# After retries exhaust (2 attempts), it raises RetryError
from tenacity import RetryError
with pytest.raises(RetryError):
await token_manager.delete_keycloak_user(user_id)
@pytest.mark.asyncio
async def test_delete_keycloak_user_keycloak_error(self, token_manager):
"""Test handling of KeycloakError (e.g., user not found)."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.side_effect = KeycloakError('User not found')
# Act
result = await token_manager.delete_keycloak_user(user_id)
# Assert
assert result is False
@pytest.mark.asyncio
async def test_delete_keycloak_user_general_exception(self, token_manager):
"""Test handling of general exceptions."""
# Arrange
user_id = 'test_user_id'
with (
patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin,
patch('asyncio.to_thread') as mock_to_thread,
):
mock_admin = MagicMock()
mock_admin.delete_user = MagicMock()
mock_get_admin.return_value = mock_admin
mock_to_thread.side_effect = Exception('Unexpected error')
# Act
result = await token_manager.delete_keycloak_user(user_id)
# Assert
assert result is False
@@ -1,4 +1,4 @@
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from server.auth.token_manager import TokenManager, create_encryption_utility
@@ -246,3 +246,103 @@ async def test_refresh(token_manager):
mock_keycloak.return_value.a_refresh_token.assert_called_once_with(
'test_refresh_token'
)
@pytest.mark.asyncio
async def test_disable_keycloak_user_success(token_manager):
"""Test successful disabling of a Keycloak user account."""
# Arrange
user_id = 'test_user_id'
email = 'user@colsch.us'
mock_user = {
'id': user_id,
'username': 'testuser',
'email': email,
'emailVerified': True,
}
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
mock_admin.a_update_user = AsyncMock()
mock_get_admin.return_value = mock_admin
# Act
await token_manager.disable_keycloak_user(user_id, email)
# Assert
mock_admin.a_get_user.assert_called_once_with(user_id)
mock_admin.a_update_user.assert_called_once_with(
user_id=user_id,
payload={
'enabled': False,
'username': 'testuser',
'email': email,
'emailVerified': True,
},
)
@pytest.mark.asyncio
async def test_disable_keycloak_user_without_email(token_manager):
"""Test disabling Keycloak user without providing email."""
# Arrange
user_id = 'test_user_id'
mock_user = {
'id': user_id,
'username': 'testuser',
'email': 'user@example.com',
'emailVerified': False,
}
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(return_value=mock_user)
mock_admin.a_update_user = AsyncMock()
mock_get_admin.return_value = mock_admin
# Act
await token_manager.disable_keycloak_user(user_id)
# Assert
mock_admin.a_get_user.assert_called_once_with(user_id)
mock_admin.a_update_user.assert_called_once()
@pytest.mark.asyncio
async def test_disable_keycloak_user_not_found(token_manager):
"""Test disabling Keycloak user when user is not found."""
# Arrange
user_id = 'nonexistent_user_id'
email = 'user@colsch.us'
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(return_value=None)
mock_get_admin.return_value = mock_admin
# Act
await token_manager.disable_keycloak_user(user_id, email)
# Assert
mock_admin.a_get_user.assert_called_once_with(user_id)
mock_admin.a_update_user.assert_not_called()
@pytest.mark.asyncio
async def test_disable_keycloak_user_exception_handling(token_manager):
"""Test that disable_keycloak_user handles exceptions gracefully without raising."""
# Arrange
user_id = 'test_user_id'
email = 'user@colsch.us'
with patch('server.auth.token_manager.get_keycloak_admin') as mock_get_admin:
mock_admin = MagicMock()
mock_admin.a_get_user = AsyncMock(side_effect=Exception('Connection error'))
mock_get_admin.return_value = mock_admin
# Act & Assert - should not raise exception
await token_manager.disable_keycloak_user(user_id, email)
# Verify the method was called
mock_admin.a_get_user.assert_called_once_with(user_id)
+25
View File
@@ -0,0 +1,25 @@
# BFCL (Berkeley Function-Calling Leaderboard) Evaluation
This directory contains the evaluation scripts for BFCL.
## Setup
You may need to clone the official BFCL repository or install the evaluation package if available.
```bash
# Example setup (adjust as needed)
# git clone https://github.com/ShishirPatil/gorilla.git
# cd gorilla/berkeley-function-call-leaderboard
# pip install -r requirements.txt
```
## Running Evaluation
To run the evaluation, you need to provide the path to the BFCL dataset:
```bash
python evaluation/benchmarks/bfcl/run_infer.py \
--agent-cls CodeActAgent \
--llm-config <your_llm_config> \
--dataset-path /path/to/bfcl_dataset.json
```
+196
View File
@@ -0,0 +1,196 @@
import asyncio
import os
import pandas as pd # type: ignore
# Assuming bfcl-eval is installed or we use a similar local structure
# The user mentioned: "Integrate bfcl-eval package for official metrics"
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
get_default_sandbox_config_for_eval,
get_metrics,
get_openhands_config_for_eval,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
)
from openhands.controller.state.state import State
from openhands.core.config import (
OpenHandsConfig,
get_evaluation_parser,
get_llm_config_arg,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction
from openhands.utils.async_utils import call_async_from_sync
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
AGENT_CLS_TO_INST_SUFFIX = {
'CodeActAgent': 'When you think you have completed the request, please finish the interaction using the "finish" tool.\n'
}
def get_config(
metadata: EvalMetadata,
) -> OpenHandsConfig:
sandbox_config = get_default_sandbox_config_for_eval()
sandbox_config.base_container_image = 'python:3.12-bookworm'
config = get_openhands_config_for_eval(
metadata=metadata,
runtime='docker',
sandbox_config=sandbox_config,
)
config.set_llm_config(metadata.llm_config)
agent_config = config.get_agent_config(metadata.agent_class)
agent_config.enable_prompt_extensions = False
return config
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(metadata)
instance_id = str(instance['id']).replace(
'/', '_'
) # BFCL IDs might contain slashes
if reset_logger:
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
reset_logger_for_multiprocessing(logger, instance_id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance_id}.')
# Prepare instruction
# BFCL usually has a question/prompt and associated functions
question = instance['question']
# We might need to format it with available tools?
# For now, let's assume the agent can handle raw text or we format it.
instruction = f'Question: {question}\n'
# instruction += f"Available Functions: {instance['function']}\n"
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
),
)
)
if state is None:
raise ValueError('State should not be None.')
metrics = get_metrics(state)
histories = compatibility_for_eval_history_pairs(state.history)
last_agent_message = state.get_last_agent_message()
model_answer_raw = last_agent_message.content if last_agent_message else ''
output = EvalOutput(
instance_id=instance_id,
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
test_result={
'generated_text': model_answer_raw,
# We will use bfcl-eval to score offline/post-hoc usually,
# or we can try to score here if the package allows easy single-instance scoring.
},
)
return output
if __name__ == '__main__':
parser = get_evaluation_parser()
parser.add_argument(
'--dataset-path',
type=str,
help='Path to the BFCL dataset (json/jsonl)',
)
args, _ = parser.parse_known_args()
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
llm_config.modify_params = False
# Load dataset
if args.dataset_path:
if args.dataset_path.endswith('.json'):
dataset_df = pd.read_json(args.dataset_path)
elif args.dataset_path.endswith('.jsonl'):
dataset_df = pd.read_json(args.dataset_path, lines=True)
else:
raise ValueError('Dataset must be .json or .jsonl')
else:
# Try to load from huggingface or default location?
# For now require path or create dummy
logger.warning('No dataset path provided, creating dummy dataset.')
dataset_df = pd.DataFrame(
[
{
'id': 'test-0',
'question': 'What is the weather in San Francisco?',
'function': [
{
'name': 'get_weather',
'parameters': {'location': 'San Francisco'},
}
],
}
]
)
if 'instance_id' not in dataset_df.columns:
if 'id' in dataset_df.columns:
dataset_df['instance_id'] = dataset_df['id']
else:
dataset_df['instance_id'] = dataset_df.index.astype(str)
metadata = make_metadata(
llm_config=llm_config,
dataset_name='bfcl',
agent_class=args.agent_cls,
max_iterations=args.max_iterations,
eval_note=args.eval_note,
eval_output_dir=args.eval_output_dir,
data_split=args.data_split,
)
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
dataset = prepare_dataset(
dataset_df, output_file=output_file, eval_n_limit=args.eval_n_limit
)
run_evaluation(
dataset=dataset,
metadata=metadata,
output_file=output_file,
num_workers=args.eval_num_workers,
process_instance_func=process_instance,
)
+22
View File
@@ -0,0 +1,22 @@
# Tau-Bench Evaluation
This directory contains the evaluation scripts for Tau-Bench.
## Setup
First, make sure you have installed the `tau-bench` package:
```bash
pip install tau-bench
```
## Running Evaluation
To run the evaluation, use the following command:
```bash
python evaluation/benchmarks/tau_bench/run_infer.py \
--agent-cls CodeActAgent \
--llm-config <your_llm_config> \
--env retail
```
@@ -0,0 +1,221 @@
import asyncio
import os
from typing import Any
import pandas as pd # type: ignore
try:
from tau_bench.agents.base import Agent as TauAgent # type: ignore
from tau_bench.envs import get_env # type: ignore
from tau_bench.types import EnvInfo # type: ignore
except ImportError:
TauAgent = Any
get_env = Any
EnvInfo = Any
from evaluation.utils.shared import (
EvalMetadata,
EvalOutput,
codeact_user_response,
compatibility_for_eval_history_pairs,
get_default_sandbox_config_for_eval,
get_metrics,
get_openhands_config_for_eval,
make_metadata,
prepare_dataset,
reset_logger_for_multiprocessing,
run_evaluation,
)
from openhands.controller.state.state import State
from openhands.core.config import (
OpenHandsConfig,
get_evaluation_parser,
get_llm_config_arg,
)
from openhands.core.logger import openhands_logger as logger
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import MessageAction
from openhands.utils.async_utils import call_async_from_sync
AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
'CodeActAgent': codeact_user_response,
}
AGENT_CLS_TO_INST_SUFFIX = {
'CodeActAgent': 'When you think you have completed the request, please finish the interaction using the "finish" tool.\n'
}
def get_config(
metadata: EvalMetadata,
) -> OpenHandsConfig:
sandbox_config = get_default_sandbox_config_for_eval()
sandbox_config.base_container_image = 'python:3.12-bookworm'
config = get_openhands_config_for_eval(
metadata=metadata,
runtime='docker',
sandbox_config=sandbox_config,
)
config.set_llm_config(metadata.llm_config)
agent_config = config.get_agent_config(metadata.agent_class)
agent_config.enable_prompt_extensions = False
return config
def process_instance(
instance: pd.Series,
metadata: EvalMetadata,
reset_logger: bool = True,
) -> EvalOutput:
config = get_config(metadata)
instance_id = str(instance['instance_id'])
# Setup the logger properly
if reset_logger:
log_dir = os.path.join(metadata.eval_output_dir, 'infer_logs')
reset_logger_for_multiprocessing(logger, instance_id, log_dir)
else:
logger.info(f'Starting evaluation for instance {instance_id}.')
# Initialize Tau-Bench environment
instance['env']
instance['task_index']
# Initialize runtime
runtime = create_runtime(config)
call_async_from_sync(runtime.connect)
# Note: We need to figure out how to bridge Tau-Bench environment with OpenHands agent.
# OpenHands agents expect to interact with a runtime (shell/browser).
# Tau-Bench environments provide a python interface.
# For now, we will assume we can run python code in the runtime to interact with Tau-Bench,
# OR we adapt the agent to call Tau-Bench API.
# Given OpenHands agents are general purpose, we probably want to expose Tau-Bench tools
# as Python functions available in the runtime, or standard tools.
# Let's inspect how Tau-Bench works. It seems it requires `tau-bench` package.
# The user request mentioned: "Integrate sierra-research/tau-bench package for dataset and evaluation"
# Since I don't have the package installed yet, I will write the skeleton and then install/mock it.
instruction = instance['instruction']
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
state: State | None = asyncio.run(
run_controller(
config=config,
initial_user_action=MessageAction(content=instruction),
runtime=runtime,
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN.get(
metadata.agent_class
),
)
)
if state is None:
raise ValueError('State should not be None.')
metrics = get_metrics(state)
histories = compatibility_for_eval_history_pairs(state.history)
# Retrieve result from the state or runtime if possible
# For Tau-Bench, we typically need to check if the goal was achieved in the env.
# Placeholder for actual score calculation
score = 0.0
output = EvalOutput(
instance_id=instance_id,
metadata=metadata,
history=histories,
metrics=metrics,
error=state.last_error if state and state.last_error else None,
test_result={
'score': score,
},
)
return output
if __name__ == '__main__':
parser = get_evaluation_parser()
parser.add_argument(
'--env',
type=str,
default='retail',
help='Tau-Bench environment name (retail, airline)',
)
args, _ = parser.parse_known_args()
llm_config = None
if args.llm_config:
llm_config = get_llm_config_arg(args.llm_config)
if llm_config is None:
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
llm_config.modify_params = False
# Load dataset
# We need to load tasks from Tau-Bench
# Since we can't import tau_bench yet, we might fail here.
# But I will write the import and let the user/system install it.
try:
from tau_bench.envs import get_env # type: ignore
except ImportError:
logger.error(
'Tau-Bench not installed. Please install it via `pip install tau-bench`'
)
# For now, we create a dummy dataset to allow syntax checking
dataset_df = pd.DataFrame(
[
{
'instance_id': '0',
'env': 'retail',
'task_index': 0,
'instruction': 'Test instruction',
}
]
)
else:
# Load tasks from the environment
env = get_env(args.env)
tasks = env.get_tasks()
data = []
for i, task in enumerate(tasks):
data.append(
{
'instance_id': f'{args.env}_{i}',
'env': args.env,
'task_index': i,
'instruction': task.instruction,
'ground_truth': task.actions, # Or whatever ground truth it provides
}
)
dataset_df = pd.DataFrame(data)
metadata = make_metadata(
llm_config=llm_config,
dataset_name=f'tau-bench-{args.env}',
agent_class=args.agent_cls,
max_iterations=args.max_iterations,
eval_note=args.eval_note,
eval_output_dir=args.eval_output_dir,
data_split=args.data_split,
)
output_file = os.path.join(metadata.eval_output_dir, 'output.jsonl')
dataset = prepare_dataset(
dataset_df, output_file=output_file, eval_n_limit=args.eval_n_limit
)
run_evaluation(
dataset=dataset,
metadata=metadata,
output_file=output_file,
num_workers=args.eval_num_workers,
process_instance_func=process_instance,
)
+146
View File
@@ -0,0 +1,146 @@
# Mock Service Worker (MSW) Guide
## Overview
[Mock Service Worker (MSW)](https://mswjs.io/) is an API mocking library that intercepts outgoing network requests at the network level. Unlike traditional mocking that patches `fetch` or `axios`, MSW uses a Service Worker in the browser and direct request interception in Node.js—making mocks transparent to your application code.
We use MSW in this project for:
- **Testing**: Write reliable unit and integration tests without real network calls
- **Development**: Run the frontend with mocked APIs when the backend isn't available or when working on features with pending backend APIs
The same mock handlers work in both environments, so you write them once and reuse everywhere.
## Relevant Files
- `src/mocks/handlers.ts` - Main handler registry that combines all domain handlers
- `src/mocks/*-handlers.ts` - Domain-specific handlers (auth, billing, conversation, etc.)
- `src/mocks/browser.ts` - Browser setup for development mode
- `src/mocks/node.ts` - Node.js setup for tests
- `vitest.setup.ts` - Global test setup with MSW lifecycle hooks
## Development Workflow
### Running with Mocked APIs
```sh
# Run with API mocking enabled
npm run dev:mock
# Run with API mocking + SaaS mode simulation
npm run dev:mock:saas
```
These commands set `VITE_MOCK_API=true` which activates the MSW Service Worker to intercept requests.
> [!NOTE]
> **OSS vs SaaS Mode**
>
> OpenHands runs in two modes:
> - **OSS mode**: For local/self-hosted deployments where users provide their own LLM API keys and configure git providers manually
> - **SaaS mode**: For the cloud offering with billing, managed API keys, and OAuth-based GitHub integration
>
> Use `dev:mock:saas` when working on SaaS-specific features like billing, API key management, or subscription flows.
## Writing Tests
### Service Layer Mocking (Recommended)
For most tests, mock at the service layer using `vi.spyOn`. This approach is explicit, test-scoped, and makes the scenario being tested clear.
```typescript
import { vi } from "vitest";
import SettingsService from "#/api/settings-service/settings-service.api";
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
getSettingsSpy.mockResolvedValue({
llm_model: "openai/gpt-4o",
llm_api_key_set: true,
// ... other settings
});
```
Use `mockResolvedValue` for success scenarios and `mockRejectedValue` for error scenarios:
```typescript
getSettingsSpy.mockRejectedValue(new Error("Failed to fetch settings"));
```
### Network Layer Mocking (Advanced)
For tests that need actual network-level behavior (WebSockets, testing retry logic, etc.), use `server.use()` to override handlers per test.
> [!IMPORTANT]
> **Reuse the global server instance** - Don't create new `setupServer()` calls in individual tests. The project already has a global MSW server configured in `vitest.setup.ts` that handles lifecycle (`server.listen()`, `server.resetHandlers()`, `server.close()`). Use `server.use()` to add runtime handlers for specific test scenarios.
```typescript
import { http, HttpResponse } from "msw";
import { server } from "#/mocks/node";
it("should handle server errors", async () => {
server.use(
http.get("/api/my-endpoint", () => {
return new HttpResponse(null, { status: 500 });
}),
);
// ... test code
});
```
For WebSocket testing, see `__tests__/helpers/msw-websocket-setup.ts` for utilities.
## Adding New API Mocks
When adding new API endpoints, create mocks in both places to maintain 1:1 similarity with the backend:
### 1. Add to `src/mocks/` (for development)
Create or update a domain-specific handler file:
```typescript
// src/mocks/my-feature-handlers.ts
import { http, HttpResponse } from "msw";
export const MY_FEATURE_HANDLERS = [
http.get("/api/my-feature", () => {
return HttpResponse.json({
data: "mock response",
});
}),
];
```
Register in `handlers.ts`:
```typescript
import { MY_FEATURE_HANDLERS } from "./my-feature-handlers";
export const handlers = [
// ... existing handlers
...MY_FEATURE_HANDLERS,
];
```
### 2. Mock in tests for specific scenarios
In your test files, spy on the service method to control responses per test case:
```typescript
import { vi } from "vitest";
import MyFeatureService from "#/api/my-feature-service.api";
const spy = vi.spyOn(MyFeatureService, "getData");
spy.mockResolvedValue({ data: "test-specific response" });
```
See `__tests__/routes/llm-settings.test.tsx` for a real-world example of service layer mocking.
> [!TIP]
> For guidance on creating service APIs, see `src/api/README.md`.
## Best Practices
- **Keep mocks close to real API contracts** - Update mocks when backend changes
- **Use service layer mocking for most tests** - It's simpler and more explicit
- **Reserve network layer mocking for integration tests** - WebSockets, retry logic, etc.
- **Export mock data from handler files** - Reuse in tests (e.g., `MOCK_DEFAULT_USER_SETTINGS`)
@@ -0,0 +1,18 @@
import { test, expect, vi } from "vitest";
import axios from "axios";
import V1GitService from "../../src/api/git-service/v1-git-service.api";
vi.mock("axios");
test("getGitChanges throws when response is not an array (dead runtime returns HTML)", async () => {
const htmlResponse = "<!DOCTYPE html><html>...</html>";
vi.mocked(axios.get).mockResolvedValue({ data: htmlResponse });
await expect(
V1GitService.getGitChanges(
"http://localhost:3000/api/conversations/123",
"test-api-key",
"/workspace",
),
).rejects.toThrow("Invalid response from runtime");
});
@@ -0,0 +1,68 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen } from "@testing-library/react";
import React from "react";
import { adaptSystemMessage } from "#/utils/system-message-adapter";
import { EventState } from "#/stores/use-event-store";
import { SystemMessageModal } from "#/components/features/conversation-panel/system-message-modal";
import { ToolsContextMenu } from "#/components/features/controls/tools-context-menu";
const v1Event: EventState["events"] = [
{
id: "v1-id",
timestamp: "2025-12-30T12:00:00Z",
source: "agent",
system_prompt: {
type: "text",
text: "v1 prompt",
},
tools: [
{
type: "function",
function: {
name: "bash",
description: "Execute bash",
parameters: {},
},
},
],
},
];
const adaptedResult = adaptSystemMessage(v1Event);
vi.mock("#/hooks/query/use-active-conversation", () => ({
useActiveConversation: () => ({ data: { conversation_version: "V1" } }),
}));
vi.mock("#/hooks/use-user-providers", () => ({
useUserProviders: () => ({ providers: ["test"] }),
}));
describe("SystemMessage UI Rendering", () => {
it("should render the 'Show Agent Tools' button in the context menu", () => {
render(
<ToolsContextMenu
onClose={() => {}}
onShowSkills={() => {}}
onShowAgentTools={() => {}}
/>,
);
expect(screen.getByTestId("show-agent-tools-button")).toBeInTheDocument();
});
it("should display the adapted v1 system prompt content correctly", () => {
render(
<SystemMessageModal
isOpen
onClose={() => {}}
systemMessage={adaptedResult}
/>,
);
const messageElement = screen.getByText("v1 prompt");
expect(messageElement).toBeDefined();
expect(messageElement).toBeVisible();
});
});
@@ -1,6 +1,7 @@
import { render, screen } from "@testing-library/react";
import { it, describe, expect, vi, beforeEach, afterEach } from "vitest";
import userEvent from "@testing-library/user-event";
import { MemoryRouter } from "react-router";
import { AuthModal } from "#/components/features/waitlist/auth-modal";
// Mock the useAuthUrl hook
@@ -27,11 +28,13 @@ describe("AuthModal", () => {
it("should render the GitHub and GitLab buttons", () => {
render(
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
providersConfigured={["github", "gitlab"]}
/>,
<MemoryRouter>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
providersConfigured={["github", "gitlab"]}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
@@ -49,11 +52,13 @@ describe("AuthModal", () => {
const user = userEvent.setup();
const mockUrl = "https://github.com/login/oauth/authorize";
render(
<AuthModal
githubAuthUrl={mockUrl}
appMode="saas"
providersConfigured={["github"]}
/>,
<MemoryRouter>
<AuthModal
githubAuthUrl={mockUrl}
appMode="saas"
providersConfigured={["github"]}
/>
</MemoryRouter>,
);
const githubButton = screen.getByRole("button", {
@@ -65,10 +70,14 @@ describe("AuthModal", () => {
});
it("should render Terms of Service and Privacy Policy text with correct links", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
render(
<MemoryRouter>
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
</MemoryRouter>,
);
// Find the terms of service section using data-testid
const termsSection = screen.getByTestId("auth-modal-terms-of-service");
const termsSection = screen.getByTestId("terms-and-privacy-notice");
expect(termsSection).toBeInTheDocument();
// Check that all text content is present in the paragraph
@@ -105,8 +114,44 @@ describe("AuthModal", () => {
expect(termsSection).toContainElement(privacyLink);
});
it("should display email verified message when emailVerified prop is true", () => {
render(
<MemoryRouter>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
emailVerified={true}
/>
</MemoryRouter>,
);
expect(
screen.getByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
).toBeInTheDocument();
});
it("should not display email verified message when emailVerified prop is false", () => {
render(
<MemoryRouter>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
emailVerified={false}
/>
</MemoryRouter>,
);
expect(
screen.queryByText("AUTH$EMAIL_VERIFIED_PLEASE_LOGIN"),
).not.toBeInTheDocument();
});
it("should open Terms of Service link in new tab", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
render(
<MemoryRouter>
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
</MemoryRouter>,
);
const tosLink = screen.getByRole("link", {
name: "COMMON$TERMS_OF_SERVICE",
@@ -115,11 +160,58 @@ describe("AuthModal", () => {
});
it("should open Privacy Policy link in new tab", () => {
render(<AuthModal githubAuthUrl="mock-url" appMode="saas" />);
render(
<MemoryRouter>
<AuthModal githubAuthUrl="mock-url" appMode="saas" />
</MemoryRouter>,
);
const privacyLink = screen.getByRole("link", {
name: "COMMON$PRIVACY_POLICY",
});
expect(privacyLink).toHaveAttribute("target", "_blank");
});
describe("Duplicate email error message", () => {
const renderAuthModalWithRouter = (initialEntries: string[]) => {
const hasDuplicatedEmail = initialEntries.includes(
"/?duplicated_email=true",
);
return render(
<MemoryRouter initialEntries={initialEntries}>
<AuthModal
githubAuthUrl="mock-url"
appMode="saas"
providersConfigured={["github"]}
hasDuplicatedEmail={hasDuplicatedEmail}
/>
</MemoryRouter>,
);
};
it("should display error message when duplicated_email query parameter is true", () => {
// Arrange
const initialEntries = ["/?duplicated_email=true"];
// Act
renderAuthModalWithRouter(initialEntries);
// Assert
const errorMessage = screen.getByText("AUTH$DUPLICATE_EMAIL_ERROR");
expect(errorMessage).toBeInTheDocument();
});
it("should not display error message when duplicated_email query parameter is missing", () => {
// Arrange
const initialEntries = ["/"];
// Act
renderAuthModalWithRouter(initialEntries);
// Assert
const errorMessage = screen.queryByText("AUTH$DUPLICATE_EMAIL_ERROR");
expect(errorMessage).not.toBeInTheDocument();
});
});
});
@@ -0,0 +1,30 @@
import React from "react";
import { describe, it, expect, vi } from "vitest";
import { screen } from "@testing-library/react";
import { renderWithProviders } from "test-utils";
import { ConfirmDeleteModal } from "#/components/features/conversation-panel/confirm-delete-modal";
vi.mock("react-i18next", async (importOriginal) => ({
...(await importOriginal<typeof import("react-i18next")>()),
Trans: ({
values,
components,
}: {
values: { title: string };
components: { title: React.ReactElement };
}) => React.cloneElement(components.title, {}, values.title),
}));
describe("ConfirmDeleteModal", () => {
it("should display the conversation title", () => {
renderWithProviders(
<ConfirmDeleteModal
onConfirm={vi.fn()}
onCancel={vi.fn()}
conversationTitle="My Test Conversation"
/>,
);
expect(screen.getByText(/My Test Conversation/)).toBeInTheDocument();
});
});
@@ -14,6 +14,8 @@ import { renderWithProviders } from "test-utils";
import { formatTimeDelta } from "#/utils/format-time-delta";
import { ConversationCard } from "#/components/features/conversation-panel/conversation-card/conversation-card";
import { clickOnEditButton } from "./utils";
import { ConversationCardActions } from "#/components/features/conversation-panel/conversation-card/conversation-card-actions";
import { ConversationStatus } from "#/types/conversation-status";
// We'll use the actual i18next implementation but override the translation function
@@ -431,4 +433,34 @@ describe("ConversationCard", () => {
expect(screen.queryByTestId("ellipsis-button")).not.toBeInTheDocument();
});
const statusTable: [ConversationStatus, boolean][] = [
["RUNNING", true],
["STARTING", true],
["STOPPED", false],
["ARCHIVED", false],
["ERROR", false],
];
it.each(statusTable)(
"should toggle stop button visibility correctly for status",
(status, shouldShow) => {
renderWithProviders(
<ConversationCardActions
contextMenuOpen={true}
onContextMenuToggle={vi.fn()}
onStop={vi.fn()}
conversationStatus={status}
/>,
);
const stopButton = screen.queryByTestId("stop-button");
if (shouldShow) {
expect(stopButton).toBeInTheDocument();
} else {
expect(stopButton).not.toBeInTheDocument();
}
},
);
});
@@ -16,6 +16,13 @@ vi.mock("#/hooks/mutation/use-unified-stop-conversation", () => ({
}),
}));
// Mock toast handlers to prevent unhandled rejection errors
vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: vi.fn(),
displayErrorToast: vi.fn(),
TOAST_OPTIONS: {},
}));
describe("ConversationPanel", () => {
const onCloseMock = vi.fn();
const RouterStub = createRoutesStub([
@@ -23,6 +30,11 @@ describe("ConversationPanel", () => {
Component: () => <ConversationPanel onClose={onCloseMock} />,
path: "/",
},
{
// Add route to prevent "No routes matched location" warning
Component: () => null,
path: "/conversations/:conversationId",
},
]);
const renderConversationPanel = () => renderWithProviders(<RouterStub />);
@@ -629,12 +641,6 @@ describe("ConversationPanel", () => {
);
updateConversationSpy.mockResolvedValue(true);
// Mock the toast function
const mockToast = vi.fn();
vi.mock("#/utils/custom-toast-handlers", () => ({
displaySuccessToast: mockToast,
}));
renderConversationPanel();
const cards = await screen.findAllByTestId("conversation-card");
@@ -762,10 +768,6 @@ describe("ConversationPanel", () => {
);
updateConversationSpy.mockRejectedValue(new Error("API Error"));
vi.mock("#/utils/custom-toast-handlers", () => ({
displayErrorToast: vi.fn(),
}));
renderConversationPanel();
const cards = await screen.findAllByTestId("conversation-card");
@@ -883,4 +885,83 @@ describe("ConversationPanel", () => {
title: "Special @#$%^&*()_+ Characters",
});
});
it("should close delete modal when clicking backdrop", async () => {
const user = userEvent.setup();
renderConversationPanel();
const cards = await screen.findAllByTestId("conversation-card");
// Open context menu and click delete
const ellipsisButton = within(cards[0]).getByTestId("ellipsis-button");
await user.click(ellipsisButton);
const deleteButton = within(cards[0]).getByTestId("delete-button");
await user.click(deleteButton);
// Modal should be visible
expect(
screen.getByRole("button", { name: /confirm/i }),
).toBeInTheDocument();
// Click the backdrop (the dark overlay behind the modal)
const backdrop = document.querySelector(".bg-black.opacity-60");
expect(backdrop).toBeInTheDocument();
await user.click(backdrop!);
// Modal should be closed
expect(
screen.queryByRole("button", { name: /confirm/i }),
).not.toBeInTheDocument();
});
it("should close stop modal when clicking backdrop", async () => {
const user = userEvent.setup();
// Create mock data with a RUNNING conversation
const mockRunningConversations: Conversation[] = [
{
conversation_id: "1",
title: "Running Conversation",
selected_repository: null,
git_provider: null,
selected_branch: null,
last_updated_at: "2021-10-01T12:00:00Z",
created_at: "2021-10-01T12:00:00Z",
status: "RUNNING" as const,
runtime_status: null,
url: null,
session_api_key: null,
},
];
vi.spyOn(ConversationService, "getUserConversations").mockResolvedValue({
results: mockRunningConversations,
next_page_id: null,
});
renderConversationPanel();
const cards = await screen.findAllByTestId("conversation-card");
// Open context menu and click stop
const ellipsisButton = within(cards[0]).getByTestId("ellipsis-button");
await user.click(ellipsisButton);
const stopButton = within(cards[0]).getByTestId("stop-button");
await user.click(stopButton);
// Modal should be visible
expect(
screen.getByRole("button", { name: /confirm/i }),
).toBeInTheDocument();
// Click the backdrop
const backdrop = document.querySelector(".bg-black.opacity-60");
expect(backdrop).toBeInTheDocument();
await user.click(backdrop!);
// Modal should be closed
expect(
screen.queryByRole("button", { name: /confirm/i }),
).not.toBeInTheDocument();
});
});

Some files were not shown because too many files have changed in this diff Show More