Compare commits

...

240 Commits

Author SHA1 Message Date
openhands
c7c2029a0e Fix StoredConversationMetadata import issue causing AttributeError
- Replace problematic lazy import pattern with direct imports
- Fix 'NoneType' object has no attribute 'conversation_id' error
- Update imports in saas_app_conversation_info_injector.py and related files
- Simplify test file by removing patch workaround
- All imports now work correctly and linting passes

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 06:44:49 +00:00
Chuck Butkus
997371aed7 Another fix 2025-12-10 01:27:46 -05:00
Chuck Butkus
a1cb0d75af Revert "Try one more thing"
This reverts commit 0c7b4573c9.
2025-12-10 01:02:13 -05:00
Chuck Butkus
0c7b4573c9 Try one more thing 2025-12-10 00:45:51 -05:00
Chuck Butkus
64e4ef1b15 test fixes 2025-12-10 00:25:35 -05:00
Chuck Butkus
b34c89c0f8 Lint fixes 2025-12-09 23:24:12 -05:00
openhands
d5734a8d0c Fix test failures in enterprise/tests/unit/server/test_event_webhook.py
- Fixed session_maker mocking by directly patching the module-level variable
- Updated all failing tests to properly mock the database session
- Fixed TestUpdateConversationMetadata tests to use correct session_maker
- Fixed TestOnWrite::test_on_write_metadata_success to use correct session_maker
- Fixed TestProcessBatchOperationsBackground tests to use correct session_maker
- All 33 tests in test_event_webhook.py now pass

The main issue was that session_maker is imported directly from storage.database
at module import time, so patching 'storage.database.session_maker' wasn't
effective. Instead, we now directly patch the module-level variable in the
conversation_callback_utils module.

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 04:20:51 +00:00
Chuck Butkus
e760c182dc Lint fixes 2025-12-09 22:33:41 -05:00
openhands
71009298af Fix enterprise test failures by mocking StoredConversationMetadata lazy imports
- Fixed NoneType errors in conversation store, SQL app conversation info service,
  conversation callback processor, and event webhook tests
- Added proper mocking of StoredConversationMetadata lazy import to use actual
  OpenHands core class instead of None
- Fixed UserStore.get_user_by_id mocking in conversation store tests
- All previously failing tests now pass (23 tests verified)

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 03:27:09 +00:00
Chuck Butkus
48f08cab0e Lint fixes 2025-12-09 22:06:22 -05:00
openhands
475e96c314 Fix circular import error between enterprise and core modules
- Move ApiKeyStore import to lazy loading in enterprise/server/mcp/mcp_config.py
- Implement lazy import mechanism in enterprise/storage/stored_conversation_metadata.py using __getattr__
- Move UserContext import to TYPE_CHECKING block in openhands/app_server/app_conversation/sql_app_conversation_info_service.py

This resolves the circular import chain:
user_context → user_models → provider → events → stream → io → json → llm → config → mcp_config → server.mcp.mcp_config → storage → stored_conversation_metadata → sql_app_conversation_info_service → user_context

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 02:41:27 +00:00
Chuck Butkus
c5dda5d0d7 Fix tests 2025-12-09 21:12:13 -05:00
openhands
63086831cb Fix circular import in openhands.events.serialization.event
The circular import was caused by openhands.events.serialization.event
importing openhands.llm.metrics at module level, which eventually led
back to openhands.events through the config system.

Changes:
- Remove module-level import of openhands.llm.metrics classes
- Add lazy import in event_from_dict function where metrics are used
- Preserve all existing functionality while breaking the import cycle

This fixes the second circular import in the chain:
events.serialization.event → llm.metrics → config → storage → events

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 02:05:49 +00:00
openhands
0d163bf1ce Fix circular import in openhands.events.event
The circular import was caused by openhands.events.event importing
openhands.llm.metrics at module level, which eventually led back to
openhands.events.event through the config system.

Changes:
- Move Metrics import to TYPE_CHECKING block for type annotations
- Add lazy import in llm_metrics property getter for runtime usage
- Use forward references in type annotations
- Preserve all existing functionality while breaking the import cycle

Fixes the ImportError: cannot import name 'Event' from partially
initialized module 'openhands.events.event' error.

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 01:41:23 +00:00
openhands
6649be08a7 Fix circular import issue by extracting RecallType to standalone module
- Created new module openhands/events/recall_type.py with RecallType enum
- Removed RecallType from openhands/events/event.py to break circular dependency
- Updated all import statements across 13 files to use new module path
- Resolves circular import chain: sync/enrich_user_interaction_data.py ->
  integrations.github.data_collector -> ... -> openhands.events.event ->
  openhands.llm.metrics -> ... -> storage.conversation_callback ->
  openhands.events.observation.agent -> openhands.events.event (circular)

The RecallType enum now has minimal dependencies and can be imported
without triggering the heavy dependency chain that caused the circular import.

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-10 00:46:31 +00:00
Chuck Butkus
ee0c1a1c2f Revert "Fix circular reference"
This reverts commit a2d61e0eb6.
2025-12-09 19:34:14 -05:00
Chuck Butkus
d3c002aee5 Fix unit tests 2025-12-09 19:11:45 -05:00
Chuck Butkus
a2d61e0eb6 Fix circular reference 2025-12-09 19:00:48 -05:00
Chuck Butkus
fab75ab33d Fix circular reference 2025-12-09 18:50:33 -05:00
Chuck Butkus
a8c4fc5318 Fix SQL migration to work with both SQLLite and Postgres 2025-12-09 18:30:58 -05:00
Chuck Butkus
1647a2466f Update to what is on main branch 2025-12-09 17:11:01 -05:00
Chuck Butkus
fe5b4bb34c Refactor to internal method 2025-12-09 15:21:48 -05:00
openhands
fb0bfd3684 Fix stripe_service tests to handle call_sync_from_async usage
- Updated test database schema to include all required tables (user, org, org_member, role, stripe_customer)
- Fixed test fixtures to use unified Base and create proper table relationships
- Updated test mocking to properly handle call_sync_from_async calls in find_customer_id_by_user_id and find_or_create_customer_by_user_id methods
- All tests now pass successfully after the stripe_service.py changes

Co-authored-by: openhands <openhands@all-hands.dev>
2025-12-09 19:49:47 +00:00
Chuck Butkus
9f5c2327ec Fix merge and some cleanup 2025-12-09 13:16:37 -05:00
chuckbutkus
1864cf9b7a Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-12-09 12:26:32 -05:00
Chuck Butkus
9ecf2c7e85 Add owner role 2025-12-09 11:59:42 -05:00
Chuck Butkus
e8d89d9a55 Fix OSS unit tests 2025-12-08 22:25:06 -05:00
Chuck Butkus
d33c405ed5 Fix unit tests 2025-12-08 22:15:28 -05:00
Chuck Butkus
3db4d3210d More lint 2025-12-08 21:43:19 -05:00
Chuck Butkus
5ca5bbf3f0 Fix some unit tests 2025-12-08 21:37:23 -05:00
Chuck Butkus
b97a4fdee9 More lint fixes 2025-12-08 21:13:32 -05:00
Chuck Butkus
36a135b942 More lint fixes 2025-12-08 21:09:06 -05:00
Chuck Butkus
7dff779fce Lint fixes 2025-12-08 20:56:27 -05:00
chuckbutkus
f40954f39e Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-12-08 15:59:44 -05:00
Chuck Butkus
00797cd8a1 Add v1_enabled field 2025-12-08 14:12:46 -05:00
Chuck Butkus
be5cd4c818 Fix migration 2025-12-08 13:54:29 -05:00
chuckbutkus
297140e727 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-12-08 13:38:55 -05:00
chuckbutkus
7e5942c2c1 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-12-05 03:17:45 -05:00
Chuck Butkus
1d3ed8f6fa Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-12-04 14:35:48 -05:00
Chuck Butkus
1aec00e92a Merge fixes 2025-12-04 13:53:46 -05:00
Chuck Butkus
517a8c3d9b Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-12-03 23:01:30 -05:00
Chuck Butkus
036ef85e9d Add metadata on create user 2025-12-03 22:53:07 -05:00
Chuck Butkus
44ef2012df Cleanup 2025-12-03 22:53:07 -05:00
Chuck Butkus
cd765937f5 Change to update user and keys in LiteLLM 2025-12-03 22:53:07 -05:00
chuckbutkus
dec0f411db Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-23 17:50:54 -05:00
chuckbutkus
93edf56824 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-20 21:58:01 -05:00
Chuck Butkus
77db0cda60 Fix unit tests 2025-11-19 22:29:38 -05:00
Chuck Butkus
d2ff260e39 Migrate byor key 2025-11-19 22:04:20 -05:00
Chuck Butkus
3c59371cbf Fix DB migration 2025-11-19 20:49:50 -05:00
chuckbutkus
8d4095e20e Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-19 20:46:10 -05:00
Chuck Butkus
869677c107 Fix unit tests 2025-11-19 20:45:47 -05:00
Chuck Butkus
e3aad64ee6 Lint fixes 2025-11-19 18:54:46 -05:00
Chuck Butkus
0422ac7ffd Handle old keys 2025-11-19 16:50:32 -05:00
Chuck Butkus
a8f7ff5142 Fix sync from async calls 2025-11-19 16:13:40 -05:00
Chuck Butkus
016761471a Revert "Fix async routine to handle being in a loop already"
This reverts commit 6e61f0617a.
2025-11-19 15:59:50 -05:00
Chuck Butkus
6e61f0617a Fix async routine to handle being in a loop already 2025-11-19 01:26:39 -05:00
Chuck Butkus
a456be6d7b Fix migrations 2025-11-19 00:47:49 -05:00
Chuck Butkus
a89d66f934 Merge main into branch 2025-11-18 14:24:00 -05:00
Chuck Butkus
ff170ecee8 Fix migration 2025-11-18 14:07:06 -05:00
Chuck Butkus
96e27a8997 Fix unit tests 2025-11-18 02:49:02 -05:00
Chuck Butkus
4da310848c Fix unit tests 2025-11-18 02:37:35 -05:00
Chuck Butkus
d79a9b0764 Fix circular reference 2025-11-18 02:15:00 -05:00
Chuck Butkus
80336b71d6 Fix test and migration 2025-11-18 01:58:42 -05:00
Chuck Butkus
a11fbda85e Actually update the key 2025-11-18 01:46:39 -05:00
Chuck Butkus
2b73238a45 Fix setting migration 2025-11-18 00:24:24 -05:00
Chuck Butkus
a8988a9564 Update encryption 2025-11-17 22:00:32 -05:00
Chuck Butkus
6d5dc76536 Update encryption and merge changes 2025-11-17 13:54:16 -05:00
chuckbutkus
104e21f501 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-17 13:22:55 -05:00
Chuck Butkus
373d7e7708 Fix count method 2025-11-12 23:59:44 -05:00
Chuck Butkus
b9533a2811 Lint fixes 2025-11-12 23:31:32 -05:00
Chuck Butkus
7b8951a761 Fix lint errors 2025-11-12 23:20:17 -05:00
Chuck Butkus
cbe234d5be Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-12 23:08:34 -05:00
Chuck Butkus
e392d1e7b3 Review fixes 2025-11-12 22:35:53 -05:00
Chuck Butkus
16fc633b90 Fix defaults 2025-11-12 21:51:18 -05:00
Chuck Butkus
fb418448b8 Fix count query 2025-11-12 16:33:26 -05:00
Chuck Butkus
8e3c6756ad Misc fixes 2025-11-12 14:38:34 -05:00
Chuck Butkus
61b8b06ec8 FIx OSS and SAAS conversation_metadata deletes 2025-11-12 14:23:09 -05:00
Chuck Butkus
3cdc3d5df0 Fix to keep user_id until we are done migrating users 2025-11-12 13:45:29 -05:00
Chuck Butkus
179e7dfaf1 Add migration to remove fields from sqllite DB 2025-11-11 13:57:33 -05:00
chuckbutkus
de21bb5740 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-11 13:02:23 -05:00
chuckbutkus
f9e99b337e Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-11 12:43:22 -05:00
Chuck Butkus
49d65992fd Reapply "Another reference fix"
This reverts commit 8e94924aba.
2025-11-10 21:39:40 -05:00
Chuck Butkus
ee62a86ad8 Reapply "More reference fixes"
This reverts commit 7646cabc53.
2025-11-10 21:39:10 -05:00
Chuck Butkus
0c7d5d4dcd Reapply "More reference fixes"
This reverts commit 85d867e9af.
2025-11-10 21:38:53 -05:00
Chuck Butkus
18cb38e535 Reapply "More reference fixes"
This reverts commit a0707d5fa2.
2025-11-10 21:38:36 -05:00
Chuck Butkus
37bf855027 Reapply "More circular references"
This reverts commit 65fc2d2d50.
2025-11-10 21:38:24 -05:00
Chuck Butkus
5894b48c3d Reapply "Fix circular reference"
This reverts commit bfa4c51ca0.
2025-11-10 21:38:00 -05:00
Chuck Butkus
7fd9704d66 Revert "Fix circular reference in provider.py"
This reverts commit bbc525260c.
2025-11-10 21:37:47 -05:00
Chuck Butkus
c2d6bd8623 Revert "Another circular reference fix"
This reverts commit 139d46feff.
2025-11-10 21:37:33 -05:00
Chuck Butkus
139d46feff Another circular reference fix 2025-11-10 21:18:52 -05:00
chuckbutkus
45b28cb4ae Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-10 20:52:38 -05:00
Chuck Butkus
bbc525260c Fix circular reference in provider.py 2025-11-10 20:52:18 -05:00
Chuck Butkus
bfa4c51ca0 Revert "Fix circular reference"
This reverts commit 9f8ca567af.
2025-11-10 15:09:25 -05:00
Chuck Butkus
65fc2d2d50 Revert "More circular references"
This reverts commit 64b7ca3faf.
2025-11-10 15:09:04 -05:00
Chuck Butkus
a0707d5fa2 Revert "More reference fixes"
This reverts commit eead092e91.
2025-11-10 15:08:45 -05:00
Chuck Butkus
85d867e9af Revert "More reference fixes"
This reverts commit bb2012b768.
2025-11-10 15:08:12 -05:00
Chuck Butkus
7646cabc53 Revert "More reference fixes"
This reverts commit 26540e8be1.
2025-11-10 15:07:57 -05:00
Chuck Butkus
8e94924aba Revert "Another reference fix"
This reverts commit 26d137c2c3.
2025-11-10 15:07:29 -05:00
chuckbutkus
fb9aa6f76c Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-10 15:05:50 -05:00
Chuck Butkus
591d32d98a Better circular ref fix and remove extraneous code 2025-11-10 14:56:41 -05:00
chuckbutkus
8491c38797 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-07 15:28:32 -05:00
Chuck Butkus
d66ced3acc More fixes 2025-11-07 15:24:16 -05:00
Chuck Butkus
de91bc86a5 Fix DB migration 2025-11-07 00:26:32 -05:00
Chuck Butkus
26d137c2c3 Another reference fix 2025-11-07 00:21:12 -05:00
Chuck Butkus
26540e8be1 More reference fixes 2025-11-07 00:07:00 -05:00
Chuck Butkus
bb2012b768 More reference fixes 2025-11-06 23:47:13 -05:00
Chuck Butkus
eead092e91 More reference fixes 2025-11-06 23:25:00 -05:00
Chuck Butkus
64b7ca3faf More circular references 2025-11-06 22:55:33 -05:00
Chuck Butkus
9f8ca567af Fix circular reference 2025-11-06 22:39:35 -05:00
chuckbutkus
617ea40d00 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-06 20:45:08 -05:00
chuckbutkus
943ab53efa Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-06 12:33:21 -05:00
Chuck Butkus
2422b1df97 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-06 00:51:59 -05:00
Chuck Butkus
9ec47a803f More migration fixes 2025-11-06 00:25:54 -05:00
Chuck Butkus
021d319db9 Fix migration queries 2025-11-05 16:59:21 -05:00
Chuck Butkus
e82c8d12c2 Make sure to migrate all tables 2025-11-05 13:13:06 -05:00
chuckbutkus
081db2b6b4 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-05 12:48:40 -05:00
chuckbutkus
ac30a73947 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-04 23:09:51 -05:00
Chuck Butkus
2f80c468ff Lint fixes 2025-11-04 22:58:43 -05:00
openhands
78b05bf008 Fix test_user_isolation by adding mock for User query
- Added mock for User query in save_app_conversation_info() method
- Mock returns a User object with user_id and org_id the same as user_id_uuid
- Handles both UUID formats (with and without dashes) from SQLAlchemy compilation
- Allows other database queries to pass through normally
- Fixes AssertionError in test_user_isolation test

Co-authored-by: openhands <openhands@all-hands.dev>
2025-11-05 03:54:39 +00:00
openhands
2fd9cbf8f2 Fix tests for classes updated with org_id usage
- Updated ApiKeyStore tests to mock UserStore.get_user_by_id calls
- Added mock_user fixture with current_org_id for org filtering
- Updated SaasSecretsStore tests to mock UserStore calls and handle org_id filtering
- Added tests for retrieve_mcp_api_key functionality
- All tests now properly handle the new org_id-based filtering introduced in commit 69186bc6c

Co-authored-by: openhands <openhands@all-hands.dev>
2025-11-04 21:07:06 +00:00
chuckbutkus
dce575fa2d Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-04 15:49:21 -05:00
Chuck Butkus
69186bc6c8 Add org_id use in queries 2025-11-04 15:19:41 -05:00
Chuck Butkus
d61b47a134 Fix lint errors 2025-11-04 14:21:21 -05:00
chuckbutkus
22a3564939 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-04 14:15:59 -05:00
Chuck Butkus
61e607fb37 Fix save app metadata 2025-11-04 14:15:05 -05:00
Chuck Butkus
544a7b08cd Fix lint errors 2025-11-04 00:57:06 -05:00
openhands
99691a6103 Add comprehensive unit tests for LiteLlmManager class
- Created test_lite_llm_manager.py with 24 test cases covering all methods
- Tests include create_entries, migrate_entries, and update_team_and_users_budget
- Comprehensive coverage of private HTTP client methods (_create_team, _get_team, etc.)
- Tests for public wrapper methods with HTTP client injection
- Error handling scenarios including missing configuration and API failures
- Mock-based testing to avoid external dependencies
- All tests passing with proper fixtures and async support

Co-authored-by: openhands <openhands@all-hands.dev>
2025-11-04 05:52:58 +00:00
Chuck Butkus
4a22138fff Fix lint errors 2025-11-04 00:41:56 -05:00
Chuck Butkus
92fb3507c9 Remove litellm tests from saas_setttings 2025-11-04 00:39:27 -05:00
Chuck Butkus
73d06b2919 Fix lint error 2025-11-03 22:20:08 -05:00
chuckbutkus
459f999175 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-03 22:17:16 -05:00
Chuck Butkus
e9fe3dcb3b Migration and test fixes 2025-11-03 22:13:33 -05:00
Chuck Butkus
c998a4da68 Fix migration 2025-11-03 21:21:37 -05:00
chuckbutkus
ddf45d9b1d Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-03 15:02:51 -05:00
chuckbutkus
8f62a97a26 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-03 13:40:42 -05:00
chuckbutkus
ee66151692 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-03 13:15:13 -05:00
chuckbutkus
e6dc590ef1 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-03 12:49:21 -05:00
Chuck Butkus
36e2e5942a Add back user isolation test 2025-11-03 12:43:22 -05:00
chuckbutkus
a6096d0b46 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-11-03 11:46:33 -05:00
openhands
f107e21d26 Create tests for SaasSQLAppConversationInfoService and move user isolation test
- Created new test file enterprise/tests/unit/storage/test_saas_sql_app_conversation_info_service.py
- Added comprehensive test suite for SaasSQLAppConversationInfoService with 4 tests:
  * test_service_initialization: Verifies proper service initialization
  * test_user_context_isolation: Tests user context isolation between different service instances
  * test_secure_select_includes_user_filtering: Validates _secure_select method functionality
  * test_to_info_with_user_id_functionality: Tests user_id override from SAAS metadata
- Moved test_user_isolation from TestSQLAppConversationInfoService to new SAAS test class
- Fixed UUID string conversion issues in SaasSQLAppConversationInfoService
- Updated all user_id handling to properly convert string to UUID for database operations
- All tests pass: 4 new SAAS tests + 17 existing original tests

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-31 18:16:59 +00:00
chuckbutkus
516591c012 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-29 23:21:40 -04:00
Chuck Butkus
9efb67a3bd Add more user_id handling to convo info 2025-10-29 17:46:10 -04:00
Chuck Butkus
c5ef7a5944 Update to secure_select 2025-10-29 16:10:30 -04:00
openhands
20366ba973 feat: Enable enterprise SQLAppConversationInfoService override in SAAS mode
- Add SaasAppConversationInfoServiceInjector to properly inject enterprise service
- Modify base config to use enterprise injector when OPENHANDS_CONFIG_CLS contains 'saas'
- Ensure OPENHANDS_CONFIG_CLS is set in saas_server.py for proper SAAS mode detection
- Clean up stored_conversation_metadata.py imports and exports

This ensures that when launching the enterprise server with uvicorn saas_server:app,
the overridden _secure_select() method with user-based filtering is used instead
of the base OSS implementation.

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-29 19:16:34 +00:00
Chuck Butkus
df03a56888 Add user_id check on enterprise 2025-10-29 14:55:50 -04:00
chuckbutkus
d202c90f5f Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-29 14:17:47 -04:00
Chuck Butkus
7addb78158 Fix another test 2025-10-29 00:41:06 -04:00
Chuck Butkus
8afa6cf51b Lint fixes 2025-10-28 23:12:08 -04:00
Chuck Butkus
1289688b64 Fix unit tests 2025-10-28 23:10:43 -04:00
Chuck Butkus
e349d37b8c Update to latest poetry version 2025-10-28 20:15:48 -04:00
Chuck Butkus
6fec7b729d Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-28 17:06:24 -04:00
chuckbutkus
cd05434d7f Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-28 14:54:17 -04:00
Chuck Butkus
9e7b74ea32 Update 2025-10-28 14:43:26 -04:00
openhands
4646439108 Separate SaaS-specific fields from StoredConversationMetadata
- Create new ConversationMetadataSaas model with conversation_id, user_id, org_id
- Remove github_user_id, user_id, org_id from StoredConversationMetadata
- Update all enterprise clients to use ConversationMetadataSaas for user/org lookups
- Add database migration to create new table and migrate existing data
- Maintain backward compatibility in OpenHands core components

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-27 23:46:27 +00:00
rohitvinodmalhotra@gmail.com
f89e41ac30 fix migration 2025-10-27 13:44:28 -04:00
rohitvinodmalhotra@gmail.com
9b0029c5bb Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-27 13:42:50 -04:00
rohitvinodmalhotra@gmail.com
3f247952fa Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-27 13:41:35 -04:00
rohitvinodmalhotra@gmail.com
dc360c8a5c fix extraneous change 2025-10-27 11:00:13 -04:00
Rohit Malhotra
5f06aad131 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-24 13:04:28 -04:00
rohitvinodmalhotra@gmail.com
26ca1cf2d7 fix lint 2025-10-24 13:03:29 -04:00
rohitvinodmalhotra@gmail.com
75c9a09ad1 fix lint 2025-10-24 13:01:32 -04:00
rohitvinodmalhotra@gmail.com
139a5f7caf Update test_billing.py 2025-10-24 13:00:55 -04:00
rohitvinodmalhotra@gmail.com
4caa72d080 fix tests 2025-10-24 12:53:28 -04:00
rohitvinodmalhotra@gmail.com
2f2a1c5c58 fix tests 2025-10-24 12:42:09 -04:00
rohitvinodmalhotra@gmail.com
37e0f7fd6e Update test_conversation_callback_processor.py 2025-10-24 12:37:42 -04:00
rohitvinodmalhotra@gmail.com
b012176c9c fix tests 2025-10-24 12:29:27 -04:00
rohitvinodmalhotra@gmail.com
a5e1a9fd99 fix tests 2025-10-24 12:20:22 -04:00
rohitvinodmalhotra@gmail.com
0b0d77bcdf fix tests 2025-10-24 12:13:10 -04:00
rohitvinodmalhotra@gmail.com
3791a76216 fix failing tests 2025-10-24 12:06:17 -04:00
rohitvinodmalhotra@gmail.com
b921f06e2b fix tests 2025-10-24 11:49:07 -04:00
rohitvinodmalhotra@gmail.com
07b8391605 rm user version update 2025-10-24 11:29:53 -04:00
rohitvinodmalhotra@gmail.com
2ec03b8c55 Update test_org_store.py 2025-10-24 11:25:50 -04:00
rohitvinodmalhotra@gmail.com
8beb9b4638 fix test 2025-10-23 11:42:28 -04:00
openhands
b40f55a328 Add all SQLAlchemy storage models to enterprise/storage/__init__.py
- Added all 36 SQLAlchemy models that inherit from Base
- Added relevant enum classes (BillingSessionType, SubscriptionAccessStatus, etc.)
- Fixed missing comma in __all__ list
- Organized imports alphabetically for better maintainability
- Included StoredConversationMetadata alias from openhands core

This ensures all storage models are properly exposed through the storage module.
2025-10-23 15:36:14 +00:00
rohitvinodmalhotra@gmail.com
4e0d553380 add init for storage models for sqlalchemy registration during unit tests 2025-10-23 10:39:05 -04:00
rohitvinodmalhotra@gmail.com
42c40d75b1 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-23 10:16:51 -04:00
rohitvinodmalhotra@gmail.com
6e30c62078 simplify 2025-10-23 09:42:11 -04:00
rohitvinodmalhotra@gmail.com
f29161b7f3 rm org migration 2025-10-22 16:47:45 -04:00
rohitvinodmalhotra@gmail.com
7d084db6d7 var for personal workspace version 2025-10-22 16:04:28 -04:00
rohitvinodmalhotra@gmail.com
0ab08e93a6 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-22 10:54:47 -04:00
openhands
d3586bf820 Fix enterprise unit tests: Update User model attribute references
- Changed keycloak_user_id to id in User object instantiations
- Updated test assertions to use user.id instead of user.keycloak_user_id
- Fixed UUID generation for User.id fields
- Updated query filters to use User.id instead of User.keycloak_user_id
- Added missing uuid imports where needed

Files modified:
- enterprise/tests/unit/test_user_store.py: Fixed 3 test functions
- enterprise/tests/unit/test_org_store.py: Fixed 1 test function
- enterprise/tests/unit/test_org_member_store.py: Fixed 6 test functions
- enterprise/tests/unit/test_models.py: Fixed user creation and query
- enterprise/tests/unit/test_auth_routes.py: Fixed mock object attributes

These changes align the tests with the updated User model schema where
keycloak_user_id has been replaced with a UUID id field.
2025-10-22 14:51:01 +00:00
rohitvinodmalhotra@gmail.com
e3dbb00d4e fix typing 2025-10-22 10:35:50 -04:00
rohitvinodmalhotra@gmail.com
e11b2008f3 fix tests 2025-10-22 10:28:49 -04:00
Rohit Malhotra
a02b5a6c0e Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-22 09:43:24 -04:00
rohitvinodmalhotra@gmail.com
3b3b05dc33 fix comparasion 2025-10-22 09:42:06 -04:00
rohitvinodmalhotra@gmail.com
7d6392f793 rm enterprise local 2025-10-22 09:37:10 -04:00
Rohit Malhotra
ec3c33afac Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-22 09:29:44 -04:00
rohitvinodmalhotra@gmail.com
eb847de7ec Merge branch 'migrate-org-db-litellm-from-deploy' of https://github.com/All-Hands-AI/OpenHands into migrate-org-db-litellm-from-deploy 2025-10-21 16:06:05 -04:00
rohitvinodmalhotra@gmail.com
c3e91baa53 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-21 16:05:31 -04:00
openhands
d2003c83fb Add downgrade for migration_status column in user_settings
- Drop migration_status column in downgrade() function
- Ensures proper migration rollback capability
2025-10-21 19:59:36 +00:00
openhands
7c0a939d96 Add migration_status boolean to UserSettings for migration tracking
- Add migration_status column to user_settings table in migration script
- Update UserSettings model with migration_status boolean field (default False)
- Add migration check in UserStore to prevent double migration
- Mark migrated records as True instead of hard deletion
- Filter non-migrated records in SaasSettingsStore

This ensures safe migration from user_settings to org-based structure
without data loss and prevents duplicate migrations.
2025-10-21 19:57:41 +00:00
openhands
f45b86a396 Rename OrgUser to OrgMember across enterprise directory
- Renamed database table from org_user to org_member in migration 077
- Renamed OrgUser class to OrgMember in storage model
- Renamed OrgUserStore class to OrgMemberStore
- Updated all import statements and references across the codebase
- Updated relationship references in related models (User, Org, Role)
- Updated foreign key constraint names (ou_* -> om_*)
- Updated method names (get_org_user -> get_org_member, get_org_users -> get_org_members)
- Updated test files to use new naming conventions
- Renamed files: org_user.py -> org_member.py, org_user_store.py -> org_member_store.py

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-21 17:41:16 +00:00
openhands
d7bf698d1e Remove org_id and relationship from GitlabWebhook table
- Remove org_id column and ForeignKey constraint from GitlabWebhook model
- Remove org relationship from GitlabWebhook model
- Remove gitlab_webhooks relationship from Org model
- Remove gitlab_webhook table modifications from migration 077
- Clean up imports in gitlab_webhook.py (removed ForeignKey, UUID, relationship imports)

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-21 15:52:50 +00:00
openhands
d655049934 Remove org_id and relationship from UserRepositoryMap table
- Remove org_id column and ForeignKey constraint from UserRepositoryMap model
- Remove org relationship from UserRepositoryMap model
- Remove user_repos relationship from Org model
- Remove user-repos table modifications from migration 077
- Clean up imports in user_repo_map.py (removed ForeignKey, UUID, relationship)

This decouples the user-repos table from the org system as requested.
2025-10-21 15:47:56 +00:00
openhands
6357b46001 Fix SQLAlchemy relationship error between Org and StoredConversationMetadata
- Add ForeignKey import to StoredConversationMetadata model
- Add ForeignKey('org.id') constraint to org_id column
- Uncomment org relationship with back_populates='conversation_metadata'
- Ensures bidirectional relationship works properly with migration 077

Fixes: sqlalchemy.exc.NoForeignKeysError: Could not determine join condition between parent/child tables on relationship Org.conversation_metadata

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-21 15:42:56 +00:00
Chuck Butkus
186f4423e0 Make org_id nullable for now 2025-10-20 20:33:21 -04:00
Chuck Butkus
baf323a26c Remove exception swallowing 2025-10-17 00:43:54 -04:00
Chuck Butkus
cc7eef9fc0 Fix lint errors 2025-10-17 00:43:54 -04:00
openhands
c9a2a6c17f Fix database schema issues for tests
- Make org_id column nullable to match migration
- Comment out org relationship for tests to avoid foreign key constraint errors
- Add note about org_id column in test file

This resolves SQLAlchemy foreign key constraint errors in unit tests
where the org table doesn't exist in the test environment.
2025-10-17 03:13:22 +00:00
Chuck Butkus
2a857a676f Missed file 2025-10-16 22:19:41 -04:00
Chuck Butkus
cf7096e80d Use same ID for user and personal org to simplify migration 2025-10-16 22:18:42 -04:00
chuckbutkus
cfd27b1dce Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-16 20:39:23 -04:00
rohitvinodmalhotra@gmail.com
c36b628879 Update slack_view.py 2025-10-16 17:53:25 -04:00
rohitvinodmalhotra@gmail.com
a34cc6b7e7 Merge branch 'migrate-org-db-litellm-from-deploy' of https://github.com/All-Hands-AI/OpenHands into migrate-org-db-litellm-from-deploy 2025-10-16 17:52:21 -04:00
rohitvinodmalhotra@gmail.com
d70006717e fix slack 2025-10-16 17:52:13 -04:00
openhands
bf57a3ac6d Fix SQLAlchemy relationship error by adding missing org_id foreign key
- Add org_id column to StoredConversationMetadata model
- Import PostgreSQL UUID type to avoid naming conflicts
- Resolves 'Could not determine join condition' error in org relationships
- Ensures consistency with migration 077_create_org_tables.py

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-16 21:33:17 +00:00
rohitvinodmalhotra@gmail.com
ffc77fe229 fix migrations 2025-10-16 16:27:55 -04:00
rohitvinodmalhotra@gmail.com
82082fcee3 fix import 2025-10-16 16:25:35 -04:00
chuckbutkus
8d1f8c24f3 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-15 22:10:26 -04:00
chuckbutkus
0369bc77dd Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-15 20:01:34 -04:00
chuckbutkus
1ef111d954 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-15 19:47:13 -04:00
rohitvinodmalhotra@gmail.com
69db41aa1d fix org relation 2025-10-15 09:42:24 -04:00
rohitvinodmalhotra@gmail.com
a7118ddda6 fix auth route 2025-10-14 22:37:22 -04:00
rohitvinodmalhotra@gmail.com
86494cdd90 fix tests 2025-10-14 22:21:52 -04:00
rohitvinodmalhotra@gmail.com
101aa68424 rm stored settings ref 2025-10-14 22:17:36 -04:00
rohitvinodmalhotra@gmail.com
47b225d76d Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-14 20:01:01 -04:00
Chuck Butkus
06758d352a Some Lint fixes 2025-10-14 01:35:45 -04:00
Chuck Butkus
6dc6f9514e Update migration and loading settings 2025-10-14 00:49:02 -04:00
Chuck Butkus
08519c2e44 Field changes to org DB structure 2025-10-13 23:16:29 -04:00
Rohit Malhotra
cc1e4b8c4a Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-13 12:11:21 -04:00
rohitvinodmalhotra@gmail.com
0d6ff3ac50 add todo 2025-10-13 11:55:23 -04:00
rohitvinodmalhotra@gmail.com
b15ffa29a5 fix broken migration 2025-10-13 11:54:23 -04:00
rohitvinodmalhotra@gmail.com
5f2ce8e18a Revert "Update agent_chat.py"
This reverts commit 8f90374f49.
2025-10-13 09:31:36 -04:00
rohitvinodmalhotra@gmail.com
8f90374f49 Update agent_chat.py 2025-10-13 09:30:42 -04:00
Chuck Butkus
4c38beb456 Fix user_settings imports 2025-10-13 00:52:43 -04:00
Chuck Butkus
02f009e6b5 Fix running enterprise server locally 2025-10-13 00:51:55 -04:00
rohitvinodmalhotra@gmail.com
fed53185ac fix imports 2025-10-12 21:11:41 -04:00
rohitvinodmalhotra@gmail.com
5cdebc3ed5 rm oh scratch files 2025-10-12 21:07:19 -04:00
rohitvinodmalhotra@gmail.com
947fc2f616 Merge branch 'main' into migrate-org-db-litellm-from-deploy 2025-10-12 21:06:05 -04:00
rohitvinodmalhotra@gmail.com
939242fc22 fix changes 2025-10-12 21:04:20 -04:00
rohitvinodmalhotra@gmail.com
f787f6a089 fix copied changes 2025-10-09 23:35:56 -04:00
rohitvinodmalhotra@gmail.com
f687bcccf7 fix copied changes 2025-10-09 23:34:02 -04:00
rohitvinodmalhotra@gmail.com
ba06aa3c0c fix copied changes 2025-10-09 23:32:50 -04:00
rohitvinodmalhotra@gmail.com
36f516b337 fix copied changes 2025-10-09 23:17:24 -04:00
rohitvinodmalhotra@gmail.com
3d4805f4b1 fix imports 2025-10-09 23:12:09 -04:00
rohitvinodmalhotra@gmail.com
bf178fcc0e revert copied change 2025-10-09 23:10:20 -04:00
openhands
7c41d6f30f Complete migration with corrected import paths and additional files
- Update all import paths from 'openhands.enterprise.*' to 'enterprise.*'
  to match OpenHands repo structure (deploy repo used openhands.enterprise)
- Add comprehensive documentation files (migration guides, structure docs)
- Add example usage files for organizational features
- Add complete test suite for organizational models and stores
- Update all server routes, auth components, integrations, and storage files
- Ensure all cross-references use correct enterprise.* import structure

This completes the migration of organizational database structure from
deploy repo PR #1413 with all import paths corrected for OpenHands repo.
2025-10-07 04:07:38 +00:00
openhands
7906b38ded Fix import path in server config
Update import from 'server.auth.constants' to 'enterprise.server.auth.constants'
to match the new enterprise directory structure.
2025-10-07 03:35:56 +00:00
openhands
d74b0e3fc6 Migrate additional storage files required by tests
- Add conversation_work.py for conversation work tracking
- Add feedback.py for user feedback storage
- Add github_app_installation.py for GitHub app installations
- Add maintenance_task.py for maintenance task processing
- Add stored_offline_token.py for offline token storage
- Update all imports to use enterprise.storage structure

These files are required by the test suite conftest.py for proper
database table creation during testing.
2025-10-07 03:33:09 +00:00
openhands
07b6ce5ed0 Migrate organizational database structure from deploy repo
- Add organizational models: Org, User, Role, OrgUser with proper relationships
- Add corresponding store classes for database operations
- Add encryption utilities for sensitive data handling
- Add LiteLLM manager for organizational LLM configuration
- Add comprehensive migration file for organizational tables
- Update constants with ORG_SETTINGS_VERSION and version mapping
- Fix import paths to use enterprise structure
- Add org_id columns to existing tables for multi-tenancy support

Migrated from deploy repo PR #1413 'Org db litellm' (98 commits)
Resolves conflicts and updates paths for OpenHands repository structure

Co-authored-by: openhands <openhands@all-hands.dev>
2025-10-07 03:09:46 +00:00
95 changed files with 10800 additions and 7047 deletions

View File

@@ -2,7 +2,7 @@ BACKEND_HOST ?= "127.0.0.1"
BACKEND_PORT = 3000
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
FRONTEND_PORT = 3001
OPENHANDS_PATH ?= "../../OpenHands"
OPENHANDS_PATH ?= ".."
OPENHANDS := $(OPENHANDS_PATH)
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build

View File

@@ -23,9 +23,9 @@ from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
from server.auth.token_manager import TokenManager
from server.config import get_config
from storage.database import session_maker
from storage.org_store import OrgStore
from storage.proactive_conversation_store import ProactiveConversationStore
from storage.saas_secrets_store import SaasSecretsStore
from storage.saas_settings_store import SaasSettingsStore
from openhands.agent_server.models import SendMessageRequest
from openhands.app_server.app_conversation.app_conversation_models import (
@@ -72,19 +72,17 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
if not user_id:
return False
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if not settings or settings.enable_proactive_conversation_starters is None:
# Check global setting first - if disabled globally, return False
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
return False
return settings.enable_proactive_conversation_starters
def _get_setting():
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
if not org:
return False
return bool(org.enable_proactive_conversation_starters)
return await call_sync_from_async(_get_setting)
async def get_user_v1_enabled_setting(user_id: str) -> bool:
@@ -96,19 +94,14 @@ async def get_user_v1_enabled_setting(user_id: str) -> bool:
Returns:
True if V1 conversations are enabled for this user, False otherwise
"""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
org = await call_sync_from_async(
OrgStore.get_current_org_from_keycloak_user_id, user_id
)
settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if not settings or settings.v1_enabled is None:
if not org or org.v1_enabled is None:
return False
return settings.v1_enabled
return org.v1_enabled
# =================================================
@@ -166,6 +159,7 @@ class GithubIssue(ResolverViewInterface):
issue_body=self.description,
previous_comments=self.previous_comments,
)
return user_instructions, conversation_instructions
async def _get_user_secrets(self):
@@ -199,6 +193,7 @@ class GithubIssue(ResolverViewInterface):
conversation_trigger=ConversationTrigger.RESOLVER,
git_provider=ProviderType.GITHUB,
)
self.conversation_id = conversation_metadata.conversation_id
return conversation_metadata
@@ -345,7 +340,6 @@ class GithubIssueComment(GithubIssue):
conversation_instructions_template = jinja_env.get_template(
'issue_conversation_instructions.j2'
)
conversation_instructions = conversation_instructions_template.render(
issue_number=self.issue_number,
issue_title=self.title,
@@ -382,8 +376,7 @@ class GithubPRComment(GithubIssueComment):
return user_instructions, conversation_instructions
async def initialize_new_conversation(self) -> ConversationMetadata:
# FIXME: Handle if initialize_conversation returns None
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
conversation_metadata: ConversationMetadata = await initialize_conversation(
user_id=self.user_info.keycloak_user_id,
conversation_id=None,
selected_repository=self.full_repo_name,
@@ -429,7 +422,6 @@ class GithubInlinePRComment(GithubPRComment):
conversation_instructions_template = jinja_env.get_template(
'pr_update_conversation_instructions.j2'
)
conversation_instructions = conversation_instructions_template.render(
pr_number=self.issue_number,
pr_title=self.title,

View File

@@ -167,6 +167,7 @@ class SlackNewConversationView(SlackViewInterface):
'channel_id': self.channel_id,
'conversation_id': self.conversation_id,
'keycloak_user_id': user_info.keycloak_user_id,
'org_id': user_info.org_id,
'parent_id': self.thread_ts or self.message_ts,
},
)
@@ -174,6 +175,7 @@ class SlackNewConversationView(SlackViewInterface):
conversation_id=self.conversation_id,
channel_id=self.channel_id,
keycloak_user_id=user_info.keycloak_user_id,
org_id=user_info.org_id,
parent_id=self.thread_ts
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
)
@@ -304,10 +306,10 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
if not agent_state or agent_state == AgentState.LOADING:
raise StartingConvoException('Conversation is still starting')
user_msg, _ = self._get_instructions(jinja)
user_msg_action = MessageAction(content=user_msg)
instructions, _ = self._get_instructions(jinja)
user_msg = MessageAction(content=instructions)
await conversation_manager.send_event_to_conversation(
self.conversation_id, event_to_dict(user_msg_action)
self.conversation_id, event_to_dict(user_msg)
)
return self.conversation_id

View File

@@ -1,19 +1,24 @@
from uuid import UUID
import stripe
from server.auth.token_manager import TokenManager
from server.constants import STRIPE_API_KEY
from server.logger import logger
from sqlalchemy.orm import Session
from storage.database import session_maker
from storage.org import Org
from storage.org_store import OrgStore
from storage.stripe_customer import StripeCustomer
from openhands.utils.async_utils import call_sync_from_async
stripe.api_key = STRIPE_API_KEY
async def find_customer_id_by_user_id(user_id: str) -> str | None:
# First search our own DB...
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
with session_maker() as session:
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.keycloak_user_id == user_id)
.filter(StripeCustomer.org_id == org_id)
.first()
)
if stripe_customer:
@@ -21,46 +26,76 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
# If that fails, fallback to stripe
search_result = await stripe.Customer.search_async(
query=f"metadata['user_id']:'{user_id}'",
query=f"metadata['org_id']:'{str(org_id)}'",
)
data = search_result.data
if not data:
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
logger.info(
'no_customer_for_org_id',
extra={'org_id': str(org_id)},
)
return None
return data[0].id # type: ignore [attr-defined]
async def find_or_create_customer(user_id: str) -> str:
customer_id = await find_customer_id_by_user_id(user_id)
if customer_id:
return customer_id
logger.info('creating_customer', extra={'user_id': user_id})
async def find_customer_id_by_user_id(user_id: str) -> str | None:
# First search our own DB...
org = await call_sync_from_async(
OrgStore.get_current_org_from_keycloak_user_id, user_id
)
if not org:
logger.warning(f'Org not found for user {user_id}')
return None
customer_id = await find_customer_id_by_org_id(org.id)
return customer_id
# Get the user info from keycloak
token_manager = TokenManager()
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
# Get the current org for the user
org = await call_sync_from_async(
OrgStore.get_current_org_from_keycloak_user_id, user_id
)
if not org:
logger.warning(f'Org not found for user {user_id}')
return None
customer_id = await find_customer_id_by_org_id(org.id)
if customer_id:
return {'customer_id': customer_id, 'org_id': str(org.id)}
logger.info(
'creating_customer',
extra={'user_id': user_id, 'org_id': str(org.id)},
)
# Create the customer in stripe
customer = await stripe.Customer.create_async(
email=str(user_info.get('email', '')),
metadata={'user_id': user_id},
email=org.contact_email,
metadata={'org_id': str(org.id)},
)
# Save the stripe customer in the local db
with session_maker() as session:
session.add(
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
StripeCustomer(
keycloak_user_id=user_id,
org_id=org.id,
stripe_customer_id=customer.id,
)
)
session.commit()
logger.info(
'created_customer',
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)
return customer.id
return {'customer_id': customer.id, 'org_id': str(org.id)}
async def has_payment_method(user_id: str) -> bool:
async def has_payment_method_by_user_id(user_id: str) -> bool:
customer_id = await find_customer_id_by_user_id(user_id)
if customer_id is None:
return False
@@ -71,3 +106,28 @@ async def has_payment_method(user_id: str) -> bool:
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
)
return bool(payment_methods.data)
async def migrate_customer(session: Session, user_id: str, org: Org):
stripe_customer = (
session.query(StripeCustomer)
.filter(StripeCustomer.keycloak_user_id == user_id)
.first()
)
if stripe_customer is None:
return
stripe_customer.org_id = org.id
customer = await stripe.Customer.modify_async(
id=stripe_customer.stripe_customer_id,
email=org.contact_email,
metadata={'user_id': '', 'org_id': str(org.id)},
)
logger.info(
'migrated_customer',
extra={
'user_id': user_id,
'org_id': str(org.id),
'stripe_customer_id': customer.id,
},
)

View File

@@ -20,6 +20,8 @@ down_revision = '059'
branch_labels = None
depends_on = None
# TODO: decide whether to modify this for orgs or users
def upgrade():
"""
@@ -28,8 +30,10 @@ def upgrade():
This replaces the functionality of the removed admin maintenance endpoint.
"""
# Import here to avoid circular imports
from server.constants import CURRENT_USER_SETTINGS_VERSION
# Hardcoded value to prevent migration failures when constant is removed from codebase
# This migration has already run in production, so we use the value that was current at the time
CURRENT_USER_SETTINGS_VERSION = 4
# Create a connection and bind it to a session
connection = op.get_bind()

View File

@@ -0,0 +1,272 @@
"""create org tables from pgerd schema
Revision ID: 084
Revises: 083
Create Date: 2025-01-07 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '084'
down_revision: Union[str, None] = '083'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Remove current settings table
op.execute('DROP TABLE IF EXISTS settings')
# Add already_migrated column to user_settings table
op.add_column(
'user_settings',
sa.Column(
'already_migrated',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
)
# Create role table
op.create_table(
'role',
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
sa.Column('name', sa.String, nullable=False),
sa.Column('rank', sa.Integer, nullable=False),
sa.UniqueConstraint('name', name='role_name_unique'),
)
# 1. Create default roles
op.execute(
sa.text("""
INSERT INTO role (name, rank) VALUES ('owner', 10), ('admin', 20), ('user', 1000)
ON CONFLICT (name) DO NOTHING;
""")
)
# Create org table with settings fields
op.create_table(
'org',
sa.Column(
'id',
postgresql.UUID(as_uuid=True),
primary_key=True,
),
sa.Column('name', sa.String, nullable=False),
sa.Column('contact_name', sa.String, nullable=True),
sa.Column('contact_email', sa.String, nullable=True),
sa.Column('conversation_expiration', sa.Integer, nullable=True),
# Settings fields moved to org table
sa.Column('agent', sa.String, nullable=True),
sa.Column('default_max_iterations', sa.Integer, nullable=True),
sa.Column('security_analyzer', sa.String, nullable=True),
sa.Column(
'confirmation_mode',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
sa.Column('default_llm_model', sa.String, nullable=True),
sa.Column('_default_llm_api_key_for_byor', sa.String, nullable=True),
sa.Column('default_llm_base_url', sa.String, nullable=True),
sa.Column('remote_runtime_resource_factor', sa.Integer, nullable=True),
sa.Column(
'enable_default_condenser',
sa.Boolean,
nullable=False,
server_default=sa.text('true'),
),
sa.Column('billing_margin', sa.Float, nullable=True),
sa.Column(
'enable_proactive_conversation_starters',
sa.Boolean,
nullable=False,
server_default=sa.text('true'),
),
sa.Column('sandbox_base_container_image', sa.String, nullable=True),
sa.Column('sandbox_runtime_container_image', sa.String, nullable=True),
sa.Column(
'org_version', sa.Integer, nullable=False, server_default=sa.text('0')
),
sa.Column('mcp_config', sa.JSON, nullable=True),
sa.Column('_search_api_key', sa.String, nullable=True),
sa.Column('_sandbox_api_key', sa.String, nullable=True),
sa.Column('max_budget_per_task', sa.Float, nullable=True),
sa.Column(
'enable_solvability_analysis',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
sa.Column('v1_enabled', sa.Boolean, nullable=True),
sa.UniqueConstraint('name', name='org_name_unique'),
)
# Create user table with user-specific settings fields
op.create_table(
'user',
sa.Column(
'id',
postgresql.UUID(as_uuid=True),
primary_key=True,
),
sa.Column('current_org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('role_id', sa.Integer, nullable=True),
sa.Column('accepted_tos', sa.DateTime, nullable=True),
sa.Column(
'enable_sound_notifications',
sa.Boolean,
nullable=True,
server_default=sa.text('false'),
),
sa.Column('language', sa.String, nullable=True),
sa.Column('user_consents_to_analytics', sa.Boolean, nullable=True),
sa.Column('email', sa.String, nullable=True),
sa.Column('email_verified', sa.Boolean, nullable=True),
sa.ForeignKeyConstraint(
['current_org_id'], ['org.id'], name='current_org_fkey'
),
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='user_role_fkey'),
)
# Create org_member table (junction table for many-to-many relationship)
op.create_table(
'org_member',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('role_id', sa.Integer, nullable=False),
sa.Column('_llm_api_key', sa.String, nullable=False),
sa.Column('max_iterations', sa.Integer, nullable=True),
sa.Column('llm_model', sa.String, nullable=True),
sa.Column('_llm_api_key_for_byor', sa.String, nullable=True),
sa.Column('llm_base_url', sa.String, nullable=True),
sa.Column('status', sa.String, nullable=True),
sa.ForeignKeyConstraint(['org_id'], ['org.id'], name='om_org_fkey'),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], name='om_user_fkey'),
sa.ForeignKeyConstraint(['role_id'], ['role.id'], name='om_role_fkey'),
sa.PrimaryKeyConstraint('org_id', 'user_id'),
)
# Add org_id column to existing tables
# billing_sessions
op.add_column(
'billing_sessions',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'billing_sessions_org_fkey', 'billing_sessions', 'org', ['org_id'], ['id']
)
# Create conversation_metadata_saas table
op.create_table(
'conversation_metadata_saas',
sa.Column('conversation_id', sa.String(), nullable=False),
sa.Column('user_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
sa.ForeignKeyConstraint(
['user_id'], ['user.id'], name='conversation_metadata_saas_user_fkey'
),
sa.ForeignKeyConstraint(
['org_id'], ['org.id'], name='conversation_metadata_saas_org_fkey'
),
sa.PrimaryKeyConstraint('conversation_id'),
)
# custom_secrets
op.add_column(
'custom_secrets',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'custom_secrets_org_fkey', 'custom_secrets', 'org', ['org_id'], ['id']
)
# api_keys
op.add_column(
'api_keys', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
)
op.create_foreign_key('api_keys_org_fkey', 'api_keys', 'org', ['org_id'], ['id'])
# slack_conversation
op.add_column(
'slack_conversation',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'slack_conversation_org_fkey', 'slack_conversation', 'org', ['org_id'], ['id']
)
# slack_users
op.add_column(
'slack_users', sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True)
)
op.create_foreign_key(
'slack_users_org_fkey', 'slack_users', 'org', ['org_id'], ['id']
)
# stripe_customers
op.alter_column(
'stripe_customers',
'keycloak_user_id',
existing_type=sa.String(),
nullable=True,
)
op.add_column(
'stripe_customers',
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=True),
)
op.create_foreign_key(
'stripe_customers_org_fkey', 'stripe_customers', 'org', ['org_id'], ['id']
)
def downgrade() -> None:
# Drop already_migrated column from user_settings table
op.drop_column('user_settings', 'already_migrated')
# Drop foreign keys and columns added to existing tables
op.drop_constraint(
'stripe_customers_org_fkey', 'stripe_customers', type_='foreignkey'
)
op.drop_column('stripe_customers', 'org_id')
op.alter_column(
'stripe_customers',
'keycloak_user_id',
existing_type=sa.String(),
nullable=False,
)
op.drop_constraint('slack_users_org_fkey', 'slack_users', type_='foreignkey')
op.drop_column('slack_users', 'org_id')
op.drop_constraint(
'slack_conversation_org_fkey', 'slack_conversation', type_='foreignkey'
)
op.drop_column('slack_conversation', 'org_id')
op.drop_constraint('api_keys_org_fkey', 'api_keys', type_='foreignkey')
op.drop_column('api_keys', 'org_id')
op.drop_constraint('custom_secrets_org_fkey', 'custom_secrets', type_='foreignkey')
op.drop_column('custom_secrets', 'org_id')
# Drop conversation_metadata_saas table
op.drop_table('conversation_metadata_saas')
op.drop_constraint(
'billing_sessions_org_fkey', 'billing_sessions', type_='foreignkey'
)
op.drop_column('billing_sessions', 'org_id')
# Drop tables in reverse order due to foreign key constraints
op.drop_table('org_member')
op.drop_table('user')
op.drop_table('org')
op.drop_table('role')

9753
enterprise/poetry.lock generated

File diff suppressed because one or more lines are too long

View File

@@ -4,6 +4,10 @@ from dotenv import load_dotenv
load_dotenv()
# Ensure SAAS configuration is used
if not os.getenv('OPENHANDS_CONFIG_CLS'):
os.environ['OPENHANDS_CONFIG_CLS'] = 'server.config.SaaSServerConfig'
import socketio # noqa: E402
from fastapi import Request, status # noqa: E402
from fastapi.middleware.cors import CORSMiddleware # noqa: E402

View File

@@ -102,7 +102,6 @@ class SaasUserAuth(UserAuth):
return settings
settings_store = await self.get_user_settings_store()
settings = await settings_store.load()
# If load() returned None, should settings be created?
if settings:
settings.email = self.email
settings.email_verified = self.email_verified

View File

@@ -9,7 +9,7 @@ from server.logger import logger
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.core.config import LLMConfig
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -525,16 +525,18 @@ class ClusteredConversationManager(StandaloneConversationManager):
)
# Look up the user_id from the database
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadata.conversation_id
StoredConversationMetadataSaas.conversation_id
== conversation_id
)
.first()
)
user_id = (
conversation_metadata.user_id if conversation_metadata else None
str(conversation_metadata_saas.user_id)
if conversation_metadata_saas
else None
)
# Handle the stopped conversation asynchronously
asyncio.create_task(

View File

@@ -19,8 +19,8 @@ IS_LOCAL_ENV = bool(HOST == 'localhost')
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
# Map of user settings versions to their corresponding default LLM models
# This ensures that CURRENT_USER_SETTINGS_VERSION and LITELLM_DEFAULT_MODEL stay in sync
USER_SETTINGS_VERSION_TO_MODEL = {
# This ensures that PERSONAL_WORKSPACE_VERSION_TO_MODEL and LITELLM_DEFAULT_MODEL stay in sync
PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
1: 'claude-3-5-sonnet-20241022',
2: 'claude-3-7-sonnet-20250219',
3: 'claude-sonnet-4-20250514',
@@ -30,29 +30,17 @@ USER_SETTINGS_VERSION_TO_MODEL = {
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
# Current user settings version - this should be the latest key in USER_SETTINGS_VERSION_TO_MODEL
CURRENT_USER_SETTINGS_VERSION = max(USER_SETTINGS_VERSION_TO_MODEL.keys())
ORG_SETTINGS_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
PERSONAL_WORKSPACE_VERSION = max(PERSONAL_WORKSPACE_VERSION_TO_MODEL.keys())
LITE_LLM_API_URL = os.environ.get(
'LITE_LLM_API_URL', 'https://llm-proxy.app.all-hands.dev'
)
LITE_LLM_TEAM_ID = os.environ.get('LITE_LLM_TEAM_ID', None)
LITE_LLM_API_KEY = os.environ.get('LITE_LLM_API_KEY', None)
SUBSCRIPTION_PRICE_DATA = {
'MONTHLY_SUBSCRIPTION': {
'unit_amount': 2000,
'currency': 'usd',
'product_data': {
'name': 'OpenHands Monthly',
'tax_code': 'txcd_10000000',
},
'tax_behavior': 'exclusive',
'recurring': {'interval': 'month', 'interval_count': 1},
},
}
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
STRIPE_WEBHOOK_SECRET = os.environ.get('STRIPE_WEBHOOK_SECRET', None)
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
SLACK_CLIENT_ID = os.environ.get('SLACK_CLIENT_ID', None)
@@ -102,5 +90,5 @@ def get_default_litellm_model():
"""
if LITELLM_DEFAULT_MODEL:
return LITELLM_DEFAULT_MODEL
model = USER_SETTINGS_VERSION_TO_MODEL[CURRENT_USER_SETTINGS_VERSION]
model = PERSONAL_WORKSPACE_VERSION_TO_MODEL[PERSONAL_WORKSPACE_VERSION]
return build_litellm_proxy_model_path(model)

View File

@@ -44,11 +44,13 @@ class MyProcessor(MaintenanceTaskProcessor):
### UserVersionUpgradeProcessor
Located in `user_version_upgrade_processor.py`, this processor:
- Handles up to 100 user IDs per task
- Upgrades users with `user_version < CURRENT_USER_SETTINGS_VERSION`
- Upgrades users with `user_version < ORG_SETTINGS_VERSION`
- Uses `SaasSettingsStore.create_default_settings()` for upgrades
**Usage:**
```python
from server.maintenance_task_processor.user_version_upgrade_processor import UserVersionUpgradeProcessor
@@ -144,22 +146,26 @@ task = create_maintenance_task(
## Best Practices
### Processor Design
- Keep tasks short-running (under 1 minute)
- Handle errors gracefully and return meaningful error information
- Use batch processing for large datasets
- Include progress information in the return dict
### Error Handling
- Always wrap your processor logic in try-catch blocks
- Return structured error information
- Log important events for debugging
### Performance
- Limit batch sizes to avoid long-running tasks
- Use database sessions efficiently
- Consider memory usage for large datasets
### Testing
- Create unit tests for your processors
- Test error conditions
- Verify the processor serialization/deserialization works correctly
@@ -167,6 +173,7 @@ task = create_maintenance_task(
## Database Patterns
The maintenance task system follows the repository's established patterns:
- Uses `session_maker()` for database operations
- Wraps sync database operations in `call_sync_from_async` for async routes
- Follows proper SQLAlchemy query patterns
@@ -174,15 +181,18 @@ The maintenance task system follows the repository's established patterns:
## Integration with Existing Systems
### User Management
- Integrates with the existing `UserSettings` model
- Uses the current user versioning system (`CURRENT_USER_SETTINGS_VERSION`)
- Uses the current user versioning system (`ORG_SETTINGS_VERSION`)
- Maintains compatibility with existing user management workflows
### Authentication
- Admin endpoints use the existing SaaS authentication system
- Requires users to have `admin = True` in their UserSettings
### Monitoring
- Tasks are logged with structured information
- Status updates are tracked in the database
- Error information is preserved for debugging
@@ -206,6 +216,7 @@ The maintenance task system follows the repository's established patterns:
## Future Enhancements
Potential improvements that could be added:
- Task dependencies and scheduling
- Retry mechanisms for failed tasks
- Real-time progress updates

View File

@@ -1,155 +0,0 @@
from __future__ import annotations
from typing import List
from server.constants import CURRENT_USER_SETTINGS_VERSION
from server.logger import logger
from storage.database import session_maker
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskProcessor
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
from openhands.core.config import load_openhands_config
class UserVersionUpgradeProcessor(MaintenanceTaskProcessor):
"""
Processor for upgrading user settings to the current version.
This processor takes a list of user IDs and upgrades any users
whose user_version is less than CURRENT_USER_SETTINGS_VERSION.
"""
user_ids: List[str]
async def __call__(self, task: MaintenanceTask) -> dict:
"""
Process user version upgrades for the specified user IDs.
Args:
task: The maintenance task being processed
Returns:
dict: Results containing successful and failed user IDs
"""
logger.info(
'user_version_upgrade_processor:start',
extra={
'task_id': task.id,
'user_count': len(self.user_ids),
'current_version': CURRENT_USER_SETTINGS_VERSION,
},
)
if len(self.user_ids) > 100:
raise ValueError(
f'Too many user IDs: {len(self.user_ids)}. Maximum is 100.'
)
config = load_openhands_config()
# Track results
successful_upgrades = []
failed_upgrades = []
users_already_current = []
# Find users that need upgrading
with session_maker() as session:
users_to_upgrade = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id.in_(self.user_ids),
UserSettings.user_version < CURRENT_USER_SETTINGS_VERSION,
)
.all()
)
# Track users that are already current
users_needing_upgrade_ids = {u.keycloak_user_id for u in users_to_upgrade}
users_already_current = [
uid for uid in self.user_ids if uid not in users_needing_upgrade_ids
]
logger.info(
'user_version_upgrade_processor:found_users',
extra={
'task_id': task.id,
'users_to_upgrade': len(users_to_upgrade),
'users_already_current': len(users_already_current),
'total_requested': len(self.user_ids),
},
)
# Process each user that needs upgrading
for user_settings in users_to_upgrade:
user_id = user_settings.keycloak_user_id
old_version = user_settings.user_version
try:
logger.info(
'user_version_upgrade_processor:upgrading_user',
extra={
'task_id': task.id,
'user_id': user_id,
'old_version': old_version,
'new_version': CURRENT_USER_SETTINGS_VERSION,
},
)
# Create SaasSettingsStore instance and upgrade
settings_store = await SaasSettingsStore.get_instance(config, user_id)
await settings_store.create_default_settings(user_settings)
successful_upgrades.append(
{
'user_id': user_id,
'old_version': old_version,
'new_version': CURRENT_USER_SETTINGS_VERSION,
}
)
logger.info(
'user_version_upgrade_processor:user_upgraded',
extra={
'task_id': task.id,
'user_id': user_id,
'old_version': old_version,
'new_version': CURRENT_USER_SETTINGS_VERSION,
},
)
except Exception as e:
failed_upgrades.append(
{'user_id': user_id, 'old_version': old_version, 'error': str(e)}
)
logger.error(
'user_version_upgrade_processor:user_upgrade_failed',
extra={
'task_id': task.id,
'user_id': user_id,
'old_version': old_version,
'error': str(e),
},
)
# Create result summary
result = {
'total_users': len(self.user_ids),
'users_already_current': users_already_current,
'successful_upgrades': successful_upgrades,
'failed_upgrades': failed_upgrades,
'summary': (
f'Processed {len(self.user_ids)} users: '
f'{len(successful_upgrades)} upgraded, '
f'{len(users_already_current)} already current, '
f'{len(failed_upgrades)} errors'
),
}
logger.info(
'user_version_upgrade_processor:completed',
extra={'task_id': task.id, 'result': result},
)
return result

View File

@@ -1,7 +1,5 @@
from typing import TYPE_CHECKING
from storage.api_key_store import ApiKeyStore
if TYPE_CHECKING:
from openhands.core.config.openhands_config import OpenHandsConfig
@@ -36,6 +34,7 @@ class SaaSOpenHandsMCPConfig(OpenHandsMCPConfig):
Returns:
A tuple containing the default SSE server configuration and a list of MCP stdio server configurations
"""
from storage.api_key_store import ApiKeyStore
api_key_store = ApiKeyStore.get_instance()
if user_id:

View File

@@ -1,109 +1,97 @@
from datetime import UTC, datetime
import httpx
from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel, field_validator
from server.config import get_config
from server.constants import LITE_LLM_API_KEY, LITE_LLM_API_URL
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.lite_llm_manager import LiteLlmManager
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.org_store import OrgStore
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.server.user_auth import get_user_id
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.http_session import httpx_verify_option
# Helper functions for BYOR API key management
async def get_byor_key_from_db(user_id: str) -> str | None:
"""Get the BYOR key from the database for a user."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
user_db_settings = await call_sync_from_async(
settings_store.get_user_settings_by_keycloak_id, user_id
)
if user_db_settings and user_db_settings.llm_api_key_for_byor:
return user_db_settings.llm_api_key_for_byor
return None
def _get_byor_key():
user = UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
if current_org_member.llm_api_key_for_byor:
return current_org_member.llm_api_key_for_byor.get_secret_value()
org = OrgStore.get_org_by_id(current_org_id)
if not org:
return None
return (
org.default_llm_api_key_for_byor.get_secret_value()
if org.default_llm_api_key_for_byor
else None
)
return await call_sync_from_async(_get_byor_key)
async def store_byor_key_in_db(user_id: str, key: str) -> None:
"""Store the BYOR key in the database for a user."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
def _update_user_settings():
with session_maker() as session:
user_db_settings = settings_store.get_user_settings_by_keycloak_id(
user_id, session
)
if user_db_settings:
user_db_settings.llm_api_key_for_byor = key
session.commit()
logger.info(
'Successfully stored BYOR key in user settings',
extra={'user_id': user_id},
)
else:
logger.warning(
'User settings not found when trying to store BYOR key',
extra={'user_id': user_id},
)
user = UserStore.get_user_by_id(user_id)
if not user:
return None
current_org_id = user.current_org_id
current_org_member: OrgMember = None
for org_member in user.org_members:
if org_member.org_id == current_org_id:
current_org_member = org_member
break
if not current_org_member:
return None
current_org_member.llm_api_key_for_byor = key
OrgMemberStore.update_org_member(current_org_member)
await call_sync_from_async(_update_user_settings)
async def generate_byor_key(user_id: str) -> str | None:
"""Generate a new BYOR key for a user."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'LiteLLM API configuration not found', extra={'user_id': user_id}
)
return None
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
response = await client.post(
f'{LITE_LLM_API_URL}/key/generate',
json={
key = await LiteLlmManager.generate_key(
user_id, user_id, f'BYOR Key - user {user_id}', {'type': 'byor'}
)
if key:
logger.info(
'Successfully generated new BYOR key',
extra={
'user_id': user_id,
'metadata': {'type': 'byor'},
'key_alias': f'BYOR Key - user {user_id}',
'key_length': len(key) if key else 0,
'key_prefix': key[:10] + '...' if key and len(key) > 10 else key,
},
)
response.raise_for_status()
response_json = response.json()
key = response_json.get('key')
if key:
logger.info(
'Successfully generated new BYOR key',
extra={
'user_id': user_id,
'key_length': len(key) if key else 0,
'key_prefix': key[:10] + '...'
if key and len(key) > 10
else key,
},
)
return key
else:
logger.error(
'Failed to generate BYOR LLM API key - no key in response',
extra={'user_id': user_id, 'response_json': response_json},
)
return None
return key
else:
logger.error(
'Failed to generate BYOR LLM API key - no key in response',
extra={'user_id': user_id},
)
return None
except Exception as e:
logger.exception(
'Error generating BYOR key',
@@ -114,30 +102,14 @@ async def generate_byor_key(user_id: str) -> str | None:
async def delete_byor_key_from_litellm(user_id: str, byor_key: str) -> bool:
"""Delete the BYOR key from LiteLLM using the key directly."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'LiteLLM API configuration not found', extra={'user_id': user_id}
)
return False
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
# Delete the key directly using the key value
delete_url = f'{LITE_LLM_API_URL}/key/delete'
delete_payload = {'keys': [byor_key]}
delete_response = await client.post(delete_url, json=delete_payload)
delete_response.raise_for_status()
logger.info(
'Successfully deleted BYOR key from LiteLLM',
extra={'user_id': user_id},
)
return True
await LiteLlmManager.delete_key(byor_key)
logger.info(
'Successfully deleted BYOR key from LiteLLM',
extra={'user_id': user_id},
)
return True
except Exception as e:
logger.exception(
'Error deleting BYOR key from LiteLLM',
@@ -315,15 +287,6 @@ async def refresh_llm_api_key_for_byor(user_id: str = Depends(get_user_id)):
logger.info('Starting BYOR LLM API key refresh', extra={'user_id': user_id})
try:
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'LiteLLM API configuration not found', extra={'user_id': user_id}
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='LiteLLM API configuration not found',
)
# Get the existing BYOR key from the database
existing_byor_key = await get_byor_key_from_db(user_id)

View File

@@ -1,3 +1,4 @@
import uuid
import warnings
from datetime import datetime, timezone
from typing import Annotated, Literal, Optional
@@ -17,12 +18,12 @@ from server.auth.constants import (
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
from server.auth.saas_user_auth import SaasUserAuth
from server.auth.token_manager import TokenManager
from server.config import get_config, sign_token
from server.config import sign_token
from server.constants import IS_FEATURE_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
from storage.user import User
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler
@@ -31,6 +32,7 @@ from openhands.server.services.conversation_service import create_provider_token
from openhands.server.shared import config
from openhands.server.user_auth import get_access_token
from openhands.server.user_auth.user_auth import get_user_auth
from openhands.utils.async_utils import call_sync_from_async
with warnings.catch_warnings():
warnings.simplefilter('ignore')
@@ -82,7 +84,8 @@ def get_cookie_domain(request: Request) -> str | None:
# for now just use the full hostname except for staging stacks.
return (
None
if (request.url.hostname or '').endswith('staging.all-hand.dev')
if not request.url.hostname
or request.url.hostname.endswith('staging.all-hands.dev')
else request.url.hostname
)
@@ -146,6 +149,21 @@ async def keycloak_callback(
)
user_id = user_info['sub']
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
if not user:
user = await UserStore.create_user(user_id, user_info)
if not user:
logger.error(f'Failed to authenticate user {user_info["preferred_username"]}')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
'error': f'Failed to authenticate user {user_info["preferred_username"]}'
},
)
logger.info(f'Logging in user {str(user.id)} in org {user.current_org_id}')
# default to github IDP for now.
# TODO: remove default once Keycloak is updated universally with the new attribute.
idp: str = user_info.get('identity_provider', ProviderType.GITHUB.value)
@@ -220,15 +238,7 @@ async def keycloak_callback(
f'&state={state}'
)
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
user_settings = settings_store.get_user_settings_by_keycloak_id(user_id)
has_accepted_tos = (
user_settings is not None and user_settings.accepted_tos is not None
)
has_accepted_tos = user.accepted_tos is not None
# If the user hasn't accepted the TOS, redirect to the TOS page
if not has_accepted_tos:
encoded_redirect_url = quote(redirect_url, safe='')
@@ -346,28 +356,20 @@ async def accept_tos(request: Request):
redirect_url = body.get('redirect_url', str(request.base_url))
# Update user settings with TOS acceptance
accepted_tos: datetime = datetime.now(timezone.utc)
with session_maker() as session:
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == user_id)
.first()
)
if user_settings:
user_settings.accepted_tos = datetime.now(timezone.utc)
session.merge(user_settings)
else:
# Create user settings if they don't exist
user_settings = UserSettings(
keycloak_user_id=user_id,
accepted_tos=datetime.now(timezone.utc),
user_version=0, # This will trigger a migration to the latest version on next load
user = session.query(User).filter(User.id == uuid.UUID(user_id)).first()
if not user:
session.rollback()
logger.error('User for {user_id} not found.')
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'User does not exist'},
)
session.add(user_settings)
user.accepted_tos = accepted_tos
session.commit()
logger.info(f'User {user_id} accepted TOS')
logger.info(f'User {user_id} accepted TOS')
response = JSONResponse(
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}

View File

@@ -2,32 +2,23 @@
import typing
from datetime import UTC, datetime
from decimal import Decimal
from enum import Enum
import httpx
import stripe
from dateutil.relativedelta import relativedelta # type: ignore
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import JSONResponse, RedirectResponse
from fastapi.responses import RedirectResponse
from integrations import stripe_service
from pydantic import BaseModel
from server.config import get_config
from server.constants import (
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
STRIPE_API_KEY,
STRIPE_WEBHOOK_SECRET,
SUBSCRIPTION_PRICE_DATA,
get_default_litellm_model,
)
from server.logger import logger
from storage.billing_session import BillingSession
from storage.database import session_maker
from storage.saas_settings_store import SaasSettingsStore
from storage.subscription_access import SubscriptionAccess
from storage.lite_llm_manager import LiteLlmManager
from storage.user_store import UserStore
from openhands.server.user_auth import get_user_id
from openhands.utils.http_session import httpx_verify_option
from openhands.utils.async_utils import call_sync_from_async
stripe.api_key = STRIPE_API_KEY
billing_router = APIRouter(prefix='/api/billing')
@@ -64,23 +55,10 @@ def validate_saas_environment(request: Request) -> None:
)
class BillingSessionType(Enum):
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
class GetCreditsResponse(BaseModel):
credits: Decimal | None = None
class SubscriptionAccessResponse(BaseModel):
start_at: datetime
end_at: datetime
created_at: datetime
cancelled_at: datetime | None = None
stripe_subscription_id: str | None = None
class CreateCheckoutSessionRequest(BaseModel):
amount: int
@@ -111,117 +89,23 @@ def calculate_credits(user_info: LiteLlmUserInfo) -> float:
async def get_credits(user_id: str = Depends(get_user_id)) -> GetCreditsResponse:
if not stripe_service.STRIPE_API_KEY:
return GetCreditsResponse()
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
user_json = await _get_litellm_user(client, user_id)
credits = calculate_credits(user_json['user_info'])
user = await call_sync_from_async(UserStore.get_user_by_id, user_id)
user_team_info = await LiteLlmManager.get_user_team_info(
user_id, str(user.current_org_id)
)
# Update to use calculate_credits
spend = user_team_info.get('spend', 0)
max_budget = (user_team_info.get('litellm_budget_table') or {}).get('max_budget', 0)
credits = max(max_budget - spend, 0)
return GetCreditsResponse(credits=Decimal('{:.2f}'.format(credits)))
# Endpoint to retrieve user's current subscription access
@billing_router.get('/subscription-access')
async def get_subscription_access(
user_id: str = Depends(get_user_id),
) -> SubscriptionAccessResponse | None:
"""Get details of the currently valid subscription for the user."""
with session_maker() as session:
now = datetime.now(UTC)
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.first()
)
if not subscription_access:
return None
return SubscriptionAccessResponse(
start_at=subscription_access.start_at,
end_at=subscription_access.end_at,
created_at=subscription_access.created_at,
cancelled_at=subscription_access.cancelled_at,
stripe_subscription_id=subscription_access.stripe_subscription_id,
)
# Endpoint to check if a user has entered a payment method into stripe
@billing_router.post('/has-payment-method')
async def has_payment_method(user_id: str = Depends(get_user_id)) -> bool:
if not user_id:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
return await stripe_service.has_payment_method(user_id)
# Endpoint to cancel user's subscription
@billing_router.post('/cancel-subscription')
async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONResponse:
"""Cancel user's active subscription at the end of the current billing period."""
if not user_id:
raise HTTPException(status.HTTP_401_UNAUTHORIZED)
with session_maker() as session:
# Find the user's active subscription
now = datetime.now(UTC)
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not already cancelled
.first()
)
if not subscription_access:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail='No active subscription found',
)
if not subscription_access.stripe_subscription_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Cannot cancel subscription: missing Stripe subscription ID',
)
try:
# Cancel the subscription in Stripe at period end
await stripe.Subscription.modify_async(
subscription_access.stripe_subscription_id, cancel_at_period_end=True
)
# Update local database
subscription_access.cancelled_at = datetime.now(UTC)
session.merge(subscription_access)
session.commit()
logger.info(
'subscription_cancelled',
extra={
'user_id': user_id,
'stripe_subscription_id': subscription_access.stripe_subscription_id,
'subscription_access_id': subscription_access.id,
'end_at': subscription_access.end_at,
},
)
return JSONResponse(
{'status': 'success', 'message': 'Subscription cancelled successfully'}
)
except stripe.StripeError as e:
logger.error(
'stripe_cancellation_failed',
extra={
'user_id': user_id,
'stripe_subscription_id': subscription_access.stripe_subscription_id,
'error': str(e),
},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f'Failed to cancel subscription: {str(e)}',
)
return await stripe_service.has_payment_method_by_user_id(user_id)
# Endpoint to create a new setup intent in stripe
@@ -230,16 +114,15 @@ async def create_customer_setup_session(
request: Request, user_id: str = Depends(get_user_id)
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
customer_id = await stripe_service.find_or_create_customer(user_id)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_id,
customer=customer_info['customer_id'],
mode='setup',
payment_method_types=['card'],
success_url=f'{request.base_url}?free_credits=success',
cancel_url=f'{request.base_url}',
)
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
# Endpoint to create a new Stripe checkout session for credit purchase
@@ -251,9 +134,9 @@ async def create_checkout_session(
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
customer_id = await stripe_service.find_or_create_customer(user_id)
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_id,
customer=customer_info['customer_id'],
line_items=[
{
'price_data': {
@@ -266,7 +149,7 @@ async def create_checkout_session(
'tax_behavior': 'exclusive',
},
'quantity': 1,
}
},
],
mode='payment',
payment_method_types=['card'],
@@ -279,8 +162,9 @@ async def create_checkout_session(
logger.info(
'created_stripe_checkout_session',
extra={
'stripe_customer_id': customer_id,
'stripe_customer_id': customer_info['customer_id'],
'user_id': user_id,
'org_id': customer_info['org_id'],
'amount': body.amount,
'checkout_session_id': checkout_session.id,
},
@@ -289,105 +173,14 @@ async def create_checkout_session(
billing_session = BillingSession(
id=checkout_session.id,
user_id=user_id,
org_id=customer_info['org_id'],
price=body.amount,
price_code='NA',
billing_session_type=BillingSessionType.DIRECT_PAYMENT.value,
)
session.add(billing_session)
session.commit()
return CreateBillingSessionResponse(redirect_url=checkout_session.url) # type: ignore[arg-type]
@billing_router.post('/subscription-checkout-session')
async def create_subscription_checkout_session(
request: Request,
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
user_id: str = Depends(get_user_id),
) -> CreateBillingSessionResponse:
validate_saas_environment(request)
# Prevent duplicate subscriptions for the same user
with session_maker() as session:
now = datetime.now(UTC)
existing_active_subscription = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.status == 'ACTIVE')
.filter(SubscriptionAccess.user_id == user_id)
.filter(SubscriptionAccess.start_at <= now)
.filter(SubscriptionAccess.end_at >= now)
.filter(SubscriptionAccess.cancelled_at.is_(None)) # Not cancelled
.first()
)
if existing_active_subscription:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail='Cannot create subscription: User already has an active subscription that has not been cancelled',
)
customer_id = await stripe_service.find_or_create_customer(user_id)
subscription_price_data = SUBSCRIPTION_PRICE_DATA[billing_session_type.value]
checkout_session = await stripe.checkout.Session.create_async(
customer=customer_id,
line_items=[
{
'price_data': subscription_price_data,
'quantity': 1,
}
],
mode='subscription',
payment_method_types=['card'],
saved_payment_method_options={
'payment_method_save': 'enabled',
},
success_url=f'{request.base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
cancel_url=f'{request.base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
subscription_data={
'metadata': {
'user_id': user_id,
'billing_session_type': billing_session_type.value,
}
},
)
logger.info(
'created_stripe_subscription_checkout_session',
extra={
'stripe_customer_id': customer_id,
'user_id': user_id,
'checkout_session_id': checkout_session.id,
'billing_session_type': billing_session_type.value,
},
)
with session_maker() as session:
billing_session = BillingSession(
id=checkout_session.id,
user_id=user_id,
price=subscription_price_data['unit_amount'],
price_code='NA',
billing_session_type=billing_session_type.value,
)
session.add(billing_session)
session.commit()
return CreateBillingSessionResponse(
redirect_url=typing.cast(str, checkout_session.url)
)
@billing_router.get('/create-subscription-checkout-session')
async def create_subscription_checkout_session_via_get(
request: Request,
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
user_id: str = Depends(get_user_id),
) -> RedirectResponse:
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
validate_saas_environment(request)
response = await create_subscription_checkout_session(
request, billing_session_type, user_id
)
return RedirectResponse(response.redirect_url)
return CreateBillingSessionResponse(redirect_url=checkout_session.url)
# Callback endpoint for successful Stripe payments - updates user credits and billing session status
@@ -409,15 +202,6 @@ async def success_callback(session_id: str, request: Request):
)
raise HTTPException(status.HTTP_400_BAD_REQUEST)
# Any non direct payment (Subscription) is processed in the invoice_payment.paid by the webhook
if (
billing_session.billing_session_type
!= BillingSessionType.DIRECT_PAYMENT.value
):
return RedirectResponse(
f'{request.base_url}settings?checkout=success', status_code=302
)
stripe_session = stripe.checkout.Session.retrieve(session_id)
if stripe_session.status != 'complete':
# Hopefully this never happens - we get a redirect from stripe where the payment is not yet complete
@@ -431,31 +215,39 @@ async def success_callback(session_id: str, request: Request):
)
raise HTTPException(status.HTTP_400_BAD_REQUEST)
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
# Update max budget in litellm
user_json = await _get_litellm_user(client, billing_session.user_id)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
new_max_budget = (
(user_json.get('user_info') or {}).get('max_budget') or 0
) + add_credits
await _upsert_litellm_user(client, billing_session.user_id, new_max_budget)
user = await call_sync_from_async(
UserStore.get_user_by_id, billing_session.user_id
)
user_team_info = await LiteLlmManager.get_user_team_info(
billing_session.user_id, str(user.current_org_id)
)
amount_subtotal = stripe_session.amount_subtotal or 0
add_credits = amount_subtotal / 100
max_budget = (user_team_info.get('litellm_budget_table') or {}).get(
'max_budget', 0
)
new_max_budget = max_budget + add_credits
# Store transaction status
billing_session.status = 'completed'
billing_session.price = amount_subtotal
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
logger.info(
'stripe_checkout_success',
extra={
'amount_subtotal': stripe_session.amount_subtotal,
'user_id': billing_session.user_id,
'checkout_session_id': billing_session.id,
'stripe_customer_id': stripe_session.customer,
},
)
session.commit()
await LiteLlmManager.update_team_and_users_budget(
str(user.current_org_id), new_max_budget
)
# Store transaction status
billing_session.status = 'completed'
billing_session.price = add_credits
billing_session.updated_at = datetime.now(UTC)
session.merge(billing_session)
logger.info(
'stripe_checkout_success',
extra={
'amount_subtotal': stripe_session.amount_subtotal,
'user_id': billing_session.user_id,
'org_id': str(user.current_org_id),
'checkout_session_id': billing_session.id,
'stripe_customer_id': stripe_session.customer,
},
)
session.commit()
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=success', status_code=302
@@ -485,206 +277,6 @@ async def cancel_callback(session_id: str, request: Request):
session.merge(billing_session)
session.commit()
# Redirect credit purchases to billing screen, subscriptions to LLM settings
if (
billing_session.billing_session_type
== BillingSessionType.DIRECT_PAYMENT.value
):
return RedirectResponse(
f'{request.base_url}settings/billing?checkout=cancel',
status_code=302,
)
else:
return RedirectResponse(
f'{request.base_url}settings?checkout=cancel', status_code=302
)
# If no billing session found, default to LLM settings (subscription flow)
return RedirectResponse(
f'{request.base_url}settings?checkout=cancel', status_code=302
f'{request.base_url}settings/billing?checkout=cancel', status_code=302
)
@billing_router.post('/stripe-webhook')
async def stripe_webhook(request: Request) -> JSONResponse:
"""Endpoint for stripe webhooks."""
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
try:
event = stripe.Webhook.construct_event(
payload, sig_header, STRIPE_WEBHOOK_SECRET
)
except ValueError as e:
# Invalid payload
raise HTTPException(status_code=400, detail=f'Invalid payload: {e}')
except stripe.SignatureVerificationError as e:
# Invalid signature
raise HTTPException(status_code=400, detail=f'Invalid signature: {e}')
# Handle the event
logger.info('stripe_webhook_event', extra={'event': event})
event_type = event['type']
if event_type == 'invoice.paid':
invoice = event['data']['object']
amount_paid = invoice.amount_paid
metadata = invoice.parent.subscription_details.metadata # type: ignore
billing_session_type = metadata.billing_session_type
assert (
amount_paid == SUBSCRIPTION_PRICE_DATA[billing_session_type]['unit_amount']
)
user_id = metadata.user_id
start_at = datetime.now(UTC)
if billing_session_type == BillingSessionType.MONTHLY_SUBSCRIPTION.value:
end_at = start_at + relativedelta(months=1)
else:
raise ValueError(f'unknown_billing_session_type:{billing_session_type}')
with session_maker() as session:
subscription_access = SubscriptionAccess(
status='ACTIVE',
user_id=user_id,
start_at=start_at,
end_at=end_at,
amount_paid=amount_paid,
stripe_invoice_payment_id=invoice.payment_intent,
stripe_subscription_id=invoice.subscription, # Store Stripe subscription ID
)
session.add(subscription_access)
session.commit()
elif event_type == 'customer.subscription.updated':
subscription = event['data']['object']
subscription_id = subscription['id']
# Handle subscription cancellation
if subscription.get('cancel_at_period_end') is True:
with session_maker() as session:
subscription_access = (
session.query(SubscriptionAccess)
.filter(
SubscriptionAccess.stripe_subscription_id == subscription_id
)
.filter(SubscriptionAccess.status == 'ACTIVE')
.first()
)
if subscription_access and not subscription_access.cancelled_at:
subscription_access.cancelled_at = datetime.now(UTC)
session.merge(subscription_access)
session.commit()
logger.info(
'subscription_cancelled_via_webhook',
extra={
'stripe_subscription_id': subscription_id,
'user_id': subscription_access.user_id,
'subscription_access_id': subscription_access.id,
},
)
elif event_type == 'customer.subscription.deleted':
subscription = event['data']['object']
subscription_id = subscription['id']
with session_maker() as session:
subscription_access = (
session.query(SubscriptionAccess)
.filter(SubscriptionAccess.stripe_subscription_id == subscription_id)
.filter(SubscriptionAccess.status == 'ACTIVE')
.first()
)
if subscription_access:
subscription_access.status = 'DISABLED'
subscription_access.updated_at = datetime.now(UTC)
session.merge(subscription_access)
session.commit()
# Reset user settings to free tier defaults
reset_user_to_free_tier_settings(subscription_access.user_id)
logger.info(
'subscription_expired_reset_to_free_tier',
extra={
'stripe_subscription_id': subscription_id,
'user_id': subscription_access.user_id,
'subscription_access_id': subscription_access.id,
},
)
else:
logger.info('stripe_webhook_unhandled_event_type', extra={'type': event_type})
return JSONResponse({'status': 'success'})
def reset_user_to_free_tier_settings(user_id: str) -> None:
"""Reset user settings to free tier defaults when subscription ends."""
config = get_config()
settings_store = SaasSettingsStore(
user_id=user_id, session_maker=session_maker, config=config
)
with session_maker() as session:
user_settings = settings_store.get_user_settings_by_keycloak_id(
user_id, session
)
if user_settings:
user_settings.llm_model = get_default_litellm_model()
user_settings.llm_api_key = None
user_settings.llm_api_key_for_byor = None
user_settings.llm_base_url = LITE_LLM_API_URL
user_settings.max_budget_per_task = None
user_settings.confirmation_mode = False
user_settings.enable_solvability_analysis = False
user_settings.security_analyzer = 'llm'
user_settings.agent = 'CodeActAgent'
user_settings.language = 'en'
user_settings.enable_default_condenser = True
user_settings.enable_sound_notifications = False
user_settings.enable_proactive_conversation_starters = True
user_settings.user_consents_to_analytics = False
session.merge(user_settings)
session.commit()
logger.info(
'user_settings_reset_to_free_tier',
extra={
'user_id': user_id,
'reset_timestamp': datetime.now(UTC).isoformat(),
},
)
async def _get_litellm_user(client: httpx.AsyncClient, user_id: str) -> dict:
"""Get a user from litellm with the id matching that given.
If no such user exists, returns a dummy user in the format:
`{'user_id': '<USER_ID>', 'user_info': {'spend': 0}, 'keys': [], 'teams': []}`
"""
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
)
response.raise_for_status()
return response.json()
async def _upsert_litellm_user(
client: httpx.AsyncClient, user_id: str, max_budget: float
):
"""Insert / Update a user in litellm."""
response = await client.post(
f'{LITE_LLM_API_URL}/user/update',
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
json={
'user_id': user_id,
'max_budget': max_budget,
},
)
response.raise_for_status()

View File

@@ -6,7 +6,7 @@ from threading import Thread
from fastapi import APIRouter, FastAPI
from sqlalchemy import func, select
from storage.database import a_session_maker, engine, session_maker
from storage.user_settings import UserSettings
from storage.user import User
from openhands.core.logger import openhands_logger as logger
from openhands.utils.async_utils import wait_all
@@ -127,7 +127,7 @@ def _db_check(delay: int):
delay: Number of seconds to hold the database connection
"""
with session_maker() as session:
num_users = session.query(UserSettings).count()
num_users = session.query(User).count()
time.sleep(delay)
logger.info(
'check',
@@ -155,7 +155,7 @@ async def _a_db_check(delay: int):
delay: Number of seconds to hold the database connection
"""
async with a_session_maker() as a_session:
stmt = select(func.count(UserSettings.id))
stmt = select(func.count(User.id))
num_users = await a_session.execute(stmt)
await asyncio.sleep(delay)
logger.info(f'a_num_users:{num_users.scalar_one()}')

View File

@@ -21,7 +21,7 @@ from server.utils.conversation_callback_utils import (
update_conversation_stats,
)
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.server.shared import conversation_manager
@@ -226,12 +226,12 @@ def _parse_conversation_id_and_subpath(path: str) -> Tuple[str, str]:
def _get_user_id(conversation_id: str) -> str:
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
.first()
)
return conversation_metadata.user_id
return str(conversation_metadata_saas.user_id)
async def _get_session_api_key(user_id: str, conversation_id: str) -> str | None:

View File

@@ -5,7 +5,7 @@ from pydantic import BaseModel, Field
from sqlalchemy.future import select
from storage.database import session_maker
from storage.feedback import ConversationFeedback
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.events.event_store import EventStore
from openhands.server.shared import file_store
@@ -33,10 +33,10 @@ async def get_event_ids(conversation_id: str, user_id: str) -> List[int]:
def _verify_conversation():
with session_maker() as session:
metadata = (
session.query(StoredConversationMetadata)
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadata.conversation_id == conversation_id,
StoredConversationMetadata.user_id == user_id,
StoredConversationMetadataSaas.conversation_id == conversation_id,
StoredConversationMetadataSaas.user_id == user_id,
)
.first()
)

View File

@@ -15,7 +15,6 @@ from integrations.slack.slack_manager import SlackManager
from integrations.utils import (
HOST_URL,
)
from pydantic import SecretStr
from server.auth.constants import (
KEYCLOAK_CLIENT_ID,
KEYCLOAK_REALM_NAME,
@@ -35,9 +34,11 @@ from slack_sdk.web.async_client import AsyncWebClient
from storage.database import session_maker
from storage.slack_team_store import SlackTeamStore
from storage.slack_user import SlackUser
from storage.user_store import UserStore
from openhands.integrations.service_types import ProviderType
from openhands.server.shared import config, sio
from openhands.utils.async_utils import call_sync_from_async
signature_verifier = SignatureVerifier(signing_secret=SLACK_SIGNING_SECRET)
slack_router = APIRouter(prefix='/slack')
@@ -79,6 +80,14 @@ async def install_callback(
status_code=400,
)
if not config.jwt_secret:
logger.error('slack_install_callback_error JWT not configured.')
return _html_response(
title='Error',
description=html.escape('JWT not configured'),
status_code=500,
)
try:
client = AsyncWebClient() # no prepared token needed for this
# Complete the installation by calling oauth.v2.access API method
@@ -94,16 +103,17 @@ async def install_callback(
# Create a state variable for keycloak oauth
payload = {}
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
if state:
payload = jwt.decode(
state, jwt_secret.get_secret_value(), algorithms=['HS256']
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
)
payload['slack_user_id'] = authed_user.get('id')
payload['bot_access_token'] = bot_access_token
payload['team_id'] = team_id
state = jwt.encode(payload, jwt_secret.get_secret_value(), algorithm='HS256')
state = jwt.encode(
payload, config.jwt_secret.get_secret_value(), algorithm='HS256'
)
# Redirect into keycloak
scope = quote('openid email profile offline_access')
@@ -149,9 +159,16 @@ async def keycloak_callback(
status_code=400,
)
jwt_secret: SecretStr = config.jwt_secret # type: ignore[assignment]
if not config.jwt_secret:
logger.error('problem_retrieving_keycloak_tokens JWT not configured.')
return _html_response(
title='Error',
description=html.escape('JWT not configured'),
status_code=500,
)
payload: dict[str, str] = jwt.decode(
state, jwt_secret.get_secret_value(), algorithms=['HS256']
state, config.jwt_secret.get_secret_value(), algorithms=['HS256']
)
slack_user_id = payload['slack_user_id']
bot_access_token = payload['bot_access_token']
@@ -180,6 +197,13 @@ async def keycloak_callback(
user_info = await token_manager.get_user_info(keycloak_access_token)
keycloak_user_id = user_info['sub']
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
if not user:
return _html_response(
title='Failed to authenticate.',
description=f'Please re-login into <a href="{HOST_URL}" style="color:#ecedee;text-decoration:underline;">OpenHands Cloud</a>. Then try <a href="https://docs.all-hands.dev/usage/cloud/slack-installation" style="color:#ecedee;text-decoration:underline;">installing the OpenHands Slack App</a> again',
status_code=400,
)
# These tokens are offline access tokens - store them!
await token_manager.store_offline_token(keycloak_user_id, keycloak_refresh_token)
@@ -211,6 +235,7 @@ async def keycloak_callback(
slack_display_name = slack_user_info.data['user']['profile']['display_name']
slack_user = SlackUser(
keycloak_user_id=keycloak_user_id,
org_id=user.current_org_id,
slack_user_id=slack_user_id,
slack_display_name=slack_display_name,
)
@@ -305,7 +330,7 @@ async def on_form_interaction(request: Request, background_tasks: BackgroundTask
body = await request.body()
form = await request.form()
payload = json.loads(form.get('payload')) # type: ignore[arg-type]
payload = json.loads(form.get('payload'))
logger.info('slack_on_form_interaction', extra={'payload': payload})

View File

@@ -20,7 +20,10 @@ from server.utils.conversation_callback_utils import (
from sqlalchemy import orm
from storage.api_key_store import ApiKeyStore
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from openhands.controller.agent import Agent
from openhands.core.config import LLMConfig, OpenHandsConfig
@@ -530,16 +533,18 @@ class SaasNestedConversationManager(ConversationManager):
"""
with session_maker() as session:
conversation_metadata = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.conversation_id == conversation_id)
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id == conversation_id
)
.first()
)
if not conversation_metadata:
if not conversation_metadata_saas:
raise ValueError(f'No conversation found {conversation_id}')
return conversation_metadata.user_id
return str(conversation_metadata_saas.user_id)
async def _get_runtime_status_from_nested_runtime(
self, session_api_key: Any | None, nested_url: str, conversation_id: str
@@ -868,9 +873,17 @@ class SaasNestedConversationManager(ConversationManager):
with session_maker() as session:
# Only include conversations updated in the past week
one_week_ago = datetime.now(UTC) - timedelta(days=7)
query = session.query(StoredConversationMetadata.conversation_id).filter(
StoredConversationMetadata.user_id == user_id,
StoredConversationMetadata.last_updated_at >= one_week_ago,
query = (
session.query(StoredConversationMetadata.conversation_id)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(
StoredConversationMetadataSaas.user_id == user_id,
StoredConversationMetadata.last_updated_at >= one_week_ago,
)
)
user_conversation_ids = set(query)
return user_conversation_ids
@@ -944,11 +957,16 @@ class SaasNestedConversationManager(ConversationManager):
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.first()
)
if conversation_metadata is None:
conversation_metadata_saas = (
session.query(StoredConversationMetadataSaas)
.filter(StoredConversationMetadataSaas.conversation_id == conversation_id)
.first()
)
if conversation_metadata is None or conversation_metadata_saas is None:
# Conversation is running in different server
return
user_id = conversation_metadata.user_id
user_id = conversation_metadata_saas.user_id
# Get the id of the next event which is not present
events_dir = get_conversation_events_dir(

View File

@@ -11,7 +11,6 @@ from storage.conversation_callback import (
)
from storage.conversation_work import ConversationWork
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.core.config import load_openhands_config
from openhands.core.schema.agent import AgentState
@@ -126,6 +125,12 @@ def update_conversation_metadata(conversation_id: str, content: dict):
conversation_id: The conversation ID to update
content: The metadata content to update
"""
# Local import fixes the lazy-loading problem
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
logger.debug(
'update_conversation_metadata',
extra={

View File

@@ -0,0 +1,87 @@
from storage.api_key import ApiKey
from storage.auth_tokens import AuthTokens
from storage.billing_session import BillingSession
from storage.billing_session_type import BillingSessionType
from storage.conversation_callback import CallbackStatus, ConversationCallback
from storage.conversation_work import ConversationWork
from storage.experiment_assignment import ExperimentAssignment
from storage.feedback import ConversationFeedback, Feedback
from storage.github_app_installation import GithubAppInstallation
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.jira_conversation import JiraConversation
from storage.jira_dc_conversation import JiraDcConversation
from storage.jira_dc_user import JiraDcUser
from storage.jira_dc_workspace import JiraDcWorkspace
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from storage.linear_conversation import LinearConversation
from storage.linear_user import LinearUser
from storage.linear_workspace import LinearWorkspace
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.openhands_pr import OpenhandsPR
from storage.org import Org
from storage.org_member import OrgMember
from storage.proactive_convos import ProactiveConversation
from storage.role import Role
from storage.slack_conversation import SlackConversation
from storage.slack_team import SlackTeam
from storage.slack_user import SlackUser
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.stored_offline_token import StoredOfflineToken
from storage.stored_repository import StoredRepository
from storage.stripe_customer import StripeCustomer
from storage.subscription_access import SubscriptionAccess
from storage.subscription_access_status import SubscriptionAccessStatus
from storage.user import User
from storage.user_repo_map import UserRepositoryMap
from storage.user_settings import UserSettings
__all__ = [
'ApiKey',
'AuthTokens',
'BillingSession',
'BillingSessionType',
'CallbackStatus',
'ConversationCallback',
'ConversationFeedback',
'StoredConversationMetadataSaas',
'ConversationWork',
'ExperimentAssignment',
'Feedback',
'GithubAppInstallation',
'GitlabWebhook',
'JiraConversation',
'JiraDcConversation',
'JiraDcUser',
'JiraDcWorkspace',
'JiraUser',
'JiraWorkspace',
'LinearConversation',
'LinearUser',
'LinearWorkspace',
'MaintenanceTask',
'MaintenanceTaskStatus',
'OpenhandsPR',
'Org',
'OrgMember',
'ProactiveConversation',
'Role',
'SlackConversation',
'SlackTeam',
'SlackUser',
'StoredConversationMetadata',
'StoredOfflineToken',
'StoredRepository',
'StoredCustomSecrets',
'StripeCustomer',
'SubscriptionAccess',
'SubscriptionAccessStatus',
'User',
'UserRepositoryMap',
'UserSettings',
'WebhookStatus',
]

View File

@@ -1,4 +1,6 @@
from sqlalchemy import Column, DateTime, Integer, String, text
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from storage.base import Base
@@ -11,9 +13,13 @@ class ApiKey(Base):
id = Column(Integer, primary_key=True, autoincrement=True)
key = Column(String(255), nullable=False, unique=True, index=True)
user_id = Column(String(255), nullable=False, index=True)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
name = Column(String(255), nullable=True)
created_at = Column(
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
)
last_used_at = Column(DateTime, nullable=True)
expires_at = Column(DateTime, nullable=True)
# Relationships
org = relationship('Org', back_populates='api_keys')

View File

@@ -9,6 +9,7 @@ from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
from storage.api_key import ApiKey
from storage.database import session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
@@ -36,10 +37,15 @@ class ApiKeyStore:
The generated API key
"""
api_key = self.generate_api_key()
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
key_record = ApiKey(
key=api_key, user_id=user_id, name=name, expires_at=expires_at
key=api_key,
user_id=user_id,
org_id=org_id,
name=name,
expires_at=expires_at,
)
session.add(key_record)
session.commit()
@@ -99,8 +105,15 @@ class ApiKeyStore:
def list_api_keys(self, user_id: str) -> list[dict]:
"""List all API keys for a user."""
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
keys = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
return [
{
@@ -115,9 +128,14 @@ class ApiKeyStore:
]
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
for key in keys:
if key.name == 'MCP_API_KEY':

View File

@@ -1,6 +1,8 @@
from datetime import UTC, datetime
from sqlalchemy import DECIMAL, Column, DateTime, Enum, String
from sqlalchemy import DECIMAL, Column, DateTime, Enum, ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from storage.base import Base
@@ -11,9 +13,9 @@ class BillingSession(Base): # type: ignore
"""
__tablename__ = 'billing_sessions'
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
status = Column(
Enum(
'in_progress',
@@ -24,15 +26,6 @@ class BillingSession(Base): # type: ignore
),
default='in_progress',
)
billing_session_type = Column(
Enum(
'DIRECT_PAYMENT',
'MONTHLY_SUBSCRIPTION',
name='billing_session_type_enum',
),
nullable=False,
default='DIRECT_PAYMENT',
)
price = Column(DECIMAL(19, 4), nullable=False)
price_code = Column(String, nullable=False)
created_at = Column(
@@ -43,3 +36,6 @@ class BillingSession(Base): # type: ignore
DateTime(timezone=True),
default=lambda: datetime.now(UTC), # type: ignore[attr-defined]
)
# Relationships
org = relationship('Org', back_populates='billing_sessions')

View File

@@ -1,5 +1,6 @@
import asyncio
import os
import sys
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
@@ -7,6 +8,9 @@ from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy.util import await_only
# Check if we're running in a test environment
IS_TESTING = 'pytest' in sys.modules
DB_HOST = os.environ.get('DB_HOST', 'localhost') # for non-GCP environments
DB_PORT = os.environ.get('DB_PORT', '5432') # for non-GCP environments
DB_USER = os.environ.get('DB_USER', 'postgres')

View File

@@ -0,0 +1,114 @@
import binascii
import hashlib
from base64 import b64decode, b64encode
from cryptography.fernet import Fernet, InvalidToken
from pydantic import SecretStr
from server.config import get_config
_jwt_service = None
_fernet = None
def encrypt_model(encrypt_keys: list, model_instance) -> dict:
return encrypt_kwargs(encrypt_keys, model_to_kwargs(model_instance))
def decrypt_model(decrypt_keys: list, model_instance) -> dict:
return decrypt_kwargs(decrypt_keys, model_to_kwargs(model_instance))
def encrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
if value is None:
continue
if isinstance(value, dict):
encrypt_kwargs(encrypt_keys, value)
continue
if key in encrypt_keys:
value = encrypt_value(value)
kwargs[key] = value
return kwargs
def decrypt_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
try:
if value is None:
continue
if key in encrypt_keys:
value = decrypt_value(value)
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
return kwargs
def encrypt_value(value: str | SecretStr) -> str:
return get_jwt_service().create_jwe_token(
{'v': value.get_secret_value() if isinstance(value, SecretStr) else value}
)
def decrypt_value(value: str | SecretStr) -> str:
token = get_jwt_service().decrypt_jwe_token(
value.get_secret_value() if isinstance(value, SecretStr) else value
)
return token['v']
def get_jwt_service():
from openhands.app_server.config import get_global_config
global _jwt_service
if _jwt_service is None:
jwt_service_injector = get_global_config().jwt
assert jwt_service_injector is not None
_jwt_service = jwt_service_injector.get_jwt_service()
return _jwt_service
def decrypt_legacy_model(decrypt_keys: list, model_instance) -> dict:
return decrypt_legacy_kwargs(decrypt_keys, model_to_kwargs(model_instance))
def decrypt_legacy_kwargs(encrypt_keys: list, kwargs: dict) -> dict:
for key, value in kwargs.items():
try:
if value is None:
continue
if key in encrypt_keys:
value = decrypt_legacy_value(value)
kwargs[key] = value
except binascii.Error:
pass # Key is in legacy format...
except InvalidToken:
pass # Key not encrypted...
return kwargs
def decrypt_legacy_value(value: str | SecretStr) -> str:
if isinstance(value, SecretStr):
return (
get_fernet().decrypt(b64decode(value.get_secret_value().encode())).decode()
)
else:
return get_fernet().decrypt(b64decode(value.encode())).decode()
def get_fernet():
global _fernet
if _fernet is None:
jwt_secret = get_config().jwt_secret.get_secret_value()
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
_fernet = Fernet(fernet_key)
return _fernet
def model_to_kwargs(model_instance):
return {
column.name: getattr(model_instance, column.name)
for column in model_instance.__table__.columns
}

View File

@@ -1,7 +1,16 @@
import sys
from enum import IntEnum
from sqlalchemy import ARRAY, Boolean, Column, DateTime, Integer, String, Text, text
from sqlalchemy import (
ARRAY,
Boolean,
Column,
DateTime,
Integer,
String,
Text,
text,
)
from storage.base import Base

View File

@@ -0,0 +1,674 @@
"""
Store class for managing organizational settings.
"""
import functools
import os
from typing import Any, Awaitable, Callable
import httpx
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import (
DEFAULT_INITIAL_BUDGET,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
ORG_SETTINGS_VERSION,
get_default_litellm_model,
)
from server.logger import logger
from storage.user_settings import UserSettings
from openhands.server.settings import Settings
from openhands.utils.async_utils import call_sync_from_async
class LiteLlmManager:
"""Manage LiteLLM interactions."""
@staticmethod
async def create_entries(
org_id: str,
keycloak_user_id: str,
oss_settings: Settings,
) -> Settings | None:
logger.info(
'SettingsStore:update_settings_with_litellm_default:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
key = LITE_LLM_API_KEY
if not local_deploy:
# Get user info to add to litellm
token_manager = TokenManager()
keycloak_user_info = (
await token_manager.get_user_info_from_user_id(keycloak_user_id) or {}
)
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
)
await LiteLlmManager._create_user(
client, keycloak_user_info.get('email'), keycloak_user_id
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, DEFAULT_INITIAL_BUDGET
)
key = await LiteLlmManager._generate_key(
client,
keycloak_user_id,
org_id,
f'OpenHands Cloud - user {keycloak_user_id}',
None,
)
oss_settings.agent = 'CodeActAgent'
# Use the model corresponding to the current user settings version
oss_settings.llm_model = get_default_litellm_model()
oss_settings.llm_api_key = SecretStr(key)
oss_settings.llm_base_url = LITE_LLM_API_URL
return oss_settings
@staticmethod
async def migrate_entries(
org_id: str,
keycloak_user_id: str,
user_settings: UserSettings,
) -> UserSettings | None:
logger.info(
'SettingsStore:umigrate_lite_llm_entries:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
if not local_deploy:
# Get user info to add to litellm
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
user_json = await LiteLlmManager._get_user(client, keycloak_user_id)
if not user_json:
return None
user_info = user_json['user_info']
max_budget = user_info.get('max_budget', 0.0)
if not max_budget:
# if max_budget is None, then we've already migrated the User
return None
spend = user_info.get('spend', 0.0)
credits = max(max_budget - spend, 0.0)
await LiteLlmManager._create_team(
client, keycloak_user_id, org_id, credits
)
await LiteLlmManager._update_user(
client, keycloak_user_id, max_budget=1000000000.0
)
await LiteLlmManager._add_user_to_team(
client, keycloak_user_id, org_id, credits
)
if user_settings.llm_api_key:
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key,
team_id=org_id,
)
if user_settings.llm_api_key_for_byor:
await LiteLlmManager._update_key(
client,
keycloak_user_id,
user_settings.llm_api_key_for_byor,
team_id=org_id,
)
user_settings.agent = 'CodeActAgent'
# Use the model corresponding to the current user settings version
user_settings.llm_model = get_default_litellm_model()
user_settings.llm_base_url = LITE_LLM_API_URL
return user_settings
@staticmethod
async def update_team_and_users_budget(
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
async with httpx.AsyncClient(
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
}
) as client:
await LiteLlmManager._update_team(client, team_id, None, max_budget)
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
return None
for membership in team_info.get('team_memberships', []):
user_id = membership.get('user_id')
if not user_id:
continue
await LiteLlmManager._update_user_in_team(
client, user_id, team_id, max_budget
)
@staticmethod
async def _create_team(
client: httpx.AsyncClient,
team_alias: str,
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/new',
json={
'team_id': team_id,
'team_alias': team_alias,
'models': [],
'max_budget': max_budget,
'spend': 0,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
},
)
# Team failed to create in litellm - this is an unforseen error state...
if not response.is_success:
if (
response.status_code == 400
and 'already exists. Please use a different team id' in response.text
):
# team already exists, so update, then return
await LiteLlmManager._update_team(
client, team_id, team_alias, max_budget
)
return
logger.error(
'error_creating_litellm_team',
extra={
'status_code': response.status_code,
'text': response.text,
'team_id': team_id,
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _get_team(client: httpx.AsyncClient, team_id: str) -> dict | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
"""Get a team from litellm with the id matching that given."""
response = await client.get(
f'{LITE_LLM_API_URL}/team/info?team_id={team_id}',
)
response.raise_for_status()
return response.json()
@staticmethod
async def _update_team(
client: httpx.AsyncClient,
team_id: str,
team_alias: str | None,
max_budget: float | None,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
json_data: dict[str, Any] = {
'team_id': team_id,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
}
if max_budget is not None:
json_data['max_budget'] = max_budget
if team_alias is not None:
json_data['team_alias'] = team_alias
response = await client.post(
f'{LITE_LLM_API_URL}/team/update',
json=json_data,
)
# Team failed to update in litellm - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_updating_litellm_team',
extra={
'status_code': response.status_code,
'text': response.text,
'team_id': [team_id],
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _create_user(
client: httpx.AsyncClient,
email: str | None,
keycloak_user_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
json={
'user_email': email,
'models': [],
'user_id': keycloak_user_id,
'teams': [LITE_LLM_TEAM_ID],
'auto_create_key': False,
'send_invite_email': False,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
},
)
if not response.is_success:
logger.warning(
'duplicate_user_email',
extra={
'user_id': keycloak_user_id,
'email': email,
},
)
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
json={
'user_email': None,
'models': [],
'user_id': keycloak_user_id,
'teams': [LITE_LLM_TEAM_ID],
'auto_create_key': False,
'send_invite_email': False,
'metadata': {
'version': ORG_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
},
)
# User failed to create in litellm - this is an unforseen error state...
if not response.is_success:
if response.status_code == 400 and 'already exists' in response.text:
# user already exists, just return
return
logger.error(
'error_creating_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
'email': None,
},
)
response.raise_for_status()
@staticmethod
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
"""Get a user from litellm with the id matching that given."""
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
)
response.raise_for_status()
return response.json()
@staticmethod
async def _update_user(
client: httpx.AsyncClient,
keycloak_user_id: str,
**kwargs,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
payload = {
'user_id': keycloak_user_id,
}
payload.update(kwargs)
response = await client.post(
f'{LITE_LLM_API_URL}/user/update',
json=payload,
)
if not response.is_success:
logger.error(
'error_updating_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
},
)
response.raise_for_status()
@staticmethod
async def _update_key(
client: httpx.AsyncClient,
keycloak_user_id: str,
key: str,
**kwargs,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
payload = {
'key': key,
}
payload.update(kwargs)
response = await client.post(
f'{LITE_LLM_API_URL}/key/update',
json=payload,
)
if not response.is_success:
logger.error(
'error_updating_litellm_key',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
},
)
response.raise_for_status()
@staticmethod
async def _delete_user(
client: httpx.AsyncClient,
keycloak_user_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [keycloak_user_id]}
)
if not response.is_success:
logger.error(
'error_deleting_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
},
)
response.raise_for_status()
@staticmethod
async def _add_user_to_team(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/member_add',
json={
'team_id': team_id,
'member': {'user_id': keycloak_user_id, 'role': 'user'},
'max_budget_in_team': max_budget,
},
)
# Failed to add user to team - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_adding_litellm_user_to_team',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
'team_id': [team_id],
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _get_user_team_info(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
) -> dict | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
team_info = await LiteLlmManager._get_team(client, team_id)
if not team_info:
return None
# Filter team_memberships based on team_id and keycloak_user_id
user_membership = next(
(
membership
for membership in team_info.get('team_memberships', [])
if membership.get('user_id') == keycloak_user_id
and membership.get('team_id') == team_id
),
None,
)
return user_membership
@staticmethod
async def _update_user_in_team(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str,
max_budget: float,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/team/member_update',
json={
'team_id': team_id,
'user_id': keycloak_user_id,
'max_budget_in_team': max_budget,
},
)
# Failed to update user in team - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_updating_litellm_user_in_team',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [keycloak_user_id],
'team_id': [team_id],
'max_budget': max_budget,
},
)
response.raise_for_status()
@staticmethod
async def _generate_key(
client: httpx.AsyncClient,
keycloak_user_id: str,
team_id: str | None,
key_alias: str | None,
metadata: dict | None,
) -> str | None:
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
json_data: dict[str, Any] = {
'user_id': keycloak_user_id,
'models': [],
}
if team_id is not None:
json_data['team_id'] = team_id
if key_alias is not None:
json_data['key_alias'] = key_alias
if metadata is not None:
json_data['metadata'] = metadata
response = await client.post(
f'{LITE_LLM_API_URL}/key/generate',
json=json_data,
)
# Failed to generate user key for team - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_generate_user_team_key',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': keycloak_user_id,
'team_id': team_id,
'key_alias': key_alias,
},
)
response.raise_for_status()
response_json = response.json()
key = response_json['key']
logger.info(
'LiteLlmManager:_lite_llm_generate_user_team_key:key_created',
extra={
'user_id': keycloak_user_id,
'team_id': team_id,
'key_alias': key_alias,
},
)
return key
@staticmethod
async def _get_key_info(
client: httpx.AsyncClient,
org_id: str,
keycloak_user_id: str,
) -> dict | None:
from storage.user_store import UserStore
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return None
user = await call_sync_from_async(UserStore.get_user_by_id, keycloak_user_id)
if not user:
return {}
org_member = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return {}
response = await client.get(
f'{LITE_LLM_API_URL}/key/info?key={org_member.llm_api_key}'
)
response.raise_for_status()
response_json = response.json()
key_info = response_json.get('info')
if not key_info:
return {}
return {
'key_max_budget': key_info.get('max_budget'),
'key_spend': key_info.get('spend'),
}
@staticmethod
async def _delete_key(
client: httpx.AsyncClient,
key_id: str,
):
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
logger.warning('LiteLLM API configuration not found')
return
response = await client.post(
f'{LITE_LLM_API_URL}/key/delete',
json={
'keys': [key_id],
},
)
# Failed to key...
if not response.is_success:
if response.status_code == 404:
# key doesn't exist, just return
return
logger.error(
'error_deleting_key',
extra={
'status_code': response.status_code,
'text': response.text,
},
)
response.raise_for_status()
logger.info(
'LiteLlmManager:_delete_key:key_deleted',
)
@staticmethod
def with_http_client(
internal_fn: Callable[..., Awaitable[Any]],
) -> Callable[..., Awaitable[Any]]:
@functools.wraps(internal_fn)
async def wrapper(*args, **kwargs):
async with httpx.AsyncClient(
headers={'x-goog-api-key': LITE_LLM_API_KEY}
) as client:
return await internal_fn(client, *args, **kwargs)
return wrapper
# Public methods with injected client
create_team = staticmethod(with_http_client(_create_team))
get_team = staticmethod(with_http_client(_get_team))
update_team = staticmethod(with_http_client(_update_team))
create_user = staticmethod(with_http_client(_create_user))
get_user = staticmethod(with_http_client(_get_user))
update_user = staticmethod(with_http_client(_update_user))
delete_user = staticmethod(with_http_client(_delete_user))
add_user_to_team = staticmethod(with_http_client(_add_user_to_team))
get_user_team_info = staticmethod(with_http_client(_get_user_team_info))
update_user_in_team = staticmethod(with_http_client(_update_user_in_team))
generate_key = staticmethod(with_http_client(_generate_key))
get_key_info = staticmethod(with_http_client(_get_key_info))
delete_key = staticmethod(with_http_client(_delete_key))

117
enterprise/storage/org.py Normal file
View File

@@ -0,0 +1,117 @@
"""
SQLAlchemy model for Organization.
"""
from uuid import uuid4
from pydantic import SecretStr
from server.constants import DEFAULT_BILLING_MARGIN
from sqlalchemy import JSON, UUID, Boolean, Column, Float, Integer, String
from sqlalchemy.orm import relationship
from storage.base import Base
from storage.encrypt_utils import decrypt_value, encrypt_value
class Org(Base): # type: ignore
"""Organization model."""
__tablename__ = 'org'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
name = Column(String, nullable=False, unique=True)
contact_name = Column(String, nullable=True)
contact_email = Column(String, nullable=True)
agent = Column(String, nullable=True)
default_max_iterations = Column(Integer, nullable=True)
security_analyzer = Column(String, nullable=True)
confirmation_mode = Column(Boolean, nullable=True, default=False)
default_llm_model = Column(String, nullable=True)
# encrypted column, don't set directly, set without the underscore
_default_llm_api_key_for_byor = Column(String, nullable=True)
default_llm_base_url = Column(String, nullable=True)
remote_runtime_resource_factor = Column(Integer, nullable=True)
enable_default_condenser = Column(Boolean, nullable=False, default=True)
billing_margin = Column(Float, nullable=True, default=DEFAULT_BILLING_MARGIN)
enable_proactive_conversation_starters = Column(
Boolean, nullable=False, default=True
)
sandbox_base_container_image = Column(String, nullable=True)
sandbox_runtime_container_image = Column(String, nullable=True)
org_version = Column(Integer, nullable=False, default=0)
mcp_config = Column(JSON, nullable=True)
# encrypted column, don't set directly, set without the underscore
_search_api_key = Column(String, nullable=True)
# encrypted column, don't set directly, set without the underscore
_sandbox_api_key = Column(String, nullable=True)
max_budget_per_task = Column(Float, nullable=True)
enable_solvability_analysis = Column(Boolean, nullable=True, default=False)
v1_enabled = Column(Boolean, nullable=True)
conversation_expiration = Column(Integer, nullable=True)
# Relationships
org_members = relationship('OrgMember', back_populates='org')
current_users = relationship('User', back_populates='current_org')
billing_sessions = relationship('BillingSession', back_populates='org')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='org'
)
user_secrets = relationship('StoredCustomSecrets', back_populates='org')
api_keys = relationship('ApiKey', back_populates='org')
slack_conversations = relationship('SlackConversation', back_populates='org')
slack_users = relationship('SlackUser', back_populates='org')
stripe_customers = relationship('StripeCustomer', back_populates='org')
def __init__(self, **kwargs):
# Handle known SQLAlchemy columns directly
for key in list(kwargs):
if hasattr(self.__class__, key):
setattr(self, key, kwargs.pop(key))
# Handle custom property-style fields
if 'default_llm_api_key_for_byor' in kwargs:
self.default_llm_api_key_for_byor = kwargs.pop(
'default_llm_api_key_for_byor'
)
if 'search_api_key' in kwargs:
self.search_api_key = kwargs.pop('search_api_key')
if 'sandbox_api_key' in kwargs:
self.sandbox_api_key = kwargs.pop('sandbox_api_key')
if kwargs:
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
@property
def default_llm_api_key_for_byor(self) -> SecretStr | None:
if self._default_llm_api_key_for_byor:
decrypted = decrypt_value(self._default_llm_api_key_for_byor)
return SecretStr(decrypted)
return None
@default_llm_api_key_for_byor.setter
def default_llm_api_key_for_byor(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._default_llm_api_key_for_byor = encrypt_value(raw) if raw else None
@property
def search_api_key(self) -> SecretStr | None:
if self._search_api_key:
decrypted = decrypt_value(self._search_api_key)
return SecretStr(decrypted)
return None
@search_api_key.setter
def search_api_key(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._search_api_key = encrypt_value(raw) if raw else None
@property
def sandbox_api_key(self) -> SecretStr | None:
if self._sandbox_api_key:
decrypted = decrypt_value(self._sandbox_api_key)
return SecretStr(decrypted)
return None
@sandbox_api_key.setter
def sandbox_api_key(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._sandbox_api_key = encrypt_value(raw) if raw else None

View File

@@ -0,0 +1,67 @@
"""
SQLAlchemy model for Organization-Member relationship.
"""
from pydantic import SecretStr
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
from sqlalchemy.orm import relationship
from storage.base import Base
from storage.encrypt_utils import decrypt_value, encrypt_value
class OrgMember(Base): # type: ignore
"""Junction table for organization-member relationships with roles."""
__tablename__ = 'org_member'
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), primary_key=True)
user_id = Column(UUID(as_uuid=True), ForeignKey('user.id'), primary_key=True)
role_id = Column(Integer, ForeignKey('role.id'), nullable=False)
_llm_api_key = Column(String, nullable=False)
max_iterations = Column(Integer, nullable=True)
llm_model = Column(String, nullable=True)
_llm_api_key_for_byor = Column(String, nullable=True)
llm_base_url = Column(String, nullable=True)
status = Column(String, nullable=True)
# Relationships
org = relationship('Org', back_populates='org_members')
user = relationship('User', back_populates='org_members')
role = relationship('Role', back_populates='org_members')
def __init__(self, **kwargs):
# Handle known SQLAlchemy columns directly
for key in list(kwargs):
if hasattr(self.__class__, key):
setattr(self, key, kwargs.pop(key))
# Handle custom property-style fields
if 'llm_api_key' in kwargs:
self.llm_api_key = kwargs.pop('llm_api_key')
if 'llm_api_key_for_byor' in kwargs:
self.llm_api_key_for_byor = kwargs.pop('llm_api_key_for_byor')
if kwargs:
raise TypeError(f'Unexpected keyword arguments: {list(kwargs.keys())}')
@property
def llm_api_key(self) -> SecretStr:
decrypted = decrypt_value(self._llm_api_key)
return SecretStr(decrypted)
@llm_api_key.setter
def llm_api_key(self, value: str | SecretStr):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._llm_api_key = encrypt_value(raw)
@property
def llm_api_key_for_byor(self) -> SecretStr | None:
if self._llm_api_key_for_byor:
decrypted = decrypt_value(self._llm_api_key_for_byor)
return SecretStr(decrypted)
return None
@llm_api_key_for_byor.setter
def llm_api_key_for_byor(self, value: str | SecretStr | None):
raw = value.get_secret_value() if isinstance(value, SecretStr) else value
self._llm_api_key_for_byor = encrypt_value(raw) if raw else None

View File

@@ -0,0 +1,125 @@
"""
Store class for managing organization-member relationships.
"""
from typing import Optional
from uuid import UUID
from storage.database import session_maker
from storage.org_member import OrgMember
from storage.user_settings import UserSettings
from openhands.storage.data_models.settings import Settings
class OrgMemberStore:
"""Store for managing organization-member relationships."""
@staticmethod
def add_user_to_org(
org_id: UUID,
user_id: UUID,
role_id: int,
llm_api_key: str,
status: Optional[str] = None,
) -> OrgMember:
"""Add a user to an organization with a specific role."""
with session_maker() as session:
org_member = OrgMember(
org_id=org_id,
user_id=user_id,
role_id=role_id,
llm_api_key=llm_api_key,
status=status,
)
session.add(org_member)
session.commit()
session.refresh(org_member)
return org_member
@staticmethod
def get_org_member(org_id: UUID, user_id: int) -> Optional[OrgMember]:
"""Get organization-user relationship."""
with session_maker() as session:
return (
session.query(OrgMember)
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
.first()
)
@staticmethod
def get_user_orgs(user_id: int) -> list[OrgMember]:
"""Get all organizations for a user."""
with session_maker() as session:
return session.query(OrgMember).filter(OrgMember.user_id == user_id).all()
@staticmethod
def get_org_members(org_id: UUID) -> list[OrgMember]:
"""Get all users in an organization."""
with session_maker() as session:
return session.query(OrgMember).filter(OrgMember.org_id == org_id).all()
@staticmethod
def update_org_member(org_member: OrgMember) -> None:
"""Update an organization-member relationship."""
with session_maker() as session:
session.merge(org_member)
session.commit()
@staticmethod
def update_user_role_in_org(
org_id: UUID, user_id: int, role_id: int, status: Optional[str] = None
) -> Optional[OrgMember]:
"""Update user's role in an organization."""
with session_maker() as session:
org_member = (
session.query(OrgMember)
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
.first()
)
if not org_member:
return None
org_member.role_id = role_id
if status is not None:
org_member.status = status
session.commit()
session.refresh(org_member)
return org_member
@staticmethod
def remove_user_from_org(org_id: UUID, user_id: int) -> bool:
"""Remove a user from an organization."""
with session_maker() as session:
org_member = (
session.query(OrgMember)
.filter(OrgMember.org_id == org_id, OrgMember.user_id == user_id)
.first()
)
if not org_member:
return False
session.delete(org_member)
session.commit()
return True
@staticmethod
def get_kwargs_from_settings(settings: Settings):
kwargs = {
normalized: getattr(settings, normalized)
for c in OrgMember.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
}
return kwargs
@staticmethod
def get_kwargs_from_user_settings(user_settings: UserSettings):
kwargs = {
normalized: getattr(user_settings, normalized)
for c in OrgMember.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
}
return kwargs

View File

@@ -0,0 +1,139 @@
"""
Store class for managing organizations.
"""
from typing import Optional
from uuid import UUID
from server.constants import ORG_SETTINGS_VERSION, get_default_litellm_model
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.org import Org
from storage.user import User
from storage.user_settings import UserSettings
from openhands.core.logger import openhands_logger as logger
from openhands.storage.data_models.settings import Settings
class OrgStore:
"""Store for managing organizations."""
@staticmethod
def create_org(
kwargs: dict,
) -> Org:
"""Create a new organization."""
with session_maker() as session:
org = Org(**kwargs)
org.org_version = ORG_SETTINGS_VERSION
org.default_llm_model = get_default_litellm_model()
session.add(org)
session.commit()
session.refresh(org)
return org
@staticmethod
def get_org_by_id(org_id: UUID) -> Org | None:
"""Get organization by ID."""
with session_maker() as session:
return session.query(Org).filter(Org.id == org_id).first()
@staticmethod
def get_current_org_from_keycloak_user_id(keycloak_user_id: str) -> Org | None:
with session_maker() as session:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == UUID(keycloak_user_id))
.first()
)
if not user:
logger.warning(f'User not found for ID {keycloak_user_id}')
return None
org_id = user.current_org_id
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
logger.warning(
f'Org not found for ID {org_id} as the current org for user {keycloak_user_id}'
)
return None
return org
@staticmethod
def get_org_by_name(name: str) -> Org | None:
"""Get organization by name."""
with session_maker() as session:
return session.query(Org).filter(Org.name == name).first()
@staticmethod
def list_orgs() -> list[Org]:
"""List all organizations."""
with session_maker() as session:
orgs = session.query(Org).all()
return orgs
@staticmethod
def update_org(
org_id: UUID,
kwargs: dict,
) -> Optional[Org]:
"""Update organization details."""
with session_maker() as session:
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
return None
if 'id' in kwargs:
kwargs.pop('id')
for key, value in kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
session.commit()
session.refresh(org)
return org
@staticmethod
def get_kwargs_from_settings(settings: Settings):
kwargs = {}
for c in Org.__table__.columns:
# Normalize for lookup
normalized = (
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
)
if not hasattr(settings, normalized):
continue
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
key = c.name
if key.startswith('_'):
key = key[1:] # remove only the very first leading underscore
kwargs[key] = getattr(settings, normalized)
return kwargs
@staticmethod
def get_kwargs_from_user_settings(user_settings: UserSettings):
kwargs = {}
for c in Org.__table__.columns:
# Normalize for lookup
normalized = (
c.name.removeprefix('_default_').removeprefix('default_').lstrip('_')
)
if not hasattr(user_settings, normalized):
continue
# ---- FIX: Output key should drop *only* leading "_" but preserve "default" ----
key = c.name
if key.startswith('_'):
key = key[1:] # remove only the very first leading underscore
kwargs[key] = getattr(user_settings, normalized)
return kwargs

View File

@@ -0,0 +1,21 @@
"""
SQLAlchemy model for Role.
"""
from sqlalchemy import Column, Identity, Integer, String
from sqlalchemy.orm import relationship
from storage.base import Base
class Role(Base): # type: ignore
"""Role model for user permissions."""
__tablename__ = 'role'
id = Column(Integer, Identity(), primary_key=True)
name = Column(String, nullable=False, unique=True)
rank = Column(Integer, nullable=False)
# Relationships
users = relationship('User', back_populates='role')
org_members = relationship('OrgMember', back_populates='role')

View File

@@ -0,0 +1,40 @@
"""
Store class for managing roles.
"""
from typing import List, Optional
from storage.database import session_maker
from storage.role import Role
class RoleStore:
"""Store for managing roles."""
@staticmethod
def create_role(name: str, rank: int) -> Role:
"""Create a new role."""
with session_maker() as session:
role = Role(name=name, rank=rank)
session.add(role)
session.commit()
session.refresh(role)
return role
@staticmethod
def get_role_by_id(role_id: int) -> Optional[Role]:
"""Get role by ID."""
with session_maker() as session:
return session.query(Role).filter(Role.id == role_id).first()
@staticmethod
def get_role_by_name(name: str) -> Optional[Role]:
"""Get role by name."""
with session_maker() as session:
return session.query(Role).filter(Role.name == name).first()
@staticmethod
def list_roles() -> List[Role]:
"""List all roles."""
with session_maker() as session:
return session.query(Role).order_by(Role.rank).all()

View File

@@ -0,0 +1,350 @@
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
from datetime import datetime
from typing import AsyncGenerator
from uuid import UUID
from fastapi import Request
from sqlalchemy import func, select
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user import User
from openhands.app_server.app_conversation.app_conversation_info_service import (
AppConversationInfoService,
AppConversationInfoServiceInjector,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
AppConversationInfoPage,
AppConversationSortOrder,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
SQLAppConversationInfoService,
StoredConversationMetadata,
)
from openhands.app_server.services.injector import InjectorState
class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
"""Extended SQLAppConversationInfoService with user-based filtering and SAAS metadata handling."""
async def _secure_select(self):
query = (
select(StoredConversationMetadata)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
async def _secure_select_with_saas_metadata(self):
"""Select query that includes SAAS metadata for retrieving user_id."""
query = (
select(StoredConversationMetadata, StoredConversationMetadataSaas)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
return query
async def search_app_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
page_id: str | None = None,
limit: int = 100,
include_sub_conversations: bool = False,
) -> AppConversationInfoPage:
"""Search for conversations with user_id from SAAS metadata."""
query = await self._secure_select_with_saas_metadata()
# Conditionally exclude sub-conversations based on the parameter
if not include_sub_conversations:
# Exclude sub-conversations (only include top-level conversations)
query = query.where(
StoredConversationMetadata.parent_conversation_id.is_(None)
)
query = self._apply_filters_with_saas_metadata(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
# Add sort order
if sort_order == AppConversationSortOrder.CREATED_AT:
query = query.order_by(StoredConversationMetadata.created_at)
elif sort_order == AppConversationSortOrder.CREATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.created_at.desc())
elif sort_order == AppConversationSortOrder.UPDATED_AT:
query = query.order_by(StoredConversationMetadata.last_updated_at)
elif sort_order == AppConversationSortOrder.UPDATED_AT_DESC:
query = query.order_by(StoredConversationMetadata.last_updated_at.desc())
elif sort_order == AppConversationSortOrder.TITLE:
query = query.order_by(StoredConversationMetadata.title)
elif sort_order == AppConversationSortOrder.TITLE_DESC:
query = query.order_by(StoredConversationMetadata.title.desc())
# Apply pagination
if page_id is not None:
try:
offset = int(page_id)
query = query.offset(offset)
except ValueError:
# If page_id is not a valid integer, start from beginning
offset = 0
else:
offset = 0
# Apply limit and get one extra to check if there are more results
query = query.limit(limit + 1)
result = await self.db_session.execute(query)
rows = result.all()
# Check if there are more results
has_more = len(rows) > limit
if has_more:
rows = rows[:limit]
items = [
self._to_info_with_user_id(stored_metadata, saas_metadata)
for stored_metadata, saas_metadata in rows
]
# Calculate next page ID
next_page_id = None
if has_more:
next_page_id = str(offset + limit)
return AppConversationInfoPage(items=items, next_page_id=next_page_id)
async def count_app_conversation_info(
self,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
) -> int:
"""Count conversations matching the given filters with SAAS metadata."""
query = (
select(func.count(StoredConversationMetadata.conversation_id))
.select_from(
StoredConversationMetadata.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
)
.where(StoredConversationMetadata.conversation_version == 'V1')
)
# Apply user filtering
user_id_str = await self.user_context.get_user_id()
if user_id_str:
user_id_uuid = UUID(user_id_str)
query = query.where(StoredConversationMetadataSaas.user_id == user_id_uuid)
query = self._apply_filters_with_saas_metadata(
query=query,
title__contains=title__contains,
created_at__gte=created_at__gte,
created_at__lt=created_at__lt,
updated_at__gte=updated_at__gte,
updated_at__lt=updated_at__lt,
)
result = await self.db_session.execute(query)
count = result.scalar()
return count or 0
def _apply_filters_with_saas_metadata(
self,
query,
title__contains: str | None = None,
created_at__gte: datetime | None = None,
created_at__lt: datetime | None = None,
updated_at__gte: datetime | None = None,
updated_at__lt: datetime | None = None,
):
"""Apply filters to query that includes SAAS metadata."""
# Apply the same filters as the base class
conditions = []
if title__contains is not None:
conditions.append(
StoredConversationMetadata.title.like(f'%{title__contains}%')
)
if created_at__gte is not None:
conditions.append(StoredConversationMetadata.created_at >= created_at__gte)
if created_at__lt is not None:
conditions.append(StoredConversationMetadata.created_at < created_at__lt)
if updated_at__gte is not None:
conditions.append(
StoredConversationMetadata.last_updated_at >= updated_at__gte
)
if updated_at__lt is not None:
conditions.append(
StoredConversationMetadata.last_updated_at < updated_at__lt
)
if conditions:
query = query.where(*conditions)
return query
async def get_app_conversation_info(
self, conversation_id: UUID
) -> AppConversationInfo | None:
"""Get conversation info with user_id from SAAS metadata."""
query = await self._secure_select_with_saas_metadata()
query = query.where(
StoredConversationMetadata.conversation_id == str(conversation_id)
)
result_set = await self.db_session.execute(query)
result = result_set.first()
if result:
stored_metadata, saas_metadata = result
return self._to_info_with_user_id(stored_metadata, saas_metadata)
return None
async def batch_get_app_conversation_info(
self, conversation_ids: list[UUID]
) -> list[AppConversationInfo | None]:
"""Batch get conversation info with user_id from SAAS metadata."""
conversation_id_strs = [
str(conversation_id) for conversation_id in conversation_ids
]
query = await self._secure_select_with_saas_metadata()
query = query.where(
StoredConversationMetadata.conversation_id.in_(conversation_id_strs)
)
result = await self.db_session.execute(query)
rows = result.all()
# Create a mapping of conversation_id to (metadata, saas_metadata)
info_by_id = {}
for stored_metadata, saas_metadata in rows:
info_by_id[stored_metadata.conversation_id] = (
stored_metadata,
saas_metadata,
)
results: list[AppConversationInfo | None] = []
for conversation_id in conversation_id_strs:
if conversation_id in info_by_id:
stored_metadata, saas_metadata = info_by_id[conversation_id]
results.append(
self._to_info_with_user_id(stored_metadata, saas_metadata)
)
else:
results.append(None)
return results
async def save_app_conversation_info(
self, info: AppConversationInfo
) -> AppConversationInfo:
"""Save conversation info and create/update SAAS metadata with user_id and org_id."""
# Save the base conversation metadata
await super().save_app_conversation_info(info)
# Get current user_id for SAAS metadata
user_id_str = await self.user_context.get_user_id()
if user_id_str:
# Convert string user_id to UUID
user_id_uuid = UUID(user_id_str)
user_query = select(User).where(User.id == user_id_uuid)
result = await self.db_session.execute(user_query)
user = result.scalar_one_or_none()
assert user
# Check if SAAS metadata already exists
saas_query = select(StoredConversationMetadataSaas).where(
StoredConversationMetadataSaas.conversation_id == str(info.id)
)
result = await self.db_session.execute(saas_query)
existing_saas_metadata = result.scalar_one_or_none()
assert existing_saas_metadata is None or (
existing_saas_metadata.user_id == user_id_uuid
and existing_saas_metadata.org_id == user.current_org_id
)
if not existing_saas_metadata:
# Create new SAAS metadata
# Set org_id to user_id as specified in requirements
saas_metadata = StoredConversationMetadataSaas(
conversation_id=str(info.id),
user_id=user_id_uuid,
org_id=user.current_org_id,
)
self.db_session.add(saas_metadata)
await self.db_session.commit()
return info
def _to_info_with_user_id(
self,
stored: StoredConversationMetadata,
saas_metadata: StoredConversationMetadataSaas,
) -> AppConversationInfo:
"""Convert stored metadata to AppConversationInfo with user_id from SAAS metadata."""
# Use the base _to_info method to get the basic info
info = self._to_info(stored)
# Override the created_by_user_id with the user_id from SAAS metadata
info.created_by_user_id = (
str(saas_metadata.user_id) if saas_metadata.user_id else None
)
return info
class SaasAppConversationInfoServiceInjector(AppConversationInfoServiceInjector):
"""Enterprise injector for SQLAppConversationInfoService with SAAS filtering."""
async def inject(
self, state: InjectorState, request: Request | None = None
) -> AsyncGenerator[AppConversationInfoService, None]:
from openhands.app_server.config import (
get_db_session,
get_user_context,
)
async with (
get_user_context(state, request) as user_context,
get_db_session(state, request) as db_session,
):
service = SaasSQLAppConversationInfoService(
db_session=db_session, user_context=user_context
)
yield service

View File

@@ -4,10 +4,15 @@ import dataclasses
import logging
from dataclasses import dataclass
from datetime import UTC
from uuid import UUID
from sqlalchemy.orm import sessionmaker
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from storage.database import session_maker
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.integrations.provider import ProviderType
@@ -29,20 +34,37 @@ logger = logging.getLogger(__name__)
class SaasConversationStore(ConversationStore):
user_id: str
session_maker: sessionmaker
org_id: UUID | None = None # will be fetched automatically
def __init__(self, user_id: str, session_maker: sessionmaker):
self.user_id = user_id
self.session_maker = session_maker
user = UserStore.get_user_by_id(user_id)
self.org_id = user.current_org_id if user else None
def _select_by_id(self, session, conversation_id: str):
return (
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
query = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.user_id == self.user_id)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(StoredConversationMetadataSaas.user_id == UUID(self.user_id))
.filter(StoredConversationMetadata.conversation_id == conversation_id)
.filter(StoredConversationMetadata.conversation_version == 'V0')
)
if self.org_id is not None:
query = query.filter(StoredConversationMetadataSaas.org_id == self.org_id)
return query
def _to_external_model(self, conversation_metadata: StoredConversationMetadata):
kwargs = {
c.name: getattr(conversation_metadata, c.name)
for c in StoredConversationMetadata.__table__.columns
if c.name != 'github_user_id' # Skip github_user_id field
}
# TODO: I'm not sure why the timezone is not set on the dates coming back out of the db
kwargs['created_at'] = kwargs['created_at'].replace(tzinfo=UTC)
@@ -53,6 +75,8 @@ class SaasConversationStore(ConversationStore):
# Convert string to ProviderType enum
kwargs['git_provider'] = ProviderType(kwargs['git_provider'])
kwargs['user_id'] = self.user_id
# Remove V1 attributes
kwargs.pop('max_budget_per_task', None)
kwargs.pop('cache_read_tokens', None)
@@ -66,7 +90,10 @@ class SaasConversationStore(ConversationStore):
async def save_metadata(self, metadata: ConversationMetadata):
kwargs = dataclasses.asdict(metadata)
kwargs['user_id'] = self.user_id
# Remove user_id and org_id from kwargs since they're no longer in StoredConversationMetadata
kwargs.pop('user_id', None)
kwargs.pop('org_id', None)
# Convert ProviderType enum to string for storage
if kwargs.get('git_provider') is not None:
@@ -80,7 +107,41 @@ class SaasConversationStore(ConversationStore):
def _save_metadata():
with self.session_maker() as session:
# Save the main conversation metadata
session.merge(stored_metadata)
# Create or update the SaaS metadata record
saas_metadata = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id
== stored_metadata.conversation_id
)
.first()
)
if not saas_metadata:
saas_metadata = StoredConversationMetadataSaas(
conversation_id=stored_metadata.conversation_id,
user_id=UUID(self.user_id),
org_id=self.org_id,
)
session.add(saas_metadata)
else:
# Validate
expected_user_id = UUID(self.user_id)
expected_org_id = self.org_id
if saas_metadata.user_id != expected_user_id:
raise ValueError(
f'Existing user_id ({saas_metadata.user_id}) does not match expected value ({expected_user_id}).'
)
if expected_org_id and saas_metadata.org_id != expected_org_id:
raise ValueError(
f'Existing org_id ({saas_metadata.org_id}) does not match expected value ({expected_org_id}).'
)
session.commit()
await call_sync_from_async(_save_metadata)
@@ -100,8 +161,29 @@ class SaasConversationStore(ConversationStore):
async def delete_metadata(self, conversation_id: str) -> None:
def _delete_metadata():
with self.session_maker() as session:
self._select_by_id(session, conversation_id).delete()
session.commit()
saas_record = (
session.query(StoredConversationMetadataSaas)
.filter(
StoredConversationMetadataSaas.conversation_id
== conversation_id,
StoredConversationMetadataSaas.user_id == UUID(self.user_id),
StoredConversationMetadataSaas.org_id == self.org_id,
)
.first()
)
if saas_record:
# Delete both records, but only if the SaaS one exists
session.query(StoredConversationMetadata).filter(
StoredConversationMetadata.conversation_id == conversation_id,
).delete()
session.delete(saas_record)
session.commit()
else:
# No SaaS record found → skip deleting main metadata
session.rollback()
await call_sync_from_async(_delete_metadata)
@@ -124,7 +206,15 @@ class SaasConversationStore(ConversationStore):
with self.session_maker() as session:
conversations = (
session.query(StoredConversationMetadata)
.filter(StoredConversationMetadata.user_id == self.user_id)
.join(
StoredConversationMetadataSaas,
StoredConversationMetadata.conversation_id
== StoredConversationMetadataSaas.conversation_id,
)
.filter(
StoredConversationMetadataSaas.user_id == UUID(self.user_id)
)
.filter(StoredConversationMetadataSaas.org_id == self.org_id)
.filter(StoredConversationMetadata.conversation_version == 'V0')
.order_by(StoredConversationMetadata.created_at.desc())
.offset(offset)

View File

@@ -8,11 +8,13 @@ from cryptography.fernet import Fernet
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
from openhands.storage.data_models.secrets import Secrets
from openhands.storage.secrets.secrets_store import SecretsStore
from openhands.utils.async_utils import call_sync_from_async
@dataclass
@@ -24,14 +26,17 @@ class SaasSecretsStore(SecretsStore):
async def load(self) -> Secrets | None:
if not self.user_id:
return None
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
org_id = user.current_org_id if user else None
with self.session_maker() as session:
# Fetch all secrets for the given user ID
settings = (
session.query(StoredCustomSecrets)
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
.all()
query = session.query(StoredCustomSecrets).filter(
StoredCustomSecrets.keycloak_user_id == self.user_id
)
if org_id is not None:
query = query.filter(StoredCustomSecrets.org_id == org_id)
settings = query.all()
if not settings:
return Secrets()
@@ -48,6 +53,8 @@ class SaasSecretsStore(SecretsStore):
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
async def store(self, item: Secrets):
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
org_id = user.current_org_id
with self.session_maker() as session:
# Incoming secrets are always the most updated ones
# Delete all existing records and override with incoming ones
@@ -76,6 +83,7 @@ class SaasSecretsStore(SecretsStore):
for secret_name, secret_value, description in secret_tuples:
new_secret = StoredCustomSecrets(
keycloak_user_id=self.user_id,
org_id=org_id,
secret_name=secret_name,
secret_value=secret_value,
description=description,

View File

@@ -2,45 +2,37 @@ from __future__ import annotations
import binascii
import hashlib
import json
import os
import uuid
from base64 import b64decode, b64encode
from dataclasses import dataclass
import httpx
from cryptography.fernet import Fernet
from integrations import stripe_service
from pydantic import SecretStr
from server.auth.token_manager import TokenManager
from server.constants import (
CURRENT_USER_SETTINGS_VERSION,
DEFAULT_INITIAL_BUDGET,
LITE_LLM_API_KEY,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
REQUIRE_PAYMENT,
get_default_litellm_model,
)
from server.logger import logger
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import joinedload, sessionmaker
from storage.database import session_maker
from storage.lite_llm_manager import LiteLlmManager
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_store import OrgStore
from storage.user import User
from storage.user_settings import UserSettings
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
from openhands.storage import get_file_store
from openhands.storage.settings.settings_store import SettingsStore
from openhands.storage.settings.settings_store import SettingsStore as OssSettingsStore
from openhands.utils.async_utils import call_sync_from_async
from openhands.utils.http_session import httpx_verify_option
@dataclass
class SaasSettingsStore(SettingsStore):
class SaasSettingsStore(OssSettingsStore):
user_id: str
session_maker: sessionmaker
config: OpenHandsConfig
ENCRYPT_VALUES = ['llm_api_key', 'llm_api_key_for_byor', 'search_api_key']
def get_user_settings_by_keycloak_id(
def _get_user_settings_by_keycloak_id(
self, keycloak_user_id: str, session=None
) -> UserSettings | None:
"""
@@ -76,246 +68,104 @@ class SaasSettingsStore(SettingsStore):
return _get_settings()
async def load(self) -> Settings | None:
if not self.user_id:
user = await call_sync_from_async(UserStore.get_user_by_id, self.user_id)
if not user:
logger.error(f'User not found for ID {self.user_id}')
return None
with self.session_maker() as session:
settings = self.get_user_settings_by_keycloak_id(self.user_id, session)
if not settings or settings.user_version != CURRENT_USER_SETTINGS_VERSION:
logger.info(
'saas_settings_store:load:triggering_migration',
extra={'user_id': self.user_id},
org_id = user.current_org_id
org_member: OrgMember = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return None
org = OrgStore.get_org_by_id(org_id)
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
)
return None
kwargs = {
**{
normalized: getattr(org, c.name)
for c in Org.__table__.columns
if (
normalized := c.name.removeprefix('_default_')
.removeprefix('default_')
.lstrip('_')
)
return await self.create_default_settings(settings)
kwargs = {
c.name: getattr(settings, c.name)
for c in UserSettings.__table__.columns
if c.name in Settings.model_fields
}
self._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
return settings
in Settings.model_fields
},
**{
normalized: getattr(user, c.name)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) in Settings.model_fields
},
}
kwargs['llm_api_key'] = org_member.llm_api_key
if org_member.max_iterations:
kwargs['max_iterations'] = org_member.max_iterations
if org_member.llm_model:
kwargs['llm_model'] = org_member.llm_model
if org_member.llm_api_key_for_byor:
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
if org_member.llm_base_url:
kwargs['llm_base_url'] = org_member.llm_base_url
settings = Settings(**kwargs)
return settings
async def store(self, item: Settings):
# Check if provider is OpenHands and generate API key if needed
if item and self._is_openhands_provider(item):
await self._ensure_openhands_api_key(item)
# Call the static store method from SettingsStore
with self.session_maker() as session:
existing = None
kwargs = {}
if item:
kwargs = item.model_dump(context={'expose_secrets': True})
self._encrypt_kwargs(kwargs)
# First check if we have an existing entry in the new table
existing = self.get_user_settings_by_keycloak_id(self.user_id, session)
kwargs = {
key: value
for key, value in kwargs.items()
if key in UserSettings.__table__.columns
}
if existing:
# Update existing entry
for key, value in kwargs.items():
setattr(existing, key, value)
existing.user_version = CURRENT_USER_SETTINGS_VERSION
session.merge(existing)
else:
kwargs['keycloak_user_id'] = self.user_id
kwargs['user_version'] = CURRENT_USER_SETTINGS_VERSION
kwargs.pop('secrets_store', None) # Don't save secrets_store to db
settings = UserSettings(**kwargs)
session.add(settings)
session.commit()
async def create_default_settings(self, user_settings: UserSettings | None):
logger.info(
'saas_settings_store:create_default_settings:start',
extra={'user_id': self.user_id},
)
# You must log in before you get default settings
if not self.user_id:
return None
# Only users that have specified a payment method get default settings
if REQUIRE_PAYMENT and not await stripe_service.has_payment_method(
self.user_id
):
logger.info(
'saas_settings_store:create_default_settings:no_payment',
extra={'user_id': self.user_id},
)
return None
settings: Settings | None = None
if user_settings is None:
settings = Settings(
language='en',
enable_proactive_conversation_starters=True,
)
elif isinstance(user_settings, UserSettings):
# Convert UserSettings (SQLAlchemy model) to Settings (Pydantic model)
kwargs = {
c.name: getattr(user_settings, c.name)
for c in UserSettings.__table__.columns
if c.name in Settings.model_fields
}
self._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
if settings:
settings = await self.update_settings_with_litellm_default(settings)
if settings is None:
logger.info(
'saas_settings_store:create_default_settings:litellm_update_failed',
extra={'user_id': self.user_id},
)
return None
await self.store(settings)
return settings
async def load_legacy_file_store_settings(self, github_user_id: str):
if not github_user_id:
return None
file_store = get_file_store(self.config.file_store, self.config.file_store_path)
path = f'users/github/{github_user_id}/settings.json'
try:
json_str = await call_sync_from_async(file_store.read, path)
logger.info(
'saas_settings_store:load_legacy_file_store_settings:found',
extra={'github_user_id': github_user_id},
)
kwargs = json.loads(json_str)
self._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
return settings
except FileNotFoundError:
return None
except Exception as e:
logger.error(
'saas_settings_store:load_legacy_file_store_settings:error',
extra={'github_user_id': github_user_id, 'error': str(e)},
)
return None
async def update_settings_with_litellm_default(
self, settings: Settings
) -> Settings | None:
logger.info(
'saas_settings_store:update_settings_with_litellm_default:start',
extra={'user_id': self.user_id},
)
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
return None
local_deploy = os.environ.get('LOCAL_DEPLOYMENT', None)
key = LITE_LLM_API_KEY
if not local_deploy:
# Get user info to add to litellm
token_manager = TokenManager()
keycloak_user_info = (
await token_manager.get_user_info_from_user_id(self.user_id) or {}
)
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
# Get the previous max budget to prevent accidental loss
# In Litellm a get always succeeds, regardless of whether the user actually exists
response = await client.get(
f'{LITE_LLM_API_URL}/user/info?user_id={self.user_id}'
)
response.raise_for_status()
response_json = response.json()
user_info = response_json.get('user_info') or {}
logger.info(
f'creating_litellm_user: {self.user_id}; prev_max_budget: {user_info.get("max_budget")}; prev_metadata: {user_info.get("metadata")}'
)
max_budget = user_info.get('max_budget') or DEFAULT_INITIAL_BUDGET
spend = user_info.get('spend') or 0
if not item:
return None
kwargs = item.model_dump(context={'expose_secrets': True})
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(self.user_id))
).first()
if not user:
# Check if we need to migrate from user_settings
user_settings = None
with session_maker() as session:
user_settings = self.get_user_settings_by_keycloak_id(
user_settings = self._get_user_settings_by_keycloak_id(
self.user_id, session
)
# In upgrade to V4, we no longer use billing margin, but instead apply this directly
# in litellm. The default billing marign was 2 before this (hence the magic numbers below)
if (
user_settings
and user_settings.user_version < 4
and user_settings.billing_margin
and user_settings.billing_margin != 1.0
):
billing_margin = user_settings.billing_margin
logger.info(
'user_settings_v4_budget_upgrade',
extra={
'max_budget': max_budget,
'billing_margin': billing_margin,
'spend': spend,
},
)
max_budget *= billing_margin
spend *= billing_margin
user_settings.billing_margin = 1.0
session.commit()
email = keycloak_user_info.get('email')
# We explicitly delete here to guard against odd inherited settings on upgrade.
# We don't care if this fails with a 404
await client.post(
f'{LITE_LLM_API_URL}/user/delete', json={'user_ids': [self.user_id]}
)
# Create the new litellm user
response = await self._create_user_in_lite_llm(
client, email, max_budget, spend
)
if not response.is_success:
logger.warning(
'duplicate_user_email',
extra={'user_id': self.user_id, 'email': email},
)
# Litellm insists on unique email addresses - it is possible the email address was registered with a different user.
response = await self._create_user_in_lite_llm(
client, None, max_budget, spend
)
# User failed to create in litellm - this is an unforseen error state...
if not response.is_success:
logger.error(
'error_creating_litellm_user',
extra={
'status_code': response.status_code,
'text': response.text,
'user_id': [self.user_id],
'email': email,
'max_budget': max_budget,
'spend': spend,
},
)
if user_settings:
user = await UserStore.migrate_user(self.user_id, user_settings)
else:
logger.error(f'User not found for ID {self.user_id}')
return None
response_json = response.json()
key = response_json['key']
logger.info(
'saas_settings_store:update_settings_with_litellm_default:user_created',
extra={'user_id': self.user_id},
org_id = user.current_org_id
# Check if provider is OpenHands and generate API key if needed
if self._is_openhands_provider(item):
await self._ensure_openhands_api_key(item, str(org_id))
org_member = None
for om in user.org_members:
if om.org_id == org_id:
org_member = om
break
if not org_member or not org_member.llm_api_key:
return None
org = session.query(Org).filter(Org.id == org_id).first()
if not org:
logger.error(
f'Org not found for ID {org_id} as the current org for user {self.user_id}'
)
return None
settings.agent = 'CodeActAgent'
# Use the model corresponding to the current user settings version
settings.llm_model = get_default_litellm_model()
settings.llm_api_key = SecretStr(key)
settings.llm_base_url = LITE_LLM_API_URL
return settings
for model in (user, org, org_member):
for key, value in kwargs.items():
if hasattr(model, key):
setattr(model, key, value)
session.commit()
@classmethod
async def get_instance(
@@ -326,6 +176,9 @@ class SaasSettingsStore(SettingsStore):
logger.debug(f'saas_settings_store.get_instance::{user_id}')
return SaasSettingsStore(user_id, session_maker, config)
def _should_encrypt(self, key):
return key in self.ENCRYPT_VALUES
def _decrypt_kwargs(self, kwargs: dict):
fernet = self._fernet()
for key, value in kwargs.items():
@@ -369,21 +222,24 @@ class SaasSettingsStore(SettingsStore):
fernet_key = b64encode(hashlib.sha256(jwt_secret.encode()).digest())
return Fernet(fernet_key)
def _should_encrypt(self, key: str) -> bool:
return key in ('llm_api_key', 'llm_api_key_for_byor', 'search_api_key')
def _is_openhands_provider(self, item: Settings) -> bool:
"""Check if the settings use the OpenHands provider."""
return bool(item.llm_model and item.llm_model.startswith('openhands/'))
async def _ensure_openhands_api_key(self, item: Settings) -> None:
async def _ensure_openhands_api_key(self, item: Settings, org_id: str) -> None:
"""Generate and set the OpenHands API key for the given settings.
First checks if an existing key with the OpenHands alias exists,
and reuses it if found. Otherwise, generates a new key.
"""
# Generate new key if none exists
generated_key = await self._generate_openhands_key()
generated_key = await LiteLlmManager.generate_key(
self.user_id,
org_id,
f'Openhands Provider Key - user {self.user_id}',
{'type': 'openhands'},
)
if generated_key:
item.llm_api_key = SecretStr(generated_key)
logger.info(
@@ -395,78 +251,3 @@ class SaasSettingsStore(SettingsStore):
'saas_settings_store:store:failed_to_generate_openhands_key',
extra={'user_id': self.user_id},
)
async def _create_user_in_lite_llm(
self, client: httpx.AsyncClient, email: str | None, max_budget: int, spend: int
):
response = await client.post(
f'{LITE_LLM_API_URL}/user/new',
json={
'user_email': email,
'models': [],
'max_budget': max_budget,
'spend': spend,
'user_id': str(self.user_id),
'teams': [LITE_LLM_TEAM_ID],
'auto_create_key': True,
'send_invite_email': False,
'metadata': {
'version': CURRENT_USER_SETTINGS_VERSION,
'model': get_default_litellm_model(),
},
'key_alias': f'OpenHands Cloud - user {self.user_id}',
},
)
return response
async def _generate_openhands_key(self) -> str | None:
"""Generate a new OpenHands provider key for a user."""
if not (LITE_LLM_API_KEY and LITE_LLM_API_URL):
logger.warning(
'saas_settings_store:_generate_openhands_key:litellm_config_not_found',
extra={'user_id': self.user_id},
)
return None
try:
async with httpx.AsyncClient(
verify=httpx_verify_option(),
headers={
'x-goog-api-key': LITE_LLM_API_KEY,
},
) as client:
response = await client.post(
f'{LITE_LLM_API_URL}/key/generate',
json={
'user_id': self.user_id,
'metadata': {'type': 'openhands'},
},
)
response.raise_for_status()
response_json = response.json()
key = response_json.get('key')
if key:
logger.info(
'saas_settings_store:_generate_openhands_key:success',
extra={
'user_id': self.user_id,
'key_length': len(key) if key else 0,
'key_prefix': (
key[:10] + '...' if key and len(key) > 10 else key
),
},
)
return key
else:
logger.error(
'saas_settings_store:_generate_openhands_key:no_key_in_response',
extra={'user_id': self.user_id, 'response_json': response_json},
)
return None
except Exception as e:
logger.exception(
'saas_settings_store:_generate_openhands_key:error',
extra={'user_id': self.user_id, 'error': str(e)},
)
return None

View File

@@ -1,4 +1,6 @@
from sqlalchemy import Column, Identity, Integer, String
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from storage.base import Base
@@ -8,4 +10,8 @@ class SlackConversation(Base): # type: ignore
conversation_id = Column(String, nullable=False, index=True)
channel_id = Column(String, nullable=False)
keycloak_user_id = Column(String, nullable=False)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
parent_id = Column(String, nullable=True, index=True)
# Relationships
org = relationship('Org', back_populates='slack_conversations')

View File

@@ -1,4 +1,6 @@
from sqlalchemy import Column, DateTime, Identity, Integer, String, text
from sqlalchemy import Column, DateTime, ForeignKey, Identity, Integer, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from storage.base import Base
@@ -6,6 +8,7 @@ class SlackUser(Base): # type: ignore
__tablename__ = 'slack_users'
id = Column(Integer, Identity(), primary_key=True)
keycloak_user_id = Column(String, nullable=False, index=True)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
slack_user_id = Column(String, nullable=False, index=True)
slack_display_name = Column(String, nullable=False)
created_at = Column(
@@ -13,3 +16,6 @@ class SlackUser(Base): # type: ignore
server_default=text('CURRENT_TIMESTAMP'),
nullable=False,
)
# Relationships
org = relationship('Org', back_populates='slack_users')

View File

@@ -1,8 +1,22 @@
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata as _StoredConversationMetadata,
)
def _get_stored_conversation_metadata():
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata as _StoredConversationMetadata,
)
StoredConversationMetadata = _StoredConversationMetadata
return _StoredConversationMetadata
# Lazy import to avoid circular dependency
StoredConversationMetadata = None
def __getattr__(name):
global StoredConversationMetadata
if name == 'StoredConversationMetadata':
if StoredConversationMetadata is None:
StoredConversationMetadata = _get_stored_conversation_metadata()
return StoredConversationMetadata
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
__all__ = ['StoredConversationMetadata']

View File

@@ -0,0 +1,28 @@
"""
SQLAlchemy model for ConversationMetadataSaas.
This model stores the SaaS-specific metadata for conversations,
containing only the conversation_id, user_id, and org_id.
"""
from sqlalchemy import UUID as SQL_UUID
from sqlalchemy import Column, ForeignKey, String
from sqlalchemy.orm import relationship
from storage.base import Base
class StoredConversationMetadataSaas(Base): # type: ignore
"""SaaS conversation metadata model containing user and org associations."""
__tablename__ = 'conversation_metadata_saas'
conversation_id = Column(String, primary_key=True)
user_id = Column(SQL_UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
org_id = Column(SQL_UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
# Relationships
user = relationship('User', back_populates='stored_conversation_metadata_saas')
org = relationship('Org', back_populates='stored_conversation_metadata_saas')
__all__ = ['StoredConversationMetadataSaas']

View File

@@ -1,4 +1,6 @@
from sqlalchemy import Column, Identity, Integer, String
from sqlalchemy import Column, ForeignKey, Identity, Integer, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from storage.base import Base
@@ -6,6 +8,10 @@ class StoredCustomSecrets(Base): # type: ignore
__tablename__ = 'custom_secrets'
id = Column(Integer, Identity(), primary_key=True)
keycloak_user_id = Column(String, nullable=True, index=True)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
secret_name = Column(String, nullable=False)
secret_value = Column(String, nullable=False)
description = Column(String, nullable=True)
# Relationships
org = relationship('Org', back_populates='user_secrets')

View File

@@ -1,4 +1,6 @@
from sqlalchemy import Column, DateTime, Integer, String, text
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from storage.base import Base
@@ -13,6 +15,7 @@ class StripeCustomer(Base): # type: ignore
__tablename__ = 'stripe_customers'
id = Column(Integer, primary_key=True, autoincrement=True)
keycloak_user_id = Column(String, nullable=False)
org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=True)
stripe_customer_id = Column(String, nullable=False)
created_at = Column(
DateTime, server_default=text('CURRENT_TIMESTAMP'), nullable=False
@@ -23,3 +26,6 @@ class StripeCustomer(Base): # type: ignore
onupdate=text('CURRENT_TIMESTAMP'),
nullable=False,
)
# Relationships
org = relationship('Org', back_populates='stripe_customers')

View File

@@ -0,0 +1,41 @@
"""
SQLAlchemy model for User.
"""
from uuid import uuid4
from sqlalchemy import (
UUID,
Boolean,
Column,
DateTime,
ForeignKey,
Integer,
String,
)
from sqlalchemy.orm import relationship
from storage.base import Base
class User(Base): # type: ignore
"""User model with organizational relationships."""
__tablename__ = 'user'
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
current_org_id = Column(UUID(as_uuid=True), ForeignKey('org.id'), nullable=False)
role_id = Column(Integer, ForeignKey('role.id'), nullable=True)
accepted_tos = Column(DateTime, nullable=True)
enable_sound_notifications = Column(Boolean, nullable=True)
language = Column(String, nullable=True)
user_consents_to_analytics = Column(Boolean, nullable=True)
email = Column(String, nullable=True)
email_verified = Column(Boolean, nullable=True)
# Relationships
role = relationship('Role', back_populates='users')
org_members = relationship('OrgMember', back_populates='user')
current_org = relationship('Org', back_populates='current_users')
stored_conversation_metadata_saas = relationship(
'StoredConversationMetadataSaas', back_populates='user'
)

View File

@@ -39,3 +39,6 @@ class UserSettings(Base): # type: ignore
git_user_name = Column(String, nullable=True)
git_user_email = Column(String, nullable=True)
v1_enabled = Column(Boolean, nullable=True)
already_migrated = Column(
Boolean, nullable=True, default=False
) # False = not migrated, True = migrated

View File

@@ -0,0 +1,332 @@
"""
Store class for managing users.
"""
import uuid
from typing import Optional
from server.logger import logger
from sqlalchemy import text
from sqlalchemy.orm import joinedload
from storage.database import session_maker
from storage.encrypt_utils import decrypt_legacy_model
from storage.org import Org
from storage.org_member import OrgMember
from storage.role_store import RoleStore
from storage.user import User
from storage.user_settings import UserSettings
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
class UserStore:
"""Store for managing users."""
@staticmethod
async def create_user(
user_id: str,
user_info: dict,
role_id: Optional[int] = None,
) -> User | None:
"""Create a new user."""
with session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_name=user_info['preferred_username'],
contact_email=user_info['email'],
)
session.add(org)
settings = await UserStore.create_default_settings(
org_id=str(org.id), user_id=user_id
)
if not settings:
return None
from storage.org_store import OrgStore
org_kwargs = OrgStore.get_kwargs_from_settings(settings)
for key, value in org_kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
user_kwargs = UserStore.get_kwargs_from_settings(settings)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
role_id=role_id,
**user_kwargs,
)
session.add(user)
role = RoleStore.get_role_by_name('owner')
from storage.org_member_store import OrgMemberStore
org_member_kwargs = OrgMemberStore.get_kwargs_from_settings(settings)
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id, # owner of your own org.
status='active',
**org_member_kwargs,
)
session.add(org_member)
session.commit()
session.refresh(user)
user.org_members # load org_members
return user
@staticmethod
async def migrate_user(
user_id: str,
user_settings: UserSettings,
user_info: dict,
) -> User:
if not user_id or not user_settings:
return None
kwargs = decrypt_legacy_model(
[
'llm_api_key',
'llm_api_key_for_byor',
'search_api_key',
'sandbox_api_key',
],
user_settings,
)
decrypted_user_settings = UserSettings(**kwargs)
with session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_name=user_info['username'],
contact_email=user_info['email'],
)
session.add(org)
from storage.lite_llm_manager import LiteLlmManager
await LiteLlmManager.migrate_entries(
str(org.id),
user_id,
decrypted_user_settings,
)
# avoids circular reference. This migrate method is temprorary until all users are migrated.
from integrations.stripe_service import migrate_customer
await migrate_customer(session, user_id, org)
from storage.org_store import OrgStore
org_kwargs = OrgStore.get_kwargs_from_user_settings(decrypted_user_settings)
org_kwargs.pop('id', None)
for key, value in org_kwargs.items():
if hasattr(org, key):
setattr(org, key, value)
user_kwargs = UserStore.get_kwargs_from_user_settings(
decrypted_user_settings
)
user_kwargs.pop('id', None)
user = User(
id=uuid.UUID(user_id),
current_org_id=org.id,
role_id=None,
**user_kwargs,
)
session.add(user)
role = RoleStore.get_role_by_name('owner')
from storage.org_member_store import OrgMemberStore
org_member_kwargs = OrgMemberStore.get_kwargs_from_user_settings(
decrypted_user_settings
)
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id, # owner of your own org.
status='active',
**org_member_kwargs,
)
session.add(org_member)
# Mark the old user_settings as migrated instead of deleting
user_settings.already_migrated = True
session.merge(user_settings)
session.flush()
# need to migrate conversation metadata
session.execute(
text("""
INSERT INTO conversation_metadata_saas (conversation_id, user_id, org_id)
SELECT
conversation_id,
:user_id,
:user_id
FROM conversation_metadata
WHERE user_id = :user_id
"""),
{'user_id': user_id},
)
# Update org_id for tables that had org_id added
user_uuid = uuid.UUID(user_id)
# Update stripe_customers
session.execute(
text(
'UPDATE stripe_customers SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update slack_users
session.execute(
text(
'UPDATE slack_users SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update slack_conversation
session.execute(
text(
'UPDATE slack_conversation SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update api_keys
session.execute(
text('UPDATE api_keys SET org_id = :org_id WHERE user_id = :user_id'),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update custom_secrets
session.execute(
text(
'UPDATE custom_secrets SET org_id = :org_id WHERE keycloak_user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
# Update billing_sessions
session.execute(
text(
'UPDATE billing_sessions SET org_id = :org_id WHERE user_id = :user_id'
),
{'org_id': user_uuid, 'user_id': user_uuid},
)
session.commit()
session.refresh(user)
user.org_members # load org_members
return user
@staticmethod
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID."""
with session_maker() as session:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(user_id))
.first()
)
if user:
return user
# Check if we need to migrate from user_settings
user_settings = (
session.query(UserSettings)
.filter(
UserSettings.keycloak_user_id == user_id,
UserSettings.already_migrated.is_(False),
)
.first()
)
if user_settings:
from server.auth.token_manager import TokenManager
token_manager = TokenManager()
user_info = call_async_from_sync(
token_manager.get_user_info_from_user_id,
GENERAL_TIMEOUT,
user_id,
)
user = call_async_from_sync(
UserStore.migrate_user,
GENERAL_TIMEOUT,
user_id,
user_settings,
user_info,
)
return user
else:
return None
@staticmethod
def list_users() -> list[User]:
"""List all users."""
with session_maker() as session:
return session.query(User).all()
# Prevent circular imports
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from openhands.storage.data_models.settings import Settings
@staticmethod
async def create_default_settings(
org_id: str, user_id: str
) -> Optional['Settings']:
logger.info(
'UserStore:create_default_settings:start',
extra={'org_id': org_id, 'user_id': user_id},
)
# You must log in before you get default settings
if not org_id:
return None
from openhands.storage.data_models.settings import Settings
settings = Settings(language='en', enable_proactive_conversation_starters=True)
from storage.lite_llm_manager import LiteLlmManager
settings = await LiteLlmManager.create_entries(org_id, user_id, settings)
if not settings:
logger.info(
'UserStore:create_default_settings:litellm_create_failed',
extra={'org_id': org_id},
)
return None
return settings
@staticmethod
def get_kwargs_from_settings(settings: 'Settings'):
kwargs = {
normalized: getattr(settings, normalized)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(settings, normalized)
}
return kwargs
@staticmethod
def get_kwargs_from_user_settings(user_settings: UserSettings):
kwargs = {
normalized: getattr(user_settings, normalized)
for c in User.__table__.columns
if (normalized := c.name.lstrip('_')) and hasattr(user_settings, normalized)
}
return kwargs

View File

@@ -3,7 +3,7 @@ from typing import cast
from uuid import uuid4
from integrations.types import GitLabResourceType
from integrations.utils import GITLAB_WEBHOOK_URL
from server.constants import WEB_HOST
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
from storage.gitlab_webhook_store import GitlabWebhookStore
@@ -11,6 +11,7 @@ from openhands.core.logger import openhands_logger as logger
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
from openhands.integrations.service_types import GitService
GITLAB_WEBHOOK_URL = f'https://{WEB_HOST}/integration/gitlab/events'
CHUNK_SIZE = 100
WEBHOOK_NAME = 'OpenHands Resolver'
SCOPES: list[str] = [

View File

@@ -1,10 +1,9 @@
import uuid
from datetime import datetime
from uuid import UUID
import pytest
from server.constants import CURRENT_USER_SETTINGS_VERSION
from server.maintenance_task_processor.user_version_upgrade_processor import (
UserVersionUpgradeProcessor,
)
from server.constants import ORG_SETTINGS_VERSION
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from storage.base import Base
@@ -14,11 +13,20 @@ from storage.billing_session import BillingSession
from storage.conversation_work import ConversationWork
from storage.feedback import Feedback
from storage.github_app_installation import GithubAppInstallation
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.org import Org
from storage.org_member import OrgMember
from storage.role import Role
from storage.stored_conversation_metadata_saas import (
StoredConversationMetadataSaas,
)
from storage.stored_offline_token import StoredOfflineToken
from storage.stripe_customer import StripeCustomer
from storage.user_settings import UserSettings
from storage.user import User
# Import the actual StoredConversationMetadata from OpenHands core
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
@pytest.fixture
@@ -67,7 +75,6 @@ def add_minimal_fixtures(session_maker):
session.add(
StoredConversationMetadata(
conversation_id='mock-conversation-id',
user_id='mock-user-id',
created_at=datetime.fromisoformat('2025-03-07'),
last_updated_at=datetime.fromisoformat('2025-03-08'),
accumulated_cost=5.25,
@@ -76,6 +83,13 @@ def add_minimal_fixtures(session_maker):
total_tokens=750,
)
)
session.add(
StoredConversationMetadataSaas(
conversation_id='mock-conversation-id',
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
)
)
session.add(
StoredOfflineToken(
user_id='mock-user-id',
@@ -84,7 +98,38 @@ def add_minimal_fixtures(session_maker):
updated_at=datetime.fromisoformat('2025-03-08'),
)
)
session.add(
Org(
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
name='mock-org',
org_version=ORG_SETTINGS_VERSION,
enable_default_condenser=True,
enable_proactive_conversation_starters=True,
)
)
session.add(
Role(
id=1,
name='admin',
rank=1,
)
)
session.add(
User(
id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
current_org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
user_consents_to_analytics=True,
)
)
session.add(
OrgMember(
org_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
user_id=uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
role_id=1,
llm_api_key='mock-api-key',
status='active',
)
)
session.add(
StripeCustomer(
keycloak_user_id='mock-user-id',
@@ -93,13 +138,6 @@ def add_minimal_fixtures(session_maker):
updated_at=datetime.fromisoformat('2025-03-10'),
)
)
session.add(
UserSettings(
keycloak_user_id='mock-user-id',
user_consents_to_analytics=True,
user_version=CURRENT_USER_SETTINGS_VERSION,
)
)
session.add(
ConversationWork(
conversation_id='mock-conversation-id',
@@ -108,17 +146,6 @@ def add_minimal_fixtures(session_maker):
updated_at=datetime.fromisoformat('2025-03-08'),
)
)
maintenance_task = MaintenanceTask(
status=MaintenanceTaskStatus.PENDING,
)
maintenance_task.set_processor(
UserVersionUpgradeProcessor(
user_ids=['mock-user-id'],
created_at=datetime.fromisoformat('2025-03-07'),
updated_at=datetime.fromisoformat('2025-03-08'),
)
)
session.add(maintenance_task)
session.commit()

View File

@@ -6,22 +6,32 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import BackgroundTasks, HTTPException, Request, status
from server.routes.event_webhook import (
BatchMethod,
BatchOperation,
_get_session_api_key,
_get_user_id,
_parse_conversation_id_and_subpath,
_process_batch_operations_background,
on_batch_write,
on_delete,
on_write,
# Import the actual StoredConversationMetadata from OpenHands core
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from server.utils.conversation_callback_utils import (
process_event,
update_conversation_metadata,
)
from storage.stored_conversation_metadata import StoredConversationMetadata
# Mock the lazy import to return the actual class
with patch(
'storage.stored_conversation_metadata.StoredConversationMetadata',
StoredConversationMetadata,
):
from server.routes.event_webhook import (
BatchMethod,
BatchOperation,
_get_session_api_key,
_get_user_id,
_parse_conversation_id_and_subpath,
_process_batch_operations_background,
on_batch_write,
on_delete,
on_write,
)
from server.utils.conversation_callback_utils import (
process_event,
update_conversation_metadata,
)
from openhands.events.observation.agent import AgentStateChangedObservation
@@ -82,7 +92,7 @@ class TestGetUserId:
session_maker_with_minimal_fixtures,
):
user_id = _get_user_id('mock-conversation-id')
assert user_id == 'mock-user-id'
assert user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
def test_get_user_id_conversation_not_found(self, session_maker):
"""Test getting user ID when conversation doesn't exist."""
@@ -105,10 +115,12 @@ class TestGetSessionApiKey:
return_value=[mock_agent_loop_info]
)
api_key = await _get_session_api_key('user-123', 'conv-456')
api_key = await _get_session_api_key(
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
)
assert api_key == 'test-api-key'
mock_manager.get_agent_loop_info.assert_called_once_with(
'user-123', filter_to_sids={'conv-456'}
'5594c7b6-f959-4b81-92e9-b09c206f5081', filter_to_sids={'conv-456'}
)
@pytest.mark.asyncio
@@ -118,7 +130,9 @@ class TestGetSessionApiKey:
mock_manager.get_agent_loop_info = AsyncMock(return_value=[])
with pytest.raises(IndexError):
await _get_session_api_key('user-123', 'conv-456')
await _get_session_api_key(
'5594c7b6-f959-4b81-92e9-b09c206f5081', 'conv-456'
)
class TestProcessEvent:
@@ -142,10 +156,15 @@ class TestProcessEvent:
mock_event = MagicMock()
mock_event_from_dict.return_value = mock_event
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
await process_event(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
'conv-456',
'events/event-1.json',
content,
)
mock_file_store.write.assert_called_once_with(
'users/user-123/conversations/conv-456/events/event-1.json',
'users/5594c7b6-f959-4b81-92e9-b09c206f5081/conversations/conv-456/events/event-1.json',
json.dumps(content),
)
mock_event_from_dict.assert_called_once_with(content)
@@ -177,14 +196,19 @@ class TestProcessEvent:
)
mock_event_from_dict.return_value = mock_event
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
await process_event(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
'conv-456',
'events/event-1.json',
content,
)
mock_file_store.write.assert_called_once()
mock_event_from_dict.assert_called_once_with(content)
mock_invoke_callbacks.assert_called_once_with('conv-456', mock_event)
mock_update_working_seconds.assert_called_once()
mock_event_store_class.assert_called_once_with(
'conv-456', mock_file_store, 'user-123'
'conv-456', mock_file_store, '5594c7b6-f959-4b81-92e9-b09c206f5081'
)
@pytest.mark.asyncio
@@ -212,7 +236,12 @@ class TestProcessEvent:
mock_event.agent_state = 'running' # Set RUNNING state to skip the update
mock_event_from_dict.return_value = mock_event
await process_event('user-123', 'conv-456', 'events/event-1.json', content)
await process_event(
'5594c7b6-f959-4b81-92e9-b09c206f5081',
'conv-456',
'events/event-1.json',
content,
)
mock_file_store.write.assert_called_once()
mock_event_from_dict.assert_called_once_with(content)
@@ -236,10 +265,13 @@ class TestUpdateConversationMetadata:
'total_tokens': 1500,
}
with patch(
'server.utils.conversation_callback_utils.session_maker',
session_maker_with_minimal_fixtures,
):
# Import the module and patch the session_maker at the module level
import server.utils.conversation_callback_utils as callback_utils
original_session_maker = callback_utils.session_maker
try:
callback_utils.session_maker = session_maker_with_minimal_fixtures
update_conversation_metadata('mock-conversation-id', content)
# Verify the conversation was updated
@@ -257,6 +289,9 @@ class TestUpdateConversationMetadata:
assert conversation.completion_tokens == 500
assert conversation.total_tokens == 1500
assert isinstance(conversation.last_updated_at, datetime)
finally:
# Restore the original session_maker
callback_utils.session_maker = original_session_maker
def test_update_conversation_metadata_partial_fields(
self, session_maker_with_minimal_fixtures
@@ -264,10 +299,13 @@ class TestUpdateConversationMetadata:
"""Test updating conversation metadata with only some fields."""
content = {'accumulated_cost': 15.75, 'prompt_tokens': 2000}
with patch(
'server.utils.conversation_callback_utils.session_maker',
session_maker_with_minimal_fixtures,
):
# Import the module and patch the session_maker at the module level
import server.utils.conversation_callback_utils as callback_utils
original_session_maker = callback_utils.session_maker
try:
callback_utils.session_maker = session_maker_with_minimal_fixtures
update_conversation_metadata('mock-conversation-id', content)
# Verify only specified fields were updated, others remain unchanged
@@ -285,6 +323,9 @@ class TestUpdateConversationMetadata:
# These should remain as original values from fixtures
assert conversation.completion_tokens == 250
assert conversation.total_tokens == 750
finally:
# Restore the original session_maker
callback_utils.session_maker = original_session_maker
def test_update_conversation_metadata_empty_content(
self, session_maker_with_minimal_fixtures
@@ -292,10 +333,13 @@ class TestUpdateConversationMetadata:
"""Test updating conversation metadata with empty content."""
content: dict[str, float] = {}
with patch(
'server.utils.conversation_callback_utils.session_maker',
session_maker_with_minimal_fixtures,
):
# Import the module and patch the session_maker at the module level
import server.utils.conversation_callback_utils as callback_utils
original_session_maker = callback_utils.session_maker
try:
callback_utils.session_maker = session_maker_with_minimal_fixtures
update_conversation_metadata('mock-conversation-id', content)
# Verify only last_updated_at was changed
@@ -314,6 +358,9 @@ class TestUpdateConversationMetadata:
assert conversation.completion_tokens == 250
assert conversation.total_tokens == 750
assert isinstance(conversation.last_updated_at, datetime)
finally:
# Restore the original session_maker
callback_utils.session_maker = original_session_maker
class TestOnDelete:
@@ -344,24 +391,31 @@ class TestOnWrite:
content = {'accumulated_cost': 20.0}
mock_request.json.return_value = content
with patch(
'server.routes.event_webhook.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.utils.conversation_callback_utils.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.routes.event_webhook._get_session_api_key'
) as mock_get_api_key:
mock_get_api_key.return_value = 'correct-api-key'
# Import the module and patch the session_maker at the module level
import server.utils.conversation_callback_utils as callback_utils
result = await on_write(
'sessions/mock-conversation-id/metadata.json',
mock_request,
'correct-api-key',
)
original_session_maker = callback_utils.session_maker
assert result.status_code == status.HTTP_200_OK
try:
with patch(
'server.routes.event_webhook.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.routes.event_webhook._get_session_api_key'
) as mock_get_api_key:
mock_get_api_key.return_value = 'correct-api-key'
callback_utils.session_maker = session_maker_with_minimal_fixtures
result = await on_write(
'sessions/mock-conversation-id/metadata.json',
mock_request,
'correct-api-key',
)
assert result.status_code == status.HTTP_200_OK
finally:
# Restore the original session_maker
callback_utils.session_maker = original_session_maker
@pytest.mark.asyncio
async def test_on_write_events_success(
@@ -569,31 +623,38 @@ class TestProcessBatchOperationsBackground:
)
]
with patch(
'server.routes.event_webhook.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.routes.event_webhook._get_session_api_key'
) as mock_get_api_key, patch(
'server.utils.conversation_callback_utils.session_maker',
session_maker_with_minimal_fixtures,
):
mock_get_api_key.return_value = 'correct-api-key'
# Import the module and patch the session_maker at the module level
import server.utils.conversation_callback_utils as callback_utils
# Should not raise any exceptions
await _process_batch_operations_background(batch_ops, 'correct-api-key')
original_session_maker = callback_utils.session_maker
# Verify the conversation metadata was updated
with session_maker_with_minimal_fixtures() as session:
conversation = (
session.query(StoredConversationMetadata)
.filter(
StoredConversationMetadata.conversation_id
== 'mock-conversation-id'
try:
with patch(
'server.routes.event_webhook.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.routes.event_webhook._get_session_api_key'
) as mock_get_api_key:
mock_get_api_key.return_value = 'correct-api-key'
callback_utils.session_maker = session_maker_with_minimal_fixtures
# Should not raise any exceptions
await _process_batch_operations_background(batch_ops, 'correct-api-key')
# Verify the conversation metadata was updated
with session_maker_with_minimal_fixtures() as session:
conversation = (
session.query(StoredConversationMetadata)
.filter(
StoredConversationMetadata.conversation_id
== 'mock-conversation-id'
)
.first()
)
.first()
)
assert conversation.accumulated_cost == 15.0
assert conversation.accumulated_cost == 15.0
finally:
# Restore the original session_maker
callback_utils.session_maker = original_session_maker
@pytest.mark.asyncio
async def test_process_batch_operations_events_success(
@@ -644,20 +705,27 @@ class TestProcessBatchOperationsBackground:
),
]
with patch(
'server.routes.event_webhook.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.routes.event_webhook._get_session_api_key'
) as mock_get_api_key, patch(
'server.utils.conversation_callback_utils.session_maker',
session_maker_with_minimal_fixtures,
):
# First call succeeds, second fails
mock_get_api_key.side_effect = ['correct-api-key', 'wrong-api-key']
# Import the module and patch the session_maker at the module level
import server.utils.conversation_callback_utils as callback_utils
# Should not raise exceptions, just log errors
await _process_batch_operations_background(batch_ops, 'correct-api-key')
original_session_maker = callback_utils.session_maker
try:
with patch(
'server.routes.event_webhook.session_maker',
session_maker_with_minimal_fixtures,
), patch(
'server.routes.event_webhook._get_session_api_key'
) as mock_get_api_key:
# First call succeeds, second fails
mock_get_api_key.side_effect = ['correct-api-key', 'wrong-api-key']
callback_utils.session_maker = session_maker_with_minimal_fixtures
# Should not raise exceptions, just log errors
await _process_batch_operations_background(batch_ops, 'correct-api-key')
finally:
# Restore the original session_maker
callback_utils.session_maker = original_session_maker
@pytest.mark.asyncio
async def test_process_batch_operations_invalid_method_skipped(

View File

@@ -0,0 +1,371 @@
"""Tests for SaasSQLAppConversationInfoService.
This module tests the SAAS implementation of SQLAppConversationInfoService,
focusing on user isolation, SAAS metadata handling, and multi-tenant functionality.
"""
from datetime import datetime, timezone
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID, uuid4
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
# Import the SAAS service
from enterprise.storage.saas_app_conversation_info_injector import (
SaasSQLAppConversationInfoService,
)
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationInfo,
)
from openhands.app_server.user.specifiy_user_context import SpecifyUserContext
from openhands.app_server.utils.sql_utils import Base
from openhands.integrations.service_types import ProviderType
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
@pytest.fixture
async def async_engine():
"""Create an async SQLite engine for testing."""
engine = create_async_engine(
'sqlite+aiosqlite:///:memory:',
poolclass=StaticPool,
connect_args={'check_same_thread': False},
echo=False,
)
# Create all tables
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine) -> AsyncGenerator[AsyncSession, None]:
"""Create an async session for testing."""
async_session_maker = async_sessionmaker(
async_engine, class_=AsyncSession, expire_on_commit=False
)
async with async_session_maker() as db_session:
yield db_session
@pytest.fixture
def service(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance for testing."""
return SaasSQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id=None)
)
@pytest.fixture
def service_with_user(async_session) -> SaasSQLAppConversationInfoService:
"""Create a SQLAppConversationInfoService instance with a user_id for testing."""
return SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111'),
)
@pytest.fixture
def sample_conversation_info() -> AppConversationInfo:
"""Create a sample AppConversationInfo for testing."""
return AppConversationInfo(
id=uuid4(),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
sandbox_id='sandbox_123',
selected_repository='https://github.com/test/repo',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title='Test Conversation',
trigger=ConversationTrigger.GUI,
pr_number=[123, 456],
llm_model='gpt-4',
metrics=None,
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
)
@pytest.fixture
def multiple_conversation_infos() -> list[AppConversationInfo]:
"""Create multiple AppConversationInfo instances for testing."""
base_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
return [
AppConversationInfo(
id=uuid4(),
created_by_user_id=None,
sandbox_id=f'sandbox_{i}',
selected_repository=f'https://github.com/test/repo{i}',
selected_branch='main',
git_provider=ProviderType.GITHUB,
title=f'Test Conversation {i}',
trigger=ConversationTrigger.GUI,
pr_number=[i * 100],
llm_model='gpt-4',
metrics=None,
created_at=base_time.replace(hour=12 + i),
updated_at=base_time.replace(hour=12 + i, minute=30),
)
for i in range(1, 6) # Create 5 conversations
]
@pytest.fixture
def mock_db_session():
"""Create a mock database session."""
return AsyncMock()
@pytest.fixture
def user1_context():
"""Create user context for user1."""
return SpecifyUserContext(user_id='a1111111-1111-1111-1111-111111111111')
@pytest.fixture
def user2_context():
"""Create user context for user2."""
return SpecifyUserContext(user_id='b2222222-2222-2222-2222-222222222222')
@pytest.fixture
def saas_service_user1(mock_db_session, user1_context):
"""Create a SaasSQLAppConversationInfoService instance for user1."""
return SaasSQLAppConversationInfoService(
db_session=mock_db_session, user_context=user1_context
)
@pytest.fixture
def saas_service_user2(mock_db_session, user2_context):
"""Create a SaasSQLAppConversationInfoService instance for user2."""
return SaasSQLAppConversationInfoService(
db_session=mock_db_session, user_context=user2_context
)
class TestSaasSQLAppConversationInfoService:
"""Test suite for SaasSQLAppConversationInfoService."""
def test_service_initialization(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
user1_context: SpecifyUserContext,
):
"""Test that the SAAS service is properly initialized."""
assert saas_service_user1.user_context == user1_context
assert saas_service_user1.db_session is not None
@pytest.mark.asyncio
async def test_user_context_isolation(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
saas_service_user2: SaasSQLAppConversationInfoService,
):
"""Test that different service instances have different user contexts."""
user1_id = await saas_service_user1.user_context.get_user_id()
user2_id = await saas_service_user2.user_context.get_user_id()
assert user1_id == 'a1111111-1111-1111-1111-111111111111'
assert user2_id == 'b2222222-2222-2222-2222-222222222222'
assert user1_id != user2_id
@pytest.mark.asyncio
async def test_secure_select_includes_user_filtering(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
):
"""Test that _secure_select method includes user filtering."""
# This test verifies that the _secure_select method exists and can be called
# The actual SQL generation is tested implicitly through integration
query = await saas_service_user1._secure_select()
assert query is not None
@pytest.mark.asyncio
async def test_to_info_with_user_id_functionality(
self,
saas_service_user1: SaasSQLAppConversationInfoService,
):
"""Test that _to_info_with_user_id properly sets user_id from SAAS metadata."""
from storage.stored_conversation_metadata_saas import (
StoredConversationMetadataSaas,
)
# Create mock metadata objects
stored_metadata = MagicMock(spec=StoredConversationMetadata)
stored_metadata.conversation_id = '12345678-1234-5678-1234-567812345678'
stored_metadata.parent_conversation_id = None
stored_metadata.title = 'Test Conversation'
stored_metadata.sandbox_id = 'test-sandbox'
stored_metadata.selected_repository = None
stored_metadata.selected_branch = None
stored_metadata.git_provider = None
stored_metadata.trigger = None
stored_metadata.pr_number = []
stored_metadata.llm_model = None
from datetime import datetime, timezone
stored_metadata.created_at = datetime.now(timezone.utc)
stored_metadata.last_updated_at = datetime.now(timezone.utc)
stored_metadata.accumulated_cost = 0.0
stored_metadata.prompt_tokens = 0
stored_metadata.completion_tokens = 0
stored_metadata.total_tokens = 0
stored_metadata.max_budget_per_task = None
stored_metadata.cache_read_tokens = 0
stored_metadata.cache_write_tokens = 0
stored_metadata.reasoning_tokens = 0
stored_metadata.context_window = 0
stored_metadata.per_turn_token = 0
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
# Test the _to_info_with_user_id method
result = saas_service_user1._to_info_with_user_id(
stored_metadata, saas_metadata
)
# Verify that the user_id from SAAS metadata is used
assert result.created_by_user_id == 'a1111111-1111-1111-1111-111111111111'
assert result.title == 'Test Conversation'
assert result.sandbox_id == 'test-sandbox'
@pytest.mark.asyncio
async def test_user_isolation(
self,
async_session: AsyncSession,
multiple_conversation_infos: list[AppConversationInfo],
):
"""Test that user isolation works correctly."""
from unittest.mock import MagicMock
from storage.user import User
# Mock the database session execute method to return mock users
# This mock intercepts User queries and returns a mock user object
# with user_id and org_id the same as the user_id_uuid from the query
original_execute = async_session.execute
async def mock_execute(query):
query_str = str(query)
# Check if this is a User query
if '"user"' in query_str.lower() and '"user".id' in query_str.lower():
# Extract the UUID from the query parameters
# The query will have bound parameters, we need to get the UUID value
if hasattr(query, 'compile'):
try:
compiled = query.compile(compile_kwargs={'literal_binds': True})
query_with_params = str(compiled)
# Extract UUID from the query string
import re
# Try both formats: with dashes and without dashes
uuid_pattern_with_dashes = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
uuid_pattern_without_dashes = r'[a-f0-9]{32}'
uuid_match = re.search(
uuid_pattern_with_dashes, query_with_params
)
if not uuid_match:
uuid_match = re.search(
uuid_pattern_without_dashes, query_with_params
)
if uuid_match:
user_id_str = uuid_match.group(0)
# If the UUID doesn't have dashes, add them
if len(user_id_str) == 32 and '-' not in user_id_str:
# Convert from 'a1111111111111111111111111111111' to 'a1111111-1111-1111-1111-111111111111'
user_id_str = f'{user_id_str[:8]}-{user_id_str[8:12]}-{user_id_str[12:16]}-{user_id_str[16:20]}-{user_id_str[20:]}'
user_id_uuid = UUID(user_id_str)
# Create a mock user with user_id and org_id the same as user_id_uuid
mock_user = MagicMock(spec=User)
mock_user.id = user_id_uuid
mock_user.current_org_id = user_id_uuid
# Create a mock result
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_user
return mock_result
except Exception:
# If there's any error in parsing, fall back to original execute
pass
# For all other queries, use the original execute method
return await original_execute(query)
# Apply the mock
async_session.execute = mock_execute
# Create services for different users
user1_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='a1111111-1111-1111-1111-111111111111'
),
)
user2_service = SaasSQLAppConversationInfoService(
db_session=async_session,
user_context=SpecifyUserContext(
user_id='b2222222-2222-2222-2222-222222222222'
),
)
# Create conversations for different users
user1_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='a1111111-1111-1111-1111-111111111111',
sandbox_id='sandbox_user1',
title='User 1 Conversation',
)
user2_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='b2222222-2222-2222-2222-222222222222',
sandbox_id='sandbox_user2',
title='User 2 Conversation',
)
# Save conversations
await user1_service.save_app_conversation_info(user1_info)
await user2_service.save_app_conversation_info(user2_info)
# User 1 should only see their conversation
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 1
assert (
user1_page.items[0].created_by_user_id
== 'a1111111-1111-1111-1111-111111111111'
)
# User 2 should only see their conversation
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 1
assert (
user2_page.items[0].created_by_user_id
== 'b2222222-2222-2222-2222-222222222222'
)
# User 1 should not be able to get user 2's conversation
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
assert user2_from_user1 is None
# User 2 should not be able to get user 1's conversation
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
assert user1_from_user2 is None

View File

@@ -1,5 +1,5 @@
from datetime import UTC, datetime, timedelta
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from storage.api_key_store import ApiKeyStore
@@ -19,6 +19,14 @@ def mock_session_maker(mock_session):
return session_maker
@pytest.fixture
def mock_user():
"""Mock user with org_id."""
user = MagicMock()
user.current_org_id = 'test-org-123'
return user
@pytest.fixture
def api_key_store(mock_session_maker):
return ApiKeyStore(mock_session_maker)
@@ -31,11 +39,13 @@ def test_generate_api_key(api_key_store):
assert len(key) == 32
def test_create_api_key(api_key_store, mock_session):
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_create_api_key(mock_get_user, api_key_store, mock_session, mock_user):
"""Test creating an API key."""
# Setup
user_id = 'test-user-123'
name = 'Test Key'
mock_get_user.return_value = mock_user
api_key_store.generate_api_key = MagicMock(return_value='test-api-key')
# Execute
@@ -43,10 +53,15 @@ def test_create_api_key(api_key_store, mock_session):
# Verify
assert result == 'test-api-key'
mock_get_user.assert_called_once_with(user_id)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
api_key_store.generate_api_key.assert_called_once()
# Verify the ApiKey was created with the correct org_id
added_api_key = mock_session.add.call_args[0][0]
assert added_api_key.org_id == mock_user.current_org_id
def test_validate_api_key_valid(api_key_store, mock_session):
"""Test validating a valid API key."""
@@ -158,10 +173,12 @@ def test_delete_api_key_by_id(api_key_store, mock_session):
mock_session.commit.assert_called_once()
def test_list_api_keys(api_key_store, mock_session):
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_list_api_keys(mock_get_user, api_key_store, mock_session, mock_user):
"""Test listing API keys for a user."""
# Setup
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
now = datetime.now(UTC)
mock_key1 = MagicMock()
mock_key1.id = 1
@@ -177,15 +194,17 @@ def test_list_api_keys(api_key_store, mock_session):
mock_key2.last_used_at = None
mock_key2.expires_at = None
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_key1,
mock_key2,
]
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_key1, mock_key2]
# Execute
result = api_key_store.list_api_keys(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert len(result) == 2
assert result[0]['id'] == 1
assert result[0]['name'] == 'Key 1'
@@ -198,3 +217,59 @@ def test_list_api_keys(api_key_store, mock_session):
assert result[1]['created_at'] == now
assert result[1]['last_used_at'] is None
assert result[1]['expires_at'] is None
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_retrieve_mcp_api_key(mock_get_user, api_key_store, mock_session, mock_user):
"""Test retrieving MCP API key for a user."""
# Setup
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
mock_mcp_key = MagicMock()
mock_mcp_key.name = 'MCP_API_KEY'
mock_mcp_key.key = 'mcp-test-key'
mock_other_key = MagicMock()
mock_other_key.name = 'Other Key'
mock_other_key.key = 'other-test-key'
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_other_key, mock_mcp_key]
# Execute
result = api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result == 'mcp-test-key'
@patch('storage.api_key_store.UserStore.get_user_by_id')
def test_retrieve_mcp_api_key_not_found(
mock_get_user, api_key_store, mock_session, mock_user
):
"""Test retrieving MCP API key when none exists."""
# Setup
user_id = 'test-user-123'
mock_get_user.return_value = mock_user
mock_other_key = MagicMock()
mock_other_key.name = 'Other Key'
mock_other_key.key = 'other-test-key'
# Mock the chained query calls for filtering by user_id and org_id
mock_query = mock_session.query.return_value
mock_filter_user = mock_query.filter.return_value
mock_filter_org = mock_filter_user.filter.return_value
mock_filter_org.all.return_value = [mock_other_key]
# Execute
result = api_key_store.retrieve_mcp_api_key(user_id)
# Verify
mock_get_user.assert_called_once_with(user_id)
assert result is None

View File

@@ -127,6 +127,7 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.UserStore') as mock_user_store,
):
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -140,6 +141,15 @@ async def test_keycloak_callback_user_not_allowed(mock_request):
)
mock_token_manager.store_idp_tokens = AsyncMock()
# Mock the user creation
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = None
mock_user_store.get_user_by_id.return_value = mock_user
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_verifier.is_active.return_value = True
mock_verifier.is_user_allowed.return_value = False
@@ -161,20 +171,19 @@ async def test_keycloak_callback_success_with_valid_offline_token(mock_request):
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.user_verifier') as mock_verifier,
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.posthog') as mock_posthog,
):
# Mock the session and query results
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
# Mock user settings with accepted_tos
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
# Setup UserStore mocks
mock_user_store.get_user_by_id.return_value = mock_user
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
@@ -226,20 +235,20 @@ async def test_keycloak_callback_success_without_offline_token(mock_request):
),
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
patch('server.routes.auth.session_maker') as mock_session_maker,
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.posthog') as mock_posthog,
):
# Mock the session and query results
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
# Setup UserStore mocks
mock_user_store.get_user_by_id.return_value = mock_user
mock_user_store.create_user = AsyncMock(return_value=mock_user)
mock_user_store.migrate_user = AsyncMock(return_value=mock_user)
# Mock user settings with accepted_tos
mock_user_settings = MagicMock()
mock_user_settings.accepted_tos = '2025-01-01'
mock_query.first.return_value = mock_user_settings
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)

View File

@@ -1,26 +1,26 @@
import uuid
from decimal import Decimal
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import stripe
from fastapi import HTTPException, Request, status
from httpx import HTTPStatusError, Response
from integrations.stripe_service import has_payment_method
from httpx import Response
from server.routes import billing
from server.routes.billing import (
CreateBillingSessionResponse,
CreateCheckoutSessionRequest,
GetCreditsResponse,
cancel_callback,
cancel_subscription,
create_checkout_session,
create_subscription_checkout_session,
create_customer_setup_session,
get_credits,
has_payment_method,
success_callback,
)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from starlette.datastructures import URL
from storage.billing_session_type import BillingSessionType
from storage.stripe_customer import Base as StripeCustomerBase
@@ -78,29 +78,31 @@ def mock_subscription_request():
@pytest.mark.asyncio
async def test_get_credits_lite_llm_error():
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
mock_response = Response(
status_code=500, json={'error': 'Internal Server Error'}, request=MagicMock()
)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get.return_value = mock_response
with patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'):
with patch('httpx.AsyncClient', return_value=mock_client):
with pytest.raises(HTTPStatusError) as exc_info:
await get_credits(mock_request)
assert (
exc_info.value.response.status_code
== status.HTTP_500_INTERNAL_SERVER_ERROR
)
with (
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
patch(
'storage.user_store.UserStore.get_user_by_id',
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
side_effect=Exception('LiteLLM API Error'),
),
):
with pytest.raises(Exception, match='LiteLLM API Error'):
await get_credits('mock_user')
@pytest.mark.asyncio
async def test_get_credits_success():
mock_response = Response(
status_code=200,
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
json={
'user_info': {
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
}
},
request=MagicMock(),
)
mock_client = AsyncMock()
@@ -109,24 +111,22 @@ async def test_get_credits_success():
with (
patch('integrations.stripe_service.STRIPE_API_KEY', 'mock_key'),
patch('httpx.AsyncClient', return_value=mock_client),
patch(
'storage.user_store.UserStore.get_user_by_id',
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
},
),
):
with patch('server.routes.billing.session_maker') as mock_session_maker:
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.first.return_value = MagicMock(
billing_margin=4
)
mock_session_maker.return_value.__enter__.return_value = mock_db_session
result = await get_credits('mock_user')
result = await get_credits('mock_user')
assert isinstance(result, GetCreditsResponse)
assert result.credits == Decimal(
'74.50'
) # 100.00 - 25.50 = 74.50 (no billing margin applied)
mock_client.__aenter__.return_value.get.assert_called_once_with(
'https://llm-proxy.app.all-hands.dev/user/info?user_id=mock_user',
headers={'x-goog-api-key': None},
)
assert isinstance(result, GetCreditsResponse)
assert result.credits == Decimal('74.50') # 100.00 - 25.50 = 74.50
@pytest.mark.asyncio
@@ -139,6 +139,9 @@ async def test_create_checkout_session_stripe_error(
id='mock-customer', metadata={'user_id': 'mock-user'}
)
mock_customer_create = AsyncMock(return_value=mock_customer)
mock_org = MagicMock()
mock_org.id = uuid.uuid4()
mock_org.contact_email = 'testy@tester.com'
with (
pytest.raises(Exception, match='Stripe API Error'),
patch('stripe.Customer.create_async', mock_customer_create),
@@ -150,6 +153,10 @@ async def test_create_checkout_session_stripe_error(
AsyncMock(side_effect=Exception('Stripe API Error')),
),
patch('integrations.stripe_service.session_maker', session_maker),
patch(
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
@@ -175,6 +182,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
id='mock-customer', metadata={'user_id': 'mock-user'}
)
mock_customer_create = AsyncMock(return_value=mock_customer)
mock_org = MagicMock()
mock_org_id = uuid.uuid4()
mock_org.id = mock_org_id
mock_org.contact_email = 'testy@tester.com'
with (
patch('stripe.Customer.create_async', mock_customer_create),
patch(
@@ -183,6 +194,10 @@ async def test_create_checkout_session_success(session_maker, mock_checkout_requ
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('integrations.stripe_service.session_maker', session_maker),
patch(
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
@@ -254,7 +269,6 @@ async def test_success_callback_stripe_incomplete():
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
@@ -282,44 +296,33 @@ async def test_success_callback_success():
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
mock_lite_llm_response = Response(
status_code=200,
json={'user_info': {'max_budget': 100.00, 'spend': 25.50}},
request=MagicMock(),
)
mock_lite_llm_update_response = Response(
status_code=200, json={}, request=MagicMock()
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch('httpx.AsyncClient') as mock_client,
patch(
'storage.user_store.UserStore.get_user_by_id',
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
return_value={
'spend': 25.50,
'litellm_budget_table': {'max_budget': 100.00},
},
),
patch(
'storage.lite_llm_manager.LiteLlmManager.update_team_and_users_budget'
) as mock_update_budget,
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_user_settings = MagicMock(billing_margin=None)
mock_db_session.query.return_value.filter.return_value.first.return_value = (
mock_user_settings
)
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete',
amount_subtotal=2500,
status='complete', amount_subtotal=2500, customer='mock_customer_id'
) # $25.00 in cents
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__.return_value.get.return_value = (
mock_lite_llm_response
)
mock_client_instance.__aenter__.return_value.post.return_value = (
mock_lite_llm_update_response
)
mock_client.return_value = mock_client_instance
response = await success_callback('test_session_id', mock_request)
assert response.status_code == 302
@@ -329,18 +332,14 @@ async def test_success_callback_success():
)
# Verify LiteLLM API calls
mock_client_instance.__aenter__.return_value.get.assert_called_once()
mock_client_instance.__aenter__.return_value.post.assert_called_once_with(
'https://llm-proxy.app.all-hands.dev/user/update',
headers={'x-goog-api-key': None},
json={
'user_id': 'mock_user',
'max_budget': 125,
}, # 100 + (25.00 from Stripe)
mock_update_budget.assert_called_once_with(
'mock_org_id',
125.0, # 100 + (25.00 from Stripe)
)
# Verify database updates
assert mock_billing_session.status == 'completed'
assert mock_billing_session.price == 25.0
mock_db_session.merge.assert_called_once()
mock_db_session.commit.assert_called_once()
@@ -354,27 +353,27 @@ async def test_success_callback_lite_llm_error():
mock_billing_session = MagicMock()
mock_billing_session.status = 'in_progress'
mock_billing_session.user_id = 'mock_user'
mock_billing_session.billing_session_type = BillingSessionType.DIRECT_PAYMENT.value
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.retrieve') as mock_stripe_retrieve,
patch('httpx.AsyncClient') as mock_client,
patch(
'storage.user_store.UserStore.get_user_by_id',
return_value=MagicMock(current_org_id='mock_org_id'),
),
patch(
'storage.lite_llm_manager.LiteLlmManager.get_user_team_info',
side_effect=Exception('LiteLLM API Error'),
),
):
mock_db_session = MagicMock()
mock_db_session.query.return_value.filter.return_value.filter.return_value.first.return_value = mock_billing_session
mock_session_maker.return_value.__enter__.return_value = mock_db_session
mock_stripe_retrieve.return_value = MagicMock(
status='complete', amount_total=2500
status='complete', amount_subtotal=2500
)
mock_client_instance = AsyncMock()
mock_client_instance.__aenter__.return_value.get.side_effect = Exception(
'LiteLLM API Error'
)
mock_client.return_value = mock_client_instance
with pytest.raises(Exception, match='LiteLLM API Error'):
await success_callback('test_session_id', mock_request)
@@ -398,7 +397,8 @@ async def test_cancel_callback_session_not_found():
response = await cancel_callback('test_session_id', mock_request)
assert response.status_code == 302
assert (
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
response.headers['location']
== 'http://test.com/settings/billing?checkout=cancel'
)
# Verify no database updates occurred
@@ -424,7 +424,8 @@ async def test_cancel_callback_success():
assert response.status_code == 302
assert (
response.headers['location'] == 'http://test.com/settings?checkout=cancel'
response.headers['location']
== 'http://test.com/settings/billing?checkout=cancel'
)
# Verify database updates
@@ -436,314 +437,67 @@ async def test_cancel_callback_success():
@pytest.mark.asyncio
async def test_has_payment_method_with_payment_method():
"""Test has_payment_method returns True when user has a payment method."""
with (
patch('integrations.stripe_service.session_maker') as mock_session_maker,
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[MagicMock()])),
) as mock_list_payment_methods,
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.first.return_value = (
MagicMock(stripe_customer_id='cus_test123')
)
mock_has_payment_method = AsyncMock(return_value=True)
with patch(
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
mock_has_payment_method,
):
result = await has_payment_method('mock_user')
assert result is True
mock_list_payment_methods.assert_called_once_with('cus_test123')
mock_has_payment_method.assert_called_once_with('mock_user')
@pytest.mark.asyncio
async def test_has_payment_method_without_payment_method():
"""Test has_payment_method returns False when user has no payment method."""
with (
patch('integrations.stripe_service.session_maker') as mock_session_maker,
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[])),
) as mock_list_payment_methods,
mock_has_payment_method = AsyncMock(return_value=False)
with patch(
'server.routes.billing.stripe_service.has_payment_method_by_user_id',
mock_has_payment_method,
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.first.return_value = (
MagicMock(stripe_customer_id='cus_test123')
)
mock_has_payment_method.return_value = False
result = await has_payment_method('mock_user')
assert result is False
mock_list_payment_methods.assert_called_once_with('cus_test123')
mock_has_payment_method.assert_called_once_with('mock_user')
@pytest.mark.asyncio
async def test_cancel_subscription_success():
"""Test successful subscription cancellation."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock active subscription
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id='sub_test123',
cancelled_at=None,
async def test_create_customer_setup_session_success():
"""Test successful creation of customer setup session."""
mock_request = Request(
scope={
'type': 'http',
'path': '/api/billing/create-customer-setup-session',
'server': ('test.com', 80),
'headers': [],
}
)
mock_request._base_url = URL('http://test.com/')
# Mock Stripe subscription response
mock_stripe_subscription = MagicMock()
mock_stripe_subscription.cancel_at_period_end = True
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
mock_session = MagicMock()
mock_session.url = 'https://checkout.stripe.com/test-session'
mock_create = AsyncMock(return_value=mock_session)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'stripe.Subscription.modify_async',
AsyncMock(return_value=mock_stripe_subscription),
) as mock_stripe_modify,
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
# Call the function
result = await cancel_subscription('test_user')
# Verify Stripe API was called
mock_stripe_modify.assert_called_once_with(
'sub_test123', cancel_at_period_end=True
)
# Verify database was updated
assert mock_subscription_access.cancelled_at is not None
mock_session.merge.assert_called_once_with(mock_subscription_access)
mock_session.commit.assert_called_once()
# Verify response
assert result.status_code == 200
@pytest.mark.asyncio
async def test_cancel_subscription_no_active_subscription():
"""Test cancellation when no active subscription exists."""
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
):
# Setup mock session with no subscription found
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await cancel_subscription('test_user')
assert exc_info.value.status_code == 404
assert 'No active subscription found' in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_cancel_subscription_missing_stripe_id():
"""Test cancellation when subscription has no Stripe ID."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock subscription without Stripe ID
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id=None, # Missing Stripe ID
cancelled_at=None,
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await cancel_subscription('test_user')
assert exc_info.value.status_code == 400
assert 'missing Stripe subscription ID' in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_cancel_subscription_stripe_error():
"""Test cancellation when Stripe API fails."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock active subscription
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id='sub_test123',
cancelled_at=None,
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'stripe.Subscription.modify_async',
AsyncMock(side_effect=stripe.StripeError('API Error')),
'integrations.stripe_service.find_or_create_customer_by_user_id',
AsyncMock(return_value=mock_customer_info),
),
):
# Setup mock session
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await cancel_subscription('test_user')
assert exc_info.value.status_code == 500
assert 'Failed to cancel subscription' in str(exc_info.value.detail)
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_duplicate_prevention(
mock_subscription_request,
):
"""Test that creating a subscription when user already has active subscription raises error."""
from datetime import UTC, datetime
from storage.subscription_access import SubscriptionAccess
# Mock active subscription
mock_subscription_access = SubscriptionAccess(
id=1,
status='ACTIVE',
user_id='test_user',
start_at=datetime.now(UTC),
end_at=datetime.now(UTC),
amount_paid=2000,
stripe_invoice_payment_id='pi_test',
stripe_subscription_id='sub_test123',
cancelled_at=None,
)
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch('stripe.checkout.Session.create_async', mock_create),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session to return existing active subscription
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = mock_subscription_access
result = await create_customer_setup_session(mock_request, 'mock_user')
# Call the function and expect HTTPException
with pytest.raises(HTTPException) as exc_info:
await create_subscription_checkout_session(
mock_subscription_request, user_id='test_user'
)
assert exc_info.value.status_code == 400
assert (
'user already has an active subscription'
in str(exc_info.value.detail).lower()
)
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_allows_after_cancellation(
mock_subscription_request,
):
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
mock_session_obj = MagicMock()
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
mock_session_obj.id = 'test_session_id'
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'integrations.stripe_service.find_or_create_customer',
AsyncMock(return_value='cus_test123'),
),
patch(
'stripe.checkout.Session.create_async',
AsyncMock(return_value=mock_session_obj),
),
patch(
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
# Should succeed
result = await create_subscription_checkout_session(
mock_subscription_request, user_id='test_user'
)
assert isinstance(result, CreateBillingSessionResponse)
assert isinstance(result, billing.CreateBillingSessionResponse)
assert result.redirect_url == 'https://checkout.stripe.com/test-session'
@pytest.mark.asyncio
async def test_create_subscription_checkout_session_success_no_existing(
mock_subscription_request,
):
"""Test successful subscription creation when no existing subscription."""
mock_session_obj = MagicMock()
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
mock_session_obj.id = 'test_session_id'
with (
patch('server.routes.billing.session_maker') as mock_session_maker,
patch(
'integrations.stripe_service.find_or_create_customer',
AsyncMock(return_value='cus_test123'),
),
patch(
'stripe.checkout.Session.create_async',
AsyncMock(return_value=mock_session_obj),
),
patch(
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
),
patch('server.routes.billing.validate_saas_environment'),
):
# Setup mock session to return no existing subscription
mock_session = MagicMock()
mock_session_maker.return_value.__enter__.return_value = mock_session
mock_session.query.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.filter.return_value.first.return_value = None
# Should succeed
result = await create_subscription_checkout_session(
mock_subscription_request, user_id='test_user'
# Verify Stripe session creation parameters
mock_create.assert_called_once_with(
customer='mock-customer-id',
mode='setup',
payment_method_types=['card'],
success_url='http://test.com/?free_credits=success',
cancel_url='http://test.com/',
)
assert isinstance(result, CreateBillingSessionResponse)
assert result.redirect_url == 'https://checkout.stripe.com/test-session'

View File

@@ -3,14 +3,29 @@ Tests for ConversationCallbackProcessor and ConversationCallback models.
"""
import json
from unittest.mock import patch
from uuid import UUID
import pytest
from storage.conversation_callback import (
CallbackStatus,
ConversationCallback,
ConversationCallbackProcessor,
# Import the actual StoredConversationMetadata from OpenHands core
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
# Mock the lazy import to return the actual class
with patch(
'storage.stored_conversation_metadata.StoredConversationMetadata',
StoredConversationMetadata,
):
from storage.conversation_callback import (
CallbackStatus,
ConversationCallback,
ConversationCallbackProcessor,
)
from storage.stored_conversation_metadata_saas import (
StoredConversationMetadataSaas,
)
from storage.stored_conversation_metadata import StoredConversationMetadata
from openhands.events.observation.agent import AgentStateChangedObservation
@@ -80,15 +95,22 @@ class TestConversationCallback:
"""Create a test conversation metadata record."""
with session_maker() as session:
metadata = StoredConversationMetadata(
conversation_id='test_conversation_123', user_id='test_user_456'
conversation_id='test_conversation_123'
)
metadata_saas = StoredConversationMetadataSaas(
conversation_id='test_conversation_123',
user_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
org_id=UUID('5594c7b6-f959-4b81-92e9-b09c206f5081'),
)
session.add(metadata)
session.add(metadata_saas)
session.commit()
session.refresh(metadata)
yield metadata
# Cleanup
session.delete(metadata)
session.delete(metadata_saas)
session.commit()
def test_callback_creation(self, conversation_metadata, session_maker):

View File

@@ -0,0 +1,650 @@
"""
Unit tests for LiteLlmManager class.
"""
import os
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from pydantic import SecretStr
from server.constants import (
get_default_litellm_model,
)
from storage.lite_llm_manager import LiteLlmManager
from storage.user_settings import UserSettings
from openhands.server.settings import Settings
class TestLiteLlmManager:
"""Test cases for LiteLlmManager class."""
@pytest.fixture
def mock_settings(self):
"""Create a mock Settings object."""
settings = Settings()
settings.agent = 'TestAgent'
settings.llm_model = 'test-model'
settings.llm_api_key = SecretStr('test-key')
settings.llm_base_url = 'http://test.com'
return settings
@pytest.fixture
def mock_user_settings(self):
"""Create a mock UserSettings object."""
user_settings = UserSettings()
user_settings.agent = 'TestAgent'
user_settings.llm_model = 'test-model'
user_settings.llm_api_key = SecretStr('test-key')
user_settings.llm_base_url = 'http://test.com'
return user_settings
@pytest.fixture
def mock_http_client(self):
"""Create a mock HTTP client."""
client = AsyncMock(spec=httpx.AsyncClient)
return client
@pytest.fixture
def mock_response(self):
"""Create a mock HTTP response."""
response = MagicMock()
response.is_success = True
response.status_code = 200
response.text = 'Success'
response.json.return_value = {'key': 'test-api-key'}
response.raise_for_status = MagicMock()
return response
@pytest.fixture
def mock_team_response(self):
"""Create a mock team response."""
response = MagicMock()
response.is_success = True
response.status_code = 200
response.json.return_value = {
'team_memberships': [
{
'user_id': 'test-user-id',
'team_id': 'test-org-id',
'max_budget': 100.0,
}
]
}
response.raise_for_status = MagicMock()
return response
@pytest.fixture
def mock_user_response(self):
"""Create a mock user response."""
response = MagicMock()
response.is_success = True
response.status_code = 200
response.json.return_value = {
'user_info': {
'max_budget': 50.0,
'spend': 10.0,
}
}
response.raise_for_status = MagicMock()
return response
@pytest.fixture
def mock_key_info_response(self):
"""Create a mock key info response."""
response = MagicMock()
response.is_success = True
response.status_code = 200
response.json.return_value = {
'info': {
'max_budget': 100.0,
'spend': 25.0,
}
}
response.raise_for_status = MagicMock()
return response
@pytest.mark.asyncio
async def test_create_entries_missing_config(self, mock_settings):
"""Test create_entries when LiteLLM config is missing."""
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings
)
assert result is None
@pytest.mark.asyncio
async def test_create_entries_local_deployment(self, mock_settings):
"""Test create_entries in local deployment mode."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-key'
assert result.llm_base_url == 'http://test.com'
@pytest.mark.asyncio
async def test_create_entries_cloud_deployment(self, mock_settings, mock_response):
"""Test create_entries in cloud deployment mode."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.post.return_value = mock_response
result = await LiteLlmManager.create_entries(
'test-org-id', 'test-user-id', mock_settings
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert (
result.llm_api_key.get_secret_value() == 'test-api-key'
)
assert result.llm_base_url == 'http://test.com'
# Verify API calls were made
assert (
mock_client.post.call_count == 4
) # create_team, create_user, add_user_to_team, generate_key
@pytest.mark.asyncio
async def test_migrate_entries_missing_config(self, mock_user_settings):
"""Test migrate_entries when LiteLLM config is missing."""
with patch.dict(os.environ, {'LITE_LLM_API_KEY': '', 'LITE_LLM_API_URL': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
result = await LiteLlmManager.migrate_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_migrate_entries_local_deployment(self, mock_user_settings):
"""Test migrate_entries in local deployment mode."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': '1'}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
result = await LiteLlmManager.migrate_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-key'
assert result.llm_base_url == 'http://test.com'
@pytest.mark.asyncio
async def test_migrate_entries_no_user_found(self, mock_user_settings):
"""Test migrate_entries when user is not found."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
# Mock the _get_user method directly to return None
with patch.object(
LiteLlmManager, '_get_user', new_callable=AsyncMock
) as mock_get_user:
mock_get_user.return_value = None
result = await LiteLlmManager.migrate_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_migrate_entries_already_migrated(
self, mock_user_settings, mock_user_response
):
"""Test migrate_entries when user is already migrated (no max_budget)."""
mock_user_response.json.return_value = {
'user_info': {
'max_budget': None, # Already migrated
'spend': 10.0,
}
}
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.get.return_value = mock_user_response
result = await LiteLlmManager.migrate_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is None
@pytest.mark.asyncio
async def test_migrate_entries_successful_migration(
self, mock_user_settings, mock_user_response, mock_response
):
"""Test successful migrate_entries operation."""
with patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}):
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'
):
with patch(
'storage.lite_llm_manager.TokenManager'
) as mock_token_manager:
mock_token_manager.return_value.get_user_info_from_user_id = (
AsyncMock(return_value={'email': 'test@example.com'})
)
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = (
mock_client
)
mock_client.get.return_value = mock_user_response
mock_client.post.return_value = mock_response
result = await LiteLlmManager.migrate_entries(
'test-org-id',
'test-user-id',
mock_user_settings,
)
assert result is not None
assert result.agent == 'CodeActAgent'
assert result.llm_model == get_default_litellm_model()
assert result.llm_api_key.get_secret_value() == 'test-key'
assert result.llm_base_url == 'http://test.com'
# Verify migration steps were called
assert (
mock_client.post.call_count == 4
) # create_team, update_user, add_user_to_team, update_key
@pytest.mark.asyncio
async def test_update_team_and_users_budget_missing_config(self):
"""Test update_team_and_users_budget when LiteLLM config is missing."""
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
# Should not raise an exception, just return early
await LiteLlmManager.update_team_and_users_budget('test-team-id', 100.0)
@pytest.mark.asyncio
async def test_update_team_and_users_budget_successful(
self, mock_team_response, mock_response
):
"""Test successful update_team_and_users_budget operation."""
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
mock_client.post.return_value = mock_response
mock_client.get.return_value = mock_team_response
await LiteLlmManager.update_team_and_users_budget(
'test-team-id', 100.0
)
# Verify update_team and update_user_in_team were called
assert (
mock_client.post.call_count == 2
) # update_team, update_user_in_team
@pytest.mark.asyncio
async def test_create_team_success(self, mock_http_client, mock_response):
"""Test successful _create_team operation."""
mock_http_client.post.return_value = mock_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
await LiteLlmManager._create_team(
mock_http_client, 'test-alias', 'test-team-id', 100.0
)
mock_http_client.post.assert_called_once()
call_args = mock_http_client.post.call_args
assert 'http://test.com/team/new' in call_args[0]
assert call_args[1]['json']['team_id'] == 'test-team-id'
assert call_args[1]['json']['team_alias'] == 'test-alias'
assert call_args[1]['json']['max_budget'] == 100.0
@pytest.mark.asyncio
async def test_create_team_already_exists(self, mock_http_client):
"""Test _create_team when team already exists."""
error_response = MagicMock()
error_response.is_success = False
error_response.status_code = 400
error_response.text = 'Team already exists. Please use a different team id'
mock_http_client.post.return_value = error_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
with patch.object(
LiteLlmManager, '_update_team', new_callable=AsyncMock
) as mock_update:
await LiteLlmManager._create_team(
mock_http_client, 'test-alias', 'test-team-id', 100.0
)
mock_update.assert_called_once_with(
mock_http_client, 'test-team-id', 'test-alias', 100.0
)
@pytest.mark.asyncio
async def test_create_team_error(self, mock_http_client):
"""Test _create_team with unexpected error."""
error_response = MagicMock()
error_response.is_success = False
error_response.status_code = 500
error_response.text = 'Internal server error'
error_response.raise_for_status.side_effect = httpx.HTTPStatusError(
'Server error', request=MagicMock(), response=error_response
)
mock_http_client.post.return_value = error_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
with pytest.raises(httpx.HTTPStatusError):
await LiteLlmManager._create_team(
mock_http_client, 'test-alias', 'test-team-id', 100.0
)
@pytest.mark.asyncio
async def test_get_team_success(self, mock_http_client, mock_team_response):
"""Test successful _get_team operation."""
mock_http_client.get.return_value = mock_team_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
result = await LiteLlmManager._get_team(
mock_http_client, 'test-team-id'
)
assert result is not None
assert 'team_memberships' in result
mock_http_client.get.assert_called_once_with(
'http://test.com/team/info?team_id=test-team-id'
)
@pytest.mark.asyncio
async def test_create_user_success(self, mock_http_client, mock_response):
"""Test successful _create_user operation."""
mock_http_client.post.return_value = mock_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
await LiteLlmManager._create_user(
mock_http_client, 'test@example.com', 'test-user-id'
)
mock_http_client.post.assert_called_once()
call_args = mock_http_client.post.call_args
assert 'http://test.com/user/new' in call_args[0]
assert call_args[1]['json']['user_email'] == 'test@example.com'
assert call_args[1]['json']['user_id'] == 'test-user-id'
@pytest.mark.asyncio
async def test_create_user_duplicate_email(self, mock_http_client, mock_response):
"""Test _create_user with duplicate email handling."""
# First call fails with duplicate email
error_response = MagicMock()
error_response.is_success = False
error_response.status_code = 400
error_response.text = 'duplicate email'
# Second call succeeds
mock_http_client.post.side_effect = [error_response, mock_response]
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
await LiteLlmManager._create_user(
mock_http_client, 'test@example.com', 'test-user-id'
)
assert mock_http_client.post.call_count == 2
# Second call should have None email
second_call_args = mock_http_client.post.call_args_list[1]
assert second_call_args[1]['json']['user_email'] is None
@pytest.mark.asyncio
async def test_generate_key_success(self, mock_http_client, mock_response):
"""Test successful _generate_key operation."""
mock_http_client.post.return_value = mock_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
result = await LiteLlmManager._generate_key(
mock_http_client,
'test-user-id',
'test-team-id',
'test-alias',
{'test': 'metadata'},
)
assert result == 'test-api-key'
mock_http_client.post.assert_called_once()
call_args = mock_http_client.post.call_args
assert 'http://test.com/key/generate' in call_args[0]
assert call_args[1]['json']['user_id'] == 'test-user-id'
assert call_args[1]['json']['team_id'] == 'test-team-id'
assert call_args[1]['json']['key_alias'] == 'test-alias'
assert call_args[1]['json']['metadata'] == {'test': 'metadata'}
@pytest.mark.asyncio
async def test_get_key_info_success(self, mock_http_client, mock_key_info_response):
"""Test successful _get_key_info operation."""
mock_http_client.get.return_value = mock_key_info_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
with patch('storage.user_store.UserStore') as mock_user_store:
# Mock user with org member
mock_user = MagicMock()
mock_org_member = MagicMock()
mock_org_member.org_id = 'test-ord-id'
mock_org_member.llm_api_key = 'test-api-key'
mock_user.org_members = [mock_org_member]
mock_user_store.get_user_by_id.return_value = mock_user
result = await LiteLlmManager._get_key_info(
mock_http_client, 'test-ord-id', 'test-user-id'
)
assert result is not None
assert result['key_max_budget'] == 100.0
assert result['key_spend'] == 25.0
@pytest.mark.asyncio
async def test_get_key_info_no_user(self, mock_http_client):
"""Test _get_key_info when user is not found."""
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
with patch('storage.user_store.UserStore') as mock_user_store:
mock_user_store.get_user_by_id.return_value = None
result = await LiteLlmManager._get_key_info(
mock_http_client, 'test-ord-id', 'test-user-id'
)
assert result == {}
@pytest.mark.asyncio
async def test_delete_key_success(self, mock_http_client, mock_response):
"""Test successful _delete_key operation."""
mock_http_client.post.return_value = mock_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
mock_http_client.post.assert_called_once()
call_args = mock_http_client.post.call_args
assert 'http://test.com/key/delete' in call_args[0]
assert call_args[1]['json']['keys'] == ['test-key-id']
@pytest.mark.asyncio
async def test_delete_key_not_found(self, mock_http_client):
"""Test _delete_key when key is not found (404 error)."""
error_response = MagicMock()
error_response.is_success = False
error_response.status_code = 404
error_response.text = 'Key not found'
mock_http_client.post.return_value = error_response
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'):
# Should not raise an exception for 404
await LiteLlmManager._delete_key(mock_http_client, 'test-key-id')
@pytest.mark.asyncio
async def test_with_http_client_decorator(self):
"""Test the with_http_client decorator functionality."""
# Create a mock internal function
async def mock_internal_fn(client, arg1, arg2, kwarg1=None):
return f'client={type(client).__name__}, arg1={arg1}, arg2={arg2}, kwarg1={kwarg1}'
# Apply the decorator
decorated_fn = LiteLlmManager.with_http_client(mock_internal_fn)
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'):
with patch('httpx.AsyncClient') as mock_client_class:
mock_client = AsyncMock()
mock_client_class.return_value.__aenter__.return_value = mock_client
result = await decorated_fn('test1', 'test2', kwarg1='test3')
# Verify the client was injected as the first argument
assert 'client=AsyncMock' in result
assert 'arg1=test1' in result
assert 'arg2=test2' in result
assert 'kwarg1=test3' in result
def test_public_methods_exist(self):
"""Test that all public wrapper methods exist and are properly decorated."""
public_methods = [
'create_team',
'get_team',
'update_team',
'create_user',
'get_user',
'update_user',
'delete_user',
'add_user_to_team',
'get_user_team_info',
'update_user_in_team',
'generate_key',
'get_key_info',
'delete_key',
]
for method_name in public_methods:
assert hasattr(LiteLlmManager, method_name)
method = getattr(LiteLlmManager, method_name)
assert callable(method)
# The methods are created by the with_http_client decorator, so they're functions
# We can verify they exist and are callable, which is the important part
@pytest.mark.asyncio
async def test_error_handling_missing_config_all_methods(self):
"""Test that all methods handle missing configuration gracefully."""
with patch('storage.lite_llm_manager.LITE_LLM_API_KEY', None):
with patch('storage.lite_llm_manager.LITE_LLM_API_URL', None):
mock_client = AsyncMock()
# Test all private methods that check for config
await LiteLlmManager._create_team(
mock_client, 'alias', 'team_id', 100.0
)
await LiteLlmManager._update_team(
mock_client, 'team_id', 'alias', 100.0
)
await LiteLlmManager._create_user(mock_client, 'email', 'user_id')
await LiteLlmManager._update_user(mock_client, 'user_id')
await LiteLlmManager._delete_user(mock_client, 'user_id')
await LiteLlmManager._add_user_to_team(
mock_client, 'user_id', 'team_id', 100.0
)
await LiteLlmManager._update_user_in_team(
mock_client, 'user_id', 'team_id', 100.0
)
await LiteLlmManager._delete_key(mock_client, 'key_id')
result1 = await LiteLlmManager._get_team(mock_client, 'team_id')
result2 = await LiteLlmManager._get_user(mock_client, 'user_id')
result3 = await LiteLlmManager._generate_key(
mock_client, 'user_id', 'team_id', 'alias', {}
)
result4 = await LiteLlmManager._get_user_team_info(
mock_client, 'user_id', 'team_id'
)
result5 = await LiteLlmManager._get_key_info(
mock_client, 'test-ord-id', 'user_id'
)
# Methods that return None when config is missing
assert result1 is None
assert result2 is None
assert result3 is None
assert result4 is None
assert result5 is None
# Verify no HTTP calls were made
mock_client.get.assert_not_called()
mock_client.post.assert_not_called()

View File

@@ -0,0 +1,70 @@
"""
Test that the models are correctly defined.
"""
from uuid import uuid4
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from storage.base import Base
from storage.org import Org
from storage.org_member import OrgMember
from storage.user import User
@pytest.fixture
def engine():
engine = create_engine('sqlite:///:memory:')
Base.metadata.create_all(engine)
return engine
@pytest.fixture
def session_maker(engine):
return sessionmaker(bind=engine)
def test_user_model(session_maker):
"""Test that the User model works correctly."""
with session_maker() as session:
# Create a test org
org = Org(name='test_org')
session.add(org)
session.flush()
# Create a test user
test_user_id = uuid4()
user = User(id=test_user_id, current_org_id=org.id, language='en')
session.add(user)
session.flush()
# Create org_member relationship
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=1,
llm_api_key='test-api-key',
status='active',
)
session.add(org_member)
session.commit()
# Query the user
queried_user = session.query(User).filter(User.id == test_user_id).first()
assert queried_user is not None
assert queried_user.language == 'en'
# Query the org
queried_org = session.query(Org).filter(Org.id == org.id).first()
assert queried_org is not None
assert queried_org.name == 'test_org'
# Query the org_member relationship
queried_org_member = (
session.query(OrgMember)
.filter(OrgMember.org_id == org.id, OrgMember.user_id == user.id)
.first()
)
assert queried_org_member is not None
assert queried_org_member.llm_api_key.get_secret_value() == 'test-api-key'

View File

@@ -0,0 +1,253 @@
import uuid
from unittest.mock import patch
# Mock the database module before importing OrgMemberStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.org import Org
from storage.org_member import OrgMember
from storage.org_member_store import OrgMemberStore
from storage.role import Role
from storage.user import User
def test_get_org_members(session_maker):
# Test getting org_members by org ID
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user1 = User(id=uuid.uuid4(), current_org_id=org.id)
user2 = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user1, user2, role])
session.flush()
org_member1 = OrgMember(
org_id=org.id,
user_id=user1.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org.id,
user_id=user2.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
org_id = org.id
# Test retrieval
with patch('storage.org_member_store.session_maker', session_maker):
org_members = OrgMemberStore.get_org_members(org_id)
assert len(org_members) == 2
api_keys = [om.llm_api_key.get_secret_value() for om in org_members]
assert 'test-key-1' in api_keys
assert 'test-key-2' in api_keys
def test_get_user_orgs(session_maker):
# Test getting org_members by user ID
with session_maker() as session:
# Create test data
org1 = Org(name='test-org-1')
org2 = Org(name='test-org-2')
session.add_all([org1, org2])
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org1.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member1 = OrgMember(
org_id=org1.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-1',
status='active',
)
org_member2 = OrgMember(
org_id=org2.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key-2',
status='active',
)
session.add_all([org_member1, org_member2])
session.commit()
user_id = user.id
# Test retrieval
with patch('storage.org_member_store.session_maker', session_maker):
org_members = OrgMemberStore.get_user_orgs(user_id)
assert len(org_members) == 2
api_keys = [ou.llm_api_key.get_secret_value() for ou in org_members]
assert 'test-key-1' in api_keys
assert 'test-key-2' in api_keys
def test_get_org_member(session_maker):
# Test getting org_member by org and user ID
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
session.commit()
org_id = org.id
user_id = user.id
# Test retrieval
with patch('storage.org_member_store.session_maker', session_maker):
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
assert retrieved_org_member is not None
assert retrieved_org_member.org_id == org_id
assert retrieved_org_member.user_id == user_id
assert retrieved_org_member.llm_api_key.get_secret_value() == 'test-key'
def test_add_user_to_org(session_maker):
# Test adding a user to an org
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.commit()
org_id = org.id
user_id = user.id
role_id = role.id
# Test creation
with patch('storage.org_member_store.session_maker', session_maker):
org_member = OrgMemberStore.add_user_to_org(
org_id=org_id,
user_id=user_id,
role_id=role_id,
llm_api_key='new-test-key',
status='active',
)
assert org_member is not None
assert org_member.org_id == org_id
assert org_member.user_id == user_id
assert org_member.role_id == role_id
assert org_member.llm_api_key.get_secret_value() == 'new-test-key'
assert org_member.status == 'active'
def test_update_user_role_in_org(session_maker):
# Test updating user role in org
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role1 = Role(name='admin', rank=1)
role2 = Role(name='user', rank=2)
session.add_all([user, role1, role2])
session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role1.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
session.commit()
org_id = org.id
user_id = user.id
role2_id = role2.id
# Test update
with patch('storage.org_member_store.session_maker', session_maker):
updated_org_member = OrgMemberStore.update_user_role_in_org(
org_id=org_id, user_id=user_id, role_id=role2_id, status='inactive'
)
assert updated_org_member is not None
assert updated_org_member.role_id == role2_id
assert updated_org_member.status == 'inactive'
def test_update_user_role_in_org_not_found(session_maker):
# Test updating org_member that doesn't exist
from uuid import uuid4
with patch('storage.org_member_store.session_maker', session_maker):
updated_org_member = OrgMemberStore.update_user_role_in_org(
org_id=uuid4(), user_id=99999, role_id=1
)
assert updated_org_member is None
def test_remove_user_from_org(session_maker):
# Test removing a user from an org
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
user = User(id=uuid.uuid4(), current_org_id=org.id)
role = Role(name='admin', rank=1)
session.add_all([user, role])
session.flush()
org_member = OrgMember(
org_id=org.id,
user_id=user.id,
role_id=role.id,
llm_api_key='test-key',
status='active',
)
session.add(org_member)
session.commit()
org_id = org.id
user_id = user.id
# Test removal
with patch('storage.org_member_store.session_maker', session_maker):
result = OrgMemberStore.remove_user_from_org(org_id, user_id)
assert result is True
# Verify it's removed
retrieved_org_member = OrgMemberStore.get_org_member(org_id, user_id)
assert retrieved_org_member is None
def test_remove_user_from_org_not_found(session_maker):
# Test removing user from org that doesn't exist
from uuid import uuid4
with patch('storage.org_member_store.session_maker', session_maker):
result = OrgMemberStore.remove_user_from_org(uuid4(), 99999)
assert result is False

View File

@@ -0,0 +1,197 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
# Mock the database module before importing OrgStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.org import Org
from storage.org_store import OrgStore
from openhands.storage.data_models.settings import Settings
@pytest.fixture
def mock_litellm_api():
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
api_url_patch = patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
)
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
client_patch = patch('httpx.AsyncClient')
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response
)
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
mock_client.return_value.__aenter__.return_value.patch.return_value = (
mock_response
)
yield mock_client
def test_get_org_by_id(session_maker, mock_litellm_api):
# Test getting org by ID
with session_maker() as session:
# Create a test org
org = Org(name='test-org')
session.add(org)
session.commit()
org_id = org.id
# Test retrieval
with (
patch('storage.org_store.session_maker', session_maker),
):
retrieved_org = OrgStore.get_org_by_id(org_id)
assert retrieved_org is not None
assert retrieved_org.id == org_id
assert retrieved_org.name == 'test-org'
def test_get_org_by_id_not_found(session_maker):
# Test getting org by ID when it doesn't exist
with patch('storage.org_store.session_maker', session_maker):
non_existent_id = uuid.uuid4()
retrieved_org = OrgStore.get_org_by_id(non_existent_id)
assert retrieved_org is None
def test_list_orgs(session_maker, mock_litellm_api):
# Test listing all orgs
with session_maker() as session:
# Create test orgs
org1 = Org(name='test-org-1')
org2 = Org(name='test-org-2')
session.add_all([org1, org2])
session.commit()
# Test listing
with (
patch('storage.org_store.session_maker', session_maker),
):
orgs = OrgStore.list_orgs()
assert len(orgs) >= 2
org_names = [org.name for org in orgs]
assert 'test-org-1' in org_names
assert 'test-org-2' in org_names
def test_update_org(session_maker, mock_litellm_api):
# Test updating org details
with session_maker() as session:
# Create a test org
org = Org(name='test-org', agent='CodeActAgent')
session.add(org)
session.commit()
org_id = org.id
# Test update
with (
patch('storage.org_store.session_maker', session_maker),
):
updated_org = OrgStore.update_org(
org_id=org_id, kwargs={'name': 'updated-org', 'agent': 'PlannerAgent'}
)
assert updated_org is not None
assert updated_org.name == 'updated-org'
assert updated_org.agent == 'PlannerAgent'
def test_update_org_not_found(session_maker):
# Test updating org that doesn't exist
with patch('storage.org_store.session_maker', session_maker):
from uuid import uuid4
updated_org = OrgStore.update_org(
org_id=uuid4(), kwargs={'name': 'updated-org'}
)
assert updated_org is None
def test_create_org(session_maker, mock_litellm_api):
# Test creating a new org
with (
patch('storage.org_store.session_maker', session_maker),
):
org = OrgStore.create_org(kwargs={'name': 'new-org', 'agent': 'CodeActAgent'})
assert org is not None
assert org.name == 'new-org'
assert org.agent == 'CodeActAgent'
assert org.id is not None
def test_get_org_by_name(session_maker, mock_litellm_api):
# Test getting org by name
with session_maker() as session:
# Create a test org
org = Org(name='test-org-by-name')
session.add(org)
session.commit()
# Test retrieval
with (
patch('storage.org_store.session_maker', session_maker),
):
retrieved_org = OrgStore.get_org_by_name('test-org-by-name')
assert retrieved_org is not None
assert retrieved_org.name == 'test-org-by-name'
def test_get_current_org_from_keycloak_user_id(session_maker, mock_litellm_api):
# Test getting current org from user ID
test_user_id = uuid.uuid4()
with session_maker() as session:
# Create test data
org = Org(name='test-org')
session.add(org)
session.flush()
from storage.user import User
user = User(id=test_user_id, current_org_id=org.id)
session.add(user)
session.commit()
# Test retrieval
with (
patch('storage.org_store.session_maker', session_maker),
):
retrieved_org = OrgStore.get_current_org_from_keycloak_user_id(
str(test_user_id)
)
assert retrieved_org is not None
assert retrieved_org.name == 'test-org'
def test_get_kwargs_from_settings():
# Test extracting org kwargs from settings
settings = Settings(
language='es',
agent='CodeActAgent',
llm_model='gpt-4',
llm_api_key=SecretStr('test-key'),
enable_sound_notifications=True,
)
kwargs = OrgStore.get_kwargs_from_settings(settings)
# Should only include fields that exist in Org model
assert 'agent' in kwargs
assert 'default_llm_model' in kwargs
assert kwargs['agent'] == 'CodeActAgent'
assert kwargs['default_llm_model'] == 'gpt-4'
# Should not include fields that don't exist in Org model
assert 'language' not in kwargs # language is not in Org model
assert 'llm_api_key' not in kwargs
assert 'llm_model' not in kwargs
assert 'enable_sound_notifications' not in kwargs

View File

@@ -1,32 +1,15 @@
from unittest.mock import MagicMock, patch
import pytest
from integrations.github.github_view import get_user_proactive_conversation_setting
from storage.user_settings import UserSettings
# Mock the database module before importing
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from integrations.github.github_view import get_user_proactive_conversation_setting
from storage.org import Org
pytestmark = pytest.mark.asyncio
# Mock the call_sync_from_async function to return the result of the function directly
def mock_call_sync_from_async(func, *args, **kwargs):
return func(*args, **kwargs)
@pytest.fixture
def mock_session():
session = MagicMock()
query = MagicMock()
filter = MagicMock()
# Mock the context manager behavior
session.__enter__.return_value = session
session.query.return_value = query
query.filter.return_value = filter
return session, query, filter
async def test_get_user_proactive_conversation_setting_no_user_id():
"""Test that the function returns False when no user ID is provided."""
with patch(
@@ -42,75 +25,82 @@ async def test_get_user_proactive_conversation_setting_no_user_id():
assert await get_user_proactive_conversation_setting(None) is False
async def test_get_user_proactive_conversation_setting_user_not_found(mock_session):
async def test_get_user_proactive_conversation_setting_user_not_found():
"""Test that False is returned when the user is not found."""
session, query, filter = mock_session
filter.first.return_value = None
with patch('integrations.github.github_view.session_maker', return_value=session):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=None,
):
with patch(
'integrations.github.github_view.call_sync_from_async',
side_effect=mock_call_sync_from_async,
):
assert await get_user_proactive_conversation_setting('user-id') is False
assert (
await get_user_proactive_conversation_setting(
'5594c7b6-f959-4b81-92e9-b09c206f5081'
)
is False
)
async def test_get_user_proactive_conversation_setting_user_setting_none(mock_session):
async def test_get_user_proactive_conversation_setting_user_setting_none():
"""Test that False is returned when the user setting is None."""
session, query, filter = mock_session
user_settings = MagicMock(spec=UserSettings)
user_settings.enable_proactive_conversation_starters = None
filter.first.return_value = user_settings
mock_org = MagicMock(spec=Org)
mock_org.enable_proactive_conversation_starters = None
with patch('integrations.github.github_view.session_maker', return_value=session):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
):
with patch(
'integrations.github.github_view.call_sync_from_async',
side_effect=mock_call_sync_from_async,
):
assert await get_user_proactive_conversation_setting('user-id') is False
assert (
await get_user_proactive_conversation_setting(
'5594c7b6-f959-4b81-92e9-b09c206f5081'
)
is False
)
async def test_get_user_proactive_conversation_setting_user_setting_true(mock_session):
async def test_get_user_proactive_conversation_setting_user_setting_true():
"""Test that True is returned when the user setting is True and the global setting is True."""
session, query, filter = mock_session
user_settings = MagicMock(spec=UserSettings)
user_settings.enable_proactive_conversation_starters = True
filter.first.return_value = user_settings
mock_org = MagicMock(spec=Org)
mock_org.enable_proactive_conversation_starters = True
with patch('integrations.github.github_view.session_maker', return_value=session):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
):
with patch(
'integrations.github.github_view.call_sync_from_async',
side_effect=mock_call_sync_from_async,
):
assert await get_user_proactive_conversation_setting('user-id') is True
assert (
await get_user_proactive_conversation_setting(
'5594c7b6-f959-4b81-92e9-b09c206f5081'
)
is True
)
async def test_get_user_proactive_conversation_setting_user_setting_false(mock_session):
async def test_get_user_proactive_conversation_setting_user_setting_false():
"""Test that False is returned when the user setting is False, regardless of global setting."""
session, query, filter = mock_session
user_settings = MagicMock(spec=UserSettings)
user_settings.enable_proactive_conversation_starters = False
filter.first.return_value = user_settings
mock_org = MagicMock(spec=Org)
mock_org.enable_proactive_conversation_starters = False
with patch('integrations.github.github_view.session_maker', return_value=session):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
):
with patch(
'integrations.github.github_view.ENABLE_PROACTIVE_CONVERSATION_STARTERS',
True,
'storage.org_store.OrgStore.get_current_org_from_keycloak_user_id',
return_value=mock_org,
):
with patch(
'integrations.github.github_view.call_sync_from_async',
side_effect=mock_call_sync_from_async,
):
assert await get_user_proactive_conversation_setting('user-id') is False
assert (
await get_user_proactive_conversation_setting(
'5594c7b6-f959-4b81-92e9-b09c206f5081'
)
is False
)

View File

@@ -0,0 +1,83 @@
from unittest.mock import patch
# Mock the database module before importing RoleStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.role import Role
from storage.role_store import RoleStore
def test_get_role_by_id(session_maker):
# Test getting role by ID
with session_maker() as session:
# Create a test role
role = Role(name='admin', rank=1)
session.add(role)
session.commit()
role_id = role.id
# Test retrieval
with patch('storage.role_store.session_maker', session_maker):
retrieved_role = RoleStore.get_role_by_id(role_id)
assert retrieved_role is not None
assert retrieved_role.id == role_id
assert retrieved_role.name == 'admin'
def test_get_role_by_id_not_found(session_maker):
# Test getting role by ID when it doesn't exist
with patch('storage.role_store.session_maker', session_maker):
retrieved_role = RoleStore.get_role_by_id(99999)
assert retrieved_role is None
def test_get_role_by_name(session_maker):
# Test getting role by name
with session_maker() as session:
# Create a test role
role = Role(name='admin', rank=1)
session.add(role)
session.commit()
role_id = role.id
# Test retrieval
with patch('storage.role_store.session_maker', session_maker):
retrieved_role = RoleStore.get_role_by_name('admin')
assert retrieved_role is not None
assert retrieved_role.id == role_id
assert retrieved_role.name == 'admin'
def test_get_role_by_name_not_found(session_maker):
# Test getting role by name when it doesn't exist
with patch('storage.role_store.session_maker', session_maker):
retrieved_role = RoleStore.get_role_by_name('nonexistent')
assert retrieved_role is None
def test_list_roles(session_maker):
# Test listing all roles
with session_maker() as session:
# Create test roles
role1 = Role(name='admin', rank=1)
role2 = Role(name='user', rank=2)
session.add_all([role1, role2])
session.commit()
# Test listing
with patch('storage.role_store.session_maker', session_maker):
roles = RoleStore.list_roles()
assert len(roles) >= 2
role_names = [role.name for role in roles]
assert 'admin' in role_names
assert 'user' in role_names
def test_create_role(session_maker):
# Test creating a new role
with patch('storage.role_store.session_maker', session_maker):
role = RoleStore.create_role(name='moderator', rank=2)
assert role is not None
assert role.name == 'moderator'
assert role.rank == 2
assert role.id is not None

View File

@@ -1,11 +1,26 @@
from datetime import UTC, datetime
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from storage.saas_conversation_store import SaasConversationStore
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
# Mock the database module before importing
with patch('storage.database.engine'), patch('storage.database.a_engine'):
# Import the actual StoredConversationMetadata from OpenHands core
from openhands.app_server.app_conversation.sql_app_conversation_info_service import (
StoredConversationMetadata,
)
# Mock the lazy import to return the actual class
with patch(
'storage.stored_conversation_metadata.StoredConversationMetadata',
StoredConversationMetadata,
):
from storage.saas_conversation_store import SaasConversationStore
from storage.user import User
@pytest.fixture(autouse=True)
def mock_call_sync_from_async():
@@ -20,12 +35,25 @@ def mock_call_sync_from_async():
yield
@pytest.fixture(autouse=True)
def mock_user_store():
"""Mock UserStore.get_user_by_id to return a mock user"""
mock_user = MagicMock(spec=User)
mock_user.current_org_id = UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
with patch(
'storage.saas_conversation_store.UserStore.get_user_by_id',
return_value=mock_user,
):
yield
@pytest.mark.asyncio
async def test_save_and_get(session_maker):
store = SaasConversationStore('12345', session_maker)
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
metadata = ConversationMetadata(
conversation_id='my-conversation-id',
user_id='12345',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
selected_repository='my-repo',
selected_branch=None,
created_at=datetime.now(UTC),
@@ -47,13 +75,13 @@ async def test_save_and_get(session_maker):
@pytest.mark.asyncio
async def test_search(session_maker):
store = SaasConversationStore('12345', session_maker)
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
# Create test conversations with different timestamps
conversations = [
ConversationMetadata(
conversation_id=f'conv-{i}',
user_id='12345',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
selected_repository='repo',
selected_branch=None,
created_at=datetime(2024, 1, i + 1, tzinfo=UTC),
@@ -92,10 +120,10 @@ async def test_search(session_maker):
@pytest.mark.asyncio
async def test_delete_metadata(session_maker):
store = SaasConversationStore('12345', session_maker)
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
metadata = ConversationMetadata(
conversation_id='to-delete',
user_id='12345',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
selected_repository='repo',
selected_branch=None,
created_at=datetime.now(UTC),
@@ -112,17 +140,17 @@ async def test_delete_metadata(session_maker):
@pytest.mark.asyncio
async def test_get_nonexistent_metadata(session_maker):
store = SaasConversationStore('12345', session_maker)
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
with pytest.raises(FileNotFoundError):
await store.get_metadata('nonexistent-id')
@pytest.mark.asyncio
async def test_exists(session_maker):
store = SaasConversationStore('12345', session_maker)
store = SaasConversationStore('5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker)
metadata = ConversationMetadata(
conversation_id='exists-test',
user_id='12345',
user_id='5594c7b6-f959-4b81-92e9-b09c206f5081',
selected_repository='repo',
selected_branch='test-branch',
created_at=datetime.now(UTC),

View File

@@ -1,6 +1,7 @@
from types import MappingProxyType
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
from pydantic import SecretStr
@@ -19,6 +20,14 @@ def mock_config():
return config
@pytest.fixture
def mock_user():
"""Mock user with org_id."""
user = MagicMock()
user.current_org_id = UUID('a1111111-1111-1111-1111-111111111111')
return user
@pytest.fixture
def secrets_store(session_maker, mock_config):
return SaasSecretsStore('user-id', session_maker, mock_config)
@@ -26,7 +35,11 @@ def secrets_store(session_maker, mock_config):
class TestSaasSecretsStore:
@pytest.mark.asyncio
async def test_store_and_load(self, secrets_store):
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
async def test_store_and_load(self, mock_get_user, secrets_store, mock_user):
# Setup mock
mock_get_user.return_value = mock_user
# Create a Secrets object with some test data
user_secrets = Secrets(
custom_secrets=MappingProxyType(
@@ -59,7 +72,10 @@ class TestSaasSecretsStore:
)
@pytest.mark.asyncio
async def test_encryption_decryption(self, secrets_store):
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
async def test_encryption_decryption(self, mock_get_user, secrets_store, mock_user):
# Setup mock
mock_get_user.return_value = mock_user
# Create a Secrets object with sensitive data
user_secrets = Secrets(
custom_secrets=MappingProxyType(
@@ -89,6 +105,7 @@ class TestSaasSecretsStore:
stored = (
session.query(StoredCustomSecrets)
.filter(StoredCustomSecrets.keycloak_user_id == 'user-id')
.filter(StoredCustomSecrets.org_id == mock_user.current_org_id)
.first()
)
@@ -152,7 +169,12 @@ class TestSaasSecretsStore:
assert await secrets_store.load() is None
@pytest.mark.asyncio
async def test_update_existing_secrets(self, secrets_store):
@patch('storage.saas_secrets_store.UserStore.get_user_by_id')
async def test_update_existing_secrets(
self, mock_get_user, secrets_store, mock_user
):
# Setup mock
mock_get_user.return_value = mock_user
# Create and store initial secrets
initial_secrets = Secrets(
custom_secrets=MappingProxyType(

View File

@@ -2,65 +2,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
from server.constants import (
CURRENT_USER_SETTINGS_VERSION,
LITE_LLM_API_URL,
LITE_LLM_TEAM_ID,
)
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.server.settings import Settings
@pytest.fixture
def mock_litellm_get_response():
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.json = MagicMock(return_value={'user_info': {}})
return mock_response
@pytest.fixture
def mock_litellm_post_response():
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
return mock_response
@pytest.fixture
def mock_litellm_api(mock_litellm_get_response, mock_litellm_post_response):
api_key_patch = patch('storage.saas_settings_store.LITE_LLM_API_KEY', 'test_key')
api_url_patch = patch(
'storage.saas_settings_store.LITE_LLM_API_URL', 'http://test.url'
# Mock the database module before importing
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from server.constants import (
LITE_LLM_API_URL,
)
team_id_patch = patch('storage.saas_settings_store.LITE_LLM_TEAM_ID', 'test_team')
client_patch = patch('httpx.AsyncClient')
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_litellm_get_response
)
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_litellm_post_response
)
yield mock_client
@pytest.fixture
def mock_stripe():
search_patch = patch(
'stripe.Customer.search_async',
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
)
payment_patch = patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[{}])),
)
with search_patch, payment_patch:
yield
from storage.saas_settings_store import SaasSettingsStore
from storage.user_settings import UserSettings
@pytest.fixture
@@ -83,41 +35,42 @@ def mock_config():
@pytest.fixture
def settings_store(session_maker, mock_config):
store = SaasSettingsStore('user-id', session_maker, mock_config)
store = SaasSettingsStore(
'5594c7b6-f959-4b81-92e9-b09c206f5081', session_maker, mock_config
)
# Patch the store method directly to filter out email and email_verified
original_load = store.load
original_create_default = store.create_default_settings
original_update_litellm = store.update_settings_with_litellm_default
# Patch the load method to add email and email_verified
# Patch the load method to read from UserSettings table directly (for testing)
async def patched_load():
settings = await original_load()
if settings:
# Add email and email_verified fields to mimic SaasUserAuth behavior
with store.session_maker() as session:
user_settings = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == store.user_id)
.first()
)
if not user_settings:
# Return default settings
return Settings(
llm_api_key=SecretStr('test_api_key'),
llm_base_url='http://test.url',
agent='CodeActAgent',
language='en',
)
# Decrypt and convert to Settings
kwargs = {}
for column in UserSettings.__table__.columns:
if column.name != 'keycloak_user_id':
value = getattr(user_settings, column.name, None)
if value is not None:
kwargs[column.name] = value
store._decrypt_kwargs(kwargs)
settings = Settings(**kwargs)
settings.email = 'test@example.com'
settings.email_verified = True
return settings
return settings
# Patch the create_default_settings method to add email and email_verified
async def patched_create_default(settings):
settings = await original_create_default(settings)
if settings:
# Add email and email_verified fields to mimic SaasUserAuth behavior
settings.email = 'test@example.com'
settings.email_verified = True
return settings
# Patch the update_settings_with_litellm_default method
async def patched_update_litellm(settings):
updated_settings = await original_update_litellm(settings)
if updated_settings:
# Add email and email_verified fields to mimic SaasUserAuth behavior
updated_settings.email = 'test@example.com'
updated_settings.email_verified = True
return updated_settings
# Patch the store method to filter out email and email_verified
# Patch the store method to write to UserSettings table directly (for testing)
async def patched_store(item):
if item:
# Make a copy of the item without email and email_verified
@@ -146,11 +99,9 @@ def settings_store(session_maker, mock_config):
for key, value in item_dict.items():
if key in existing.__class__.__table__.columns:
setattr(existing, key, value)
existing.user_version = CURRENT_USER_SETTINGS_VERSION
session.merge(existing)
else:
item_dict['keycloak_user_id'] = store.user_id
item_dict['user_version'] = CURRENT_USER_SETTINGS_VERSION
settings = UserSettings(**item_dict)
session.add(settings)
session.commit()
@@ -158,8 +109,6 @@ def settings_store(session_maker, mock_config):
# Replace the methods with our patched versions
store.store = patched_store
store.load = patched_load
store.create_default_settings = patched_create_default
store.update_settings_with_litellm_default = patched_update_litellm
return store
@@ -197,17 +146,11 @@ async def test_store_and_load_keycloak_user(settings_store):
@pytest.mark.asyncio
async def test_load_returns_default_when_not_found(
settings_store, mock_litellm_api, mock_stripe, mock_github_user, session_maker
):
async def test_load_returns_default_when_not_found(settings_store, session_maker):
file_store = MagicMock()
file_store.read.side_effect = FileNotFoundError()
with (
patch(
'storage.saas_settings_store.get_file_store',
MagicMock(return_value=file_store),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
loaded_settings = await settings_store.load()
@@ -218,233 +161,9 @@ async def test_load_returns_default_when_not_found(
assert loaded_settings.llm_base_url == 'http://test.url'
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default(
settings_store, mock_litellm_api, session_maker
):
settings = Settings()
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
settings = await settings_store.update_settings_with_litellm_default(settings)
assert settings.agent == 'CodeActAgent'
assert settings.llm_api_key
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
assert settings.llm_base_url == 'http://test.url'
# Get the actual call arguments
call_args = mock_litellm_api.return_value.__aenter__.return_value.post.call_args[1]
# Check that the URL and most of the JSON payload match what we expect
assert call_args['json']['user_email'] == 'testy@tester.com'
assert call_args['json']['models'] == []
assert call_args['json']['max_budget'] == 10.0
assert call_args['json']['user_id'] == 'user-id'
assert call_args['json']['teams'] == ['test_team']
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
@pytest.mark.asyncio
async def test_create_default_settings_no_user_id():
store = SaasSettingsStore('', MagicMock(), MagicMock())
settings = await store.create_default_settings(None)
assert settings is None
@pytest.mark.asyncio
async def test_create_default_settings_require_payment_enabled(
settings_store, mock_stripe
):
# Mock stripe_service.has_payment_method to return False
with (
patch('storage.saas_settings_store.REQUIRE_PAYMENT', True),
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[])),
),
patch(
'integrations.stripe_service.session_maker', settings_store.session_maker
),
):
settings = await settings_store.create_default_settings(None)
assert settings is None
@pytest.mark.asyncio
async def test_create_default_settings_require_payment_disabled(
settings_store, mock_stripe, mock_github_user, mock_litellm_api, session_maker
):
# Even without payment method, should get default settings when REQUIRE_PAYMENT is False
file_store = MagicMock()
file_store.read.side_effect = FileNotFoundError()
with (
patch('storage.saas_settings_store.REQUIRE_PAYMENT', False),
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[])),
),
patch(
'storage.saas_settings_store.get_file_store',
MagicMock(return_value=file_store),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
settings = await settings_store.create_default_settings(None)
assert settings is not None
assert settings.language == 'en'
@pytest.mark.asyncio
async def test_create_default_lite_llm_settings_no_api_config(settings_store):
with (
patch('storage.saas_settings_store.LITE_LLM_API_KEY', None),
patch('storage.saas_settings_store.LITE_LLM_API_URL', None),
):
settings = Settings()
settings = await settings_store.update_settings_with_litellm_default(settings)
@pytest.mark.asyncio
async def test_update_settings_with_litellm_default_error(settings_store):
with patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'duplicate@example.com'}),
):
with patch('httpx.AsyncClient') as mock_client:
mock_client.return_value.__aenter__.return_value.get.return_value = (
AsyncMock(
json=MagicMock(
return_value={'user_info': {'max_budget': 10, 'spend': 5}}
)
)
)
mock_client.return_value.__aenter__.return_value.post.return_value.is_success = False
settings = Settings()
settings = await settings_store.update_settings_with_litellm_default(
settings
)
assert settings is None
@pytest.mark.asyncio
async def test_update_settings_with_litellm_retry_on_duplicate_email(
settings_store, mock_litellm_api, session_maker
):
# First response is a delete and succeeds
mock_delete_response = MagicMock()
mock_delete_response.is_success = True
mock_delete_response.status_code = 200
# Second response fails with duplicate email error
mock_error_response = MagicMock()
mock_error_response.is_success = False
mock_error_response.status_code = 400
mock_error_response.text = 'User with this email already exists'
# Thire response succeeds with no email
mock_success_response = MagicMock()
mock_success_response.is_success = True
mock_success_response.json = MagicMock(return_value={'key': 'new_test_api_key'})
# Set up mocks
post_mock = AsyncMock()
post_mock.side_effect = [
mock_delete_response,
mock_error_response,
mock_success_response,
]
mock_litellm_api.return_value.__aenter__.return_value.post = post_mock
with (
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'duplicate@example.com'}),
),
patch('storage.saas_settings_store.session_maker', session_maker),
):
settings = Settings()
settings = await settings_store.update_settings_with_litellm_default(settings)
assert settings is not None
assert settings.llm_api_key
assert settings.llm_api_key.get_secret_value() == 'new_test_api_key'
# Verify second call was with email
second_call_args = post_mock.call_args_list[1][1]
assert second_call_args['json']['user_email'] == 'duplicate@example.com'
# Verify third call was with None for email
third_call_args = post_mock.call_args_list[2][1]
assert third_call_args['json']['user_email'] is None
@pytest.mark.asyncio
async def test_create_user_in_lite_llm(settings_store):
# Test the _create_user_in_lite_llm method directly
mock_client = AsyncMock()
mock_response = AsyncMock()
mock_response.is_success = True
mock_client.post.return_value = mock_response
# Test with email
await settings_store._create_user_in_lite_llm(
mock_client, 'test@example.com', 50, 10
)
# Get the actual call arguments
call_args = mock_client.post.call_args[1]
# Check that the URL and most of the JSON payload match what we expect
assert call_args['json']['user_email'] == 'test@example.com'
assert call_args['json']['models'] == []
assert call_args['json']['max_budget'] == 50
assert call_args['json']['spend'] == 10
assert call_args['json']['user_id'] == 'user-id'
assert call_args['json']['teams'] == [LITE_LLM_TEAM_ID]
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
# Test with None email
mock_client.post.reset_mock()
await settings_store._create_user_in_lite_llm(mock_client, None, 25, 15)
# Get the actual call arguments
call_args = mock_client.post.call_args[1]
# Check that the URL and most of the JSON payload match what we expect
assert call_args['json']['user_email'] is None
assert call_args['json']['models'] == []
assert call_args['json']['max_budget'] == 25
assert call_args['json']['spend'] == 15
assert call_args['json']['user_id'] == str(settings_store.user_id)
assert call_args['json']['teams'] == [LITE_LLM_TEAM_ID]
assert call_args['json']['auto_create_key'] is True
assert call_args['json']['send_invite_email'] is False
assert call_args['json']['metadata']['version'] == CURRENT_USER_SETTINGS_VERSION
assert 'model' in call_args['json']['metadata']
# Verify response is returned correctly
assert (
await settings_store._create_user_in_lite_llm(
mock_client, 'email@test.com', 30, 7
)
== mock_response
)
@pytest.mark.asyncio
async def test_encryption(settings_store):
settings_store.user_id = 'mock-id' # GitHub user ID
settings_store.user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081' # GitHub user ID
settings = Settings(
llm_api_key=SecretStr('secret_key'),
agent='smith',
@@ -456,7 +175,9 @@ async def test_encryption(settings_store):
with settings_store.session_maker() as session:
stored = (
session.query(UserSettings)
.filter(UserSettings.keycloak_user_id == 'mock-id')
.filter(
UserSettings.keycloak_user_id == '5594c7b6-f959-4b81-92e9-b09c206f5081'
)
.first()
)
# The stored key should be encrypted

View File

@@ -3,27 +3,30 @@ This test file verifies that the stripe_service functions properly use the datab
to store and retrieve customer IDs.
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import stripe
from integrations.stripe_service import (
find_customer_id_by_user_id,
find_or_create_customer,
find_or_create_customer_by_user_id,
)
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from storage.stripe_customer import Base as StripeCustomerBase
from storage.base import Base
from storage.org import Org
from storage.org_member import OrgMember
from storage.role import Role
from storage.stripe_customer import StripeCustomer
from storage.user_settings import Base as UserBase
from storage.user import User
@pytest.fixture
def engine():
engine = create_engine('sqlite:///:memory:')
UserBase.metadata.create_all(engine)
StripeCustomerBase.metadata.create_all(engine)
# Create all tables using the unified Base
Base.metadata.create_all(engine)
return engine
@@ -32,79 +35,158 @@ def session_maker(engine):
return sessionmaker(bind=engine)
@pytest.fixture
def test_org_and_user(session_maker):
"""Create a test org and user for use in tests."""
test_user_id = uuid.uuid4()
test_org_id = uuid.uuid4()
with session_maker() as session:
# Create role first
role = Role(name='test-role', rank=1)
session.add(role)
session.flush()
# Create org
org = Org(id=test_org_id, name='test-org', contact_email='testy@tester.com')
session.add(org)
session.flush()
# Create user with current_org_id
user = User(id=test_user_id, current_org_id=test_org_id, role_id=role.id)
session.add(user)
session.flush()
# Create org member relationship
org_member = OrgMember(
org_id=test_org_id,
user_id=test_user_id,
role_id=role.id,
llm_api_key='test-key',
)
session.add(org_member)
session.commit()
return test_user_id, test_org_id
@pytest.mark.asyncio
async def test_find_customer_id_by_user_id_checks_db_first(session_maker):
async def test_find_customer_id_by_user_id_checks_db_first(
session_maker, test_org_and_user
):
"""Test that find_customer_id_by_user_id checks the database first"""
test_user_id, test_org_id = test_org_and_user
# Set up the mock for the database query result
with session_maker() as session:
# Create stripe customer
session.add(
StripeCustomer(
keycloak_user_id='test-user-id',
keycloak_user_id=str(test_user_id),
org_id=test_org_id,
stripe_customer_id='cus_test123',
)
)
session.commit()
with patch('integrations.stripe_service.session_maker', session_maker):
# Create a mock org object to return from OrgStore
mock_org = MagicMock()
mock_org.id = test_org_id
with (
patch('integrations.stripe_service.session_maker', session_maker),
patch('storage.org_store.session_maker', session_maker),
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
):
# Mock the call_sync_from_async to return the org
mock_call_sync.return_value = mock_org
# Call the function
result = await find_customer_id_by_user_id('test-user-id')
result = await find_customer_id_by_user_id(str(test_user_id))
# Verify the result
assert result == 'cus_test123'
# Verify that call_sync_from_async was called with the correct function
mock_call_sync.assert_called_once()
@pytest.mark.asyncio
async def test_find_customer_id_by_user_id_falls_back_to_stripe(session_maker):
async def test_find_customer_id_by_user_id_falls_back_to_stripe(
session_maker, test_org_and_user
):
"""Test that find_customer_id_by_user_id falls back to Stripe if not found in the database"""
test_user_id, test_org_id = test_org_and_user
# Set up the mock for stripe.Customer.search_async
mock_customer = stripe.Customer(id='cus_test123')
mock_search = AsyncMock(return_value=MagicMock(data=[mock_customer]))
# Create a mock org object to return from OrgStore
mock_org = MagicMock()
mock_org.id = test_org_id
with (
patch('integrations.stripe_service.session_maker', session_maker),
patch('storage.org_store.session_maker', session_maker),
patch('stripe.Customer.search_async', mock_search),
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
):
# Mock the call_sync_from_async to return the org
mock_call_sync.return_value = mock_org
# Call the function
result = await find_customer_id_by_user_id('test-user-id')
result = await find_customer_id_by_user_id(str(test_user_id))
# Verify the result
assert result == 'cus_test123'
# Verify that Stripe was searched
# Verify that Stripe was searched with the org_id
mock_search.assert_called_once()
assert "metadata['user_id']:'test-user-id'" in mock_search.call_args[1]['query']
assert (
f"metadata['org_id']:'{str(test_org_id)}'" in mock_search.call_args[1]['query']
)
@pytest.mark.asyncio
async def test_create_customer_stores_id_in_db(session_maker):
async def test_create_customer_stores_id_in_db(session_maker, test_org_and_user):
"""Test that create_customer stores the customer ID in the database"""
# Set up the mock for stripe.Customer.search_async
test_user_id, test_org_id = test_org_and_user
# Set up the mock for stripe.Customer.search_async and create_async
mock_search = AsyncMock(return_value=MagicMock(data=[]))
mock_create_async = AsyncMock(return_value=stripe.Customer(id='cus_test123'))
# Create a mock org object to return from OrgStore
mock_org = MagicMock()
mock_org.id = test_org_id
mock_org.contact_email = 'testy@tester.com'
with (
patch('integrations.stripe_service.session_maker', session_maker),
patch('storage.org_store.session_maker', session_maker),
patch('stripe.Customer.search_async', mock_search),
patch('stripe.Customer.create_async', mock_create_async),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'email': 'testy@tester.com'}),
),
patch('integrations.stripe_service.call_sync_from_async') as mock_call_sync,
):
# Mock the call_sync_from_async to return the org
mock_call_sync.return_value = mock_org
# Call the function
result = await find_or_create_customer('test-user-id')
result = await find_or_create_customer_by_user_id(str(test_user_id))
# Verify the result
assert result == 'cus_test123'
assert result == {'customer_id': 'cus_test123', 'org_id': str(test_org_id)}
# Verify that the stripe customer was stored in the db
with session_maker() as session:
customer = session.query(StripeCustomer).first()
assert customer.id > 0
assert customer.keycloak_user_id == 'test-user-id'
assert customer.keycloak_user_id == str(test_user_id)
assert customer.org_id == test_org_id
assert customer.stripe_customer_id == 'cus_test123'
assert customer.created_at is not None
assert customer.updated_at is not None

View File

@@ -0,0 +1,164 @@
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic import SecretStr
# Mock the database module before importing UserStore
with patch('storage.database.engine'), patch('storage.database.a_engine'):
from storage.user import User
from storage.user_store import UserStore
from sqlalchemy.orm import configure_mappers
from openhands.storage.data_models.settings import Settings
@pytest.fixture(autouse=True, scope='session')
def load_all_models():
configure_mappers() # fail fast if anythings missing
yield
@pytest.fixture
def mock_litellm_api():
api_key_patch = patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test_key')
api_url_patch = patch(
'storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.url'
)
team_id_patch = patch('storage.lite_llm_manager.LITE_LLM_TEAM_ID', 'test_team')
client_patch = patch('httpx.AsyncClient')
with api_key_patch, api_url_patch, team_id_patch, client_patch as mock_client:
mock_response = AsyncMock()
mock_response.is_success = True
mock_response.json = MagicMock(return_value={'key': 'test_api_key'})
mock_client.return_value.__aenter__.return_value.post.return_value = (
mock_response
)
mock_client.return_value.__aenter__.return_value.get.return_value = (
mock_response
)
yield mock_client
@pytest.fixture
def mock_stripe():
search_patch = patch(
'stripe.Customer.search_async',
AsyncMock(return_value=MagicMock(id='mock-customer-id')),
)
payment_patch = patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[{}])),
)
with search_patch, payment_patch:
yield
@pytest.mark.asyncio
async def test_create_default_settings_no_org_id():
# Test UserStore.create_default_settings with empty org_id
settings = await UserStore.create_default_settings('', 'test-user-id')
assert settings is None
@pytest.mark.asyncio
async def test_create_default_settings_require_org(session_maker, mock_stripe):
# Mock stripe_service.has_payment_method to return False
with (
patch(
'stripe.Customer.list_payment_methods_async',
AsyncMock(return_value=MagicMock(data=[])),
),
patch('integrations.stripe_service.session_maker', session_maker),
):
settings = await UserStore.create_default_settings(
'test-org-id', 'test-user-id'
)
assert settings is None
@pytest.mark.asyncio
async def test_create_default_settings_with_litellm(session_maker, mock_litellm_api):
# Test that UserStore.create_default_settings works with LiteLLM
with (
patch('integrations.stripe_service.session_maker', session_maker),
patch('storage.user_store.session_maker', session_maker),
patch('storage.org_store.session_maker', session_maker),
patch(
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
AsyncMock(return_value={'attributes': {'github_id': ['12345']}}),
),
):
settings = await UserStore.create_default_settings(
'test-org-id', 'test-user-id'
)
assert settings is not None
assert settings.llm_api_key.get_secret_value() == 'test_api_key'
assert settings.llm_base_url == 'http://test.url'
assert settings.agent == 'CodeActAgent'
@pytest.mark.skip(reason='Complex integration test with session isolation issues')
@pytest.mark.asyncio
async def test_create_user(session_maker, mock_litellm_api):
# Test creating a new user - skipped due to complex session isolation issues
pass
def test_get_user_by_id(session_maker):
# Test getting user by ID
test_org_id = uuid.uuid4()
test_user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
with session_maker() as session:
# Create a test user
user = User(id=uuid.UUID(test_user_id), current_org_id=test_org_id)
session.add(user)
session.commit()
user_id = user.id
# Test retrieval
with patch('storage.user_store.session_maker', session_maker):
retrieved_user = UserStore.get_user_by_id(test_user_id)
assert retrieved_user is not None
assert retrieved_user.id == user_id
def test_list_users(session_maker):
# Test listing all users
test_org_id1 = uuid.uuid4()
test_org_id2 = uuid.uuid4()
test_user_id1 = uuid.uuid4()
test_user_id2 = uuid.uuid4()
with session_maker() as session:
# Create test users
user1 = User(id=test_user_id1, current_org_id=test_org_id1)
user2 = User(id=test_user_id2, current_org_id=test_org_id2)
session.add_all([user1, user2])
session.commit()
# Test listing
with patch('storage.user_store.session_maker', session_maker):
users = UserStore.list_users()
assert len(users) >= 2
user_ids = [user.id for user in users]
assert test_user_id1 in user_ids
assert test_user_id2 in user_ids
def test_get_kwargs_from_settings():
# Test extracting user kwargs from settings
settings = Settings(
language='es',
enable_sound_notifications=True,
llm_api_key=SecretStr('test-key'),
)
kwargs = UserStore.get_kwargs_from_settings(settings)
# Should only include fields that exist in User model
assert 'language' in kwargs
assert 'enable_sound_notifications' in kwargs
# Should not include fields that don't exist in User model
assert 'llm_api_key' not in kwargs

View File

@@ -21,11 +21,23 @@ import logging
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import AsyncGenerator
from typing import TYPE_CHECKING, AsyncGenerator
from uuid import UUID
if TYPE_CHECKING:
from openhands.app_server.user.user_context import UserContext
from fastapi import Request
from sqlalchemy import Column, DateTime, Float, Integer, Select, String, func, select
from sqlalchemy import (
Column,
DateTime,
Float,
Integer,
Select,
String,
func,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession
from openhands.agent_server.utils import utc_now
@@ -39,7 +51,6 @@ from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationSortOrder,
)
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.user_context import UserContext
from openhands.app_server.utils.sql_utils import (
Base,
create_json_type_decorator,
@@ -59,8 +70,6 @@ class StoredConversationMetadata(Base): # type: ignore
conversation_id = Column(
String, primary_key=True, default=lambda: str(uuid.uuid4())
)
github_user_id = Column(String, nullable=True) # The GitHub user ID
user_id = Column(String, nullable=False) # The Keycloak User ID
selected_repository = Column(String, nullable=True)
selected_branch = Column(String, nullable=True)
git_provider = Column(
@@ -188,10 +197,9 @@ class SQLAppConversationInfoService(AppConversationInfoService):
updated_at__lt: datetime | None = None,
) -> int:
"""Count sandboxed conversations matching the given filters."""
query = select(func.count(StoredConversationMetadata.conversation_id))
user_id = await self.user_context.get_user_id()
if user_id:
query = query.where(StoredConversationMetadata.user_id == user_id)
query = select(func.count(StoredConversationMetadata.conversation_id)).where(
StoredConversationMetadata.conversation_version == 'V1'
)
query = self._apply_filters(
query=query,
@@ -308,22 +316,11 @@ class SQLAppConversationInfoService(AppConversationInfoService):
async def save_app_conversation_info(
self, info: AppConversationInfo
) -> AppConversationInfo:
user_id = await self.user_context.get_user_id()
if user_id:
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_id == str(info.id)
)
result = await self.db_session.execute(query)
existing = result.scalar_one_or_none()
assert existing is None or existing.user_id == user_id
metrics = info.metrics or MetricsSnapshot()
usage = metrics.accumulated_token_usage or TokenUsage()
stored = StoredConversationMetadata(
conversation_id=str(info.id),
github_user_id=None, # TODO: Should we add this to the conversation info?
user_id=info.created_by_user_id or '',
selected_repository=info.selected_repository,
selected_branch=info.selected_branch,
git_provider=info.git_provider.value if info.git_provider else None,
@@ -331,7 +328,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
last_updated_at=info.updated_at,
created_at=info.created_at,
trigger=info.trigger.value if info.trigger else None,
pr_number=info.pr_number,
pr_number=info.pr_number or [],
# Cost and token metrics
accumulated_cost=metrics.accumulated_cost,
prompt_tokens=usage.prompt_tokens,
@@ -484,11 +481,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
query = select(StoredConversationMetadata).where(
StoredConversationMetadata.conversation_version == 'V1'
)
user_id = await self.user_context.get_user_id()
if user_id:
query = query.where(
StoredConversationMetadata.user_id == user_id,
)
return query
def _to_info(
@@ -523,7 +515,7 @@ class SQLAppConversationInfoService(AppConversationInfoService):
return AppConversationInfo(
id=UUID(stored.conversation_id),
created_by_user_id=stored.user_id if stored.user_id else None,
created_by_user_id=None, # User ID is now stored in ConversationMetadataSaas
sandbox_id=stored.sandbox_id,
selected_repository=stored.selected_repository,
selected_branch=stored.selected_branch,
@@ -567,13 +559,6 @@ class SQLAppConversationInfoService(AppConversationInfoService):
StoredConversationMetadata.conversation_id == str(conversation_id)
)
# Apply user security filtering - only allow deletion of conversations owned by the current user
user_id = await self.user_context.get_user_id()
if user_id:
delete_query = delete_query.where(
StoredConversationMetadata.user_id == user_id
)
# Execute the secure delete query
result = await self.db_session.execute(delete_query)

View File

@@ -0,0 +1,46 @@
"""Update conversation_metadata table to match StoredConversationMetadata dataclass
Revision ID: 004
Revises: 003
Create Date: 2025-11-11 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '004'
down_revision: Union[str, Sequence[str], None] = '003'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
with op.batch_alter_table('conversation_metadata') as batch_op:
# Drop columns not in StoredConversationMetadata dataclass
batch_op.drop_column('github_user_id')
# Alter user_id to become nullable
batch_op.alter_column(
'user_id',
existing_type=sa.String(),
nullable=True,
)
def downgrade() -> None:
"""Downgrade schema."""
with op.batch_alter_table('conversation_metadata') as batch_op:
# Add back removed column
batch_op.add_column(sa.Column('github_user_id', sa.String(), nullable=True))
# Restore NOT NULL constraint
batch_op.alter_column(
'user_id',
existing_type=sa.String(),
nullable=False,
)

View File

@@ -186,7 +186,20 @@ def config_from_env() -> AppServerConfig:
config.sandbox_spec = DockerSandboxSpecServiceInjector()
if config.app_conversation_info is None:
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
# Use enterprise injector if running in SAAS mode
if 'saas' in (os.getenv('OPENHANDS_CONFIG_CLS') or '').lower():
try:
# Import enterprise injector dynamically
from enterprise.storage.saas_app_conversation_info_injector import (
SaasAppConversationInfoServiceInjector,
)
config.app_conversation_info = SaasAppConversationInfoServiceInjector()
except ImportError:
# Fallback to OSS injector if enterprise module is not available
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
else:
config.app_conversation_info = SQLAppConversationInfoServiceInjector()
if config.app_conversation_start_task is None:
config.app_conversation_start_task = (

View File

@@ -153,15 +153,15 @@ class JwtService:
# Add standard JWT claims
now = utc_now()
if expires_in is None:
expires_in = timedelta(hours=1)
jwt_payload = {
**payload,
'iat': int(now.timestamp()),
'exp': int((now + expires_in).timestamp()),
}
# Only add exp if expires_in is provided
if expires_in is not None:
jwt_payload['exp'] = int((now + expires_in).timestamp())
# Get the raw key for JWE encryption and derive a 256-bit key
secret_key = self._keys[key_id].key.get_secret_value()
key_bytes = secret_key.encode() if isinstance(secret_key, str) else secret_key

View File

@@ -1,4 +1,5 @@
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.event import Event, EventSource
from openhands.events.recall_type import RecallType
from openhands.events.stream import EventStream, EventStreamSubscriber
__all__ = [

View File

@@ -3,7 +3,7 @@ from typing import Any
from openhands.core.schema import ActionType
from openhands.events.action.action import Action
from openhands.events.event import RecallType
from openhands.events.recall_type import RecallType
@dataclass

View File

@@ -1,9 +1,12 @@
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
from openhands.events.tool import ToolCallMetadata
from openhands.llm.metrics import Metrics
if TYPE_CHECKING:
from openhands.llm.metrics import Metrics
class EventSource(str, Enum):
@@ -22,16 +25,6 @@ class FileReadSource(str, Enum):
DEFAULT = 'default'
class RecallType(str, Enum):
"""The type of information that can be retrieved from microagents."""
WORKSPACE_CONTEXT = 'workspace_context'
"""Workspace context (repo instructions, runtime, etc.)"""
KNOWLEDGE = 'knowledge'
"""A knowledge microagent."""
@dataclass
class Event:
INVALID_ID = -1
@@ -97,14 +90,17 @@ class Event:
# optional metadata, LLM call cost of the edit
@property
def llm_metrics(self) -> Metrics | None:
def llm_metrics(self) -> 'Metrics | None':
if hasattr(self, '_llm_metrics'):
metrics = getattr(self, '_llm_metrics')
# Lazy import to avoid circular dependency
from openhands.llm.metrics import Metrics
return metrics if isinstance(metrics, Metrics) else None
return None
@llm_metrics.setter
def llm_metrics(self, value: Metrics) -> None:
def llm_metrics(self, value: 'Metrics') -> None:
self._llm_metrics = value
# optional field, metadata about the tool call, if the event has a tool call

View File

@@ -1,4 +1,3 @@
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
@@ -28,6 +27,7 @@ from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
from openhands.events.observation.task_tracking import TaskTrackingObservation
from openhands.events.recall_type import RecallType
__all__ = [
'Observation',

View File

@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
from openhands.core.schema import ObservationType
from openhands.events.event import RecallType
from openhands.events.observation.observation import Observation
from openhands.events.recall_type import RecallType
@dataclass

View File

@@ -0,0 +1,11 @@
from enum import Enum
class RecallType(str, Enum):
"""The type of information that can be retrieved from microagents."""
WORKSPACE_CONTEXT = 'workspace_context'
"""Workspace context (repo instructions, runtime, etc.)"""
KNOWLEDGE = 'knowledge'
"""A knowledge microagent."""

View File

@@ -10,7 +10,6 @@ from openhands.events.serialization.action import action_from_dict
from openhands.events.serialization.observation import observation_from_dict
from openhands.events.serialization.utils import remove_fields
from openhands.events.tool import ToolCallMetadata
from openhands.llm.metrics import Cost, Metrics, ResponseLatency, TokenUsage
# TODO: move `content` into `extras`
TOP_KEYS = [
@@ -67,6 +66,14 @@ def event_from_dict(data: dict[str, Any]) -> 'Event':
if key == 'tool_call_metadata':
value = ToolCallMetadata(**value)
if key == 'llm_metrics':
# Lazy import to avoid circular dependency
from openhands.llm.metrics import (
Cost,
Metrics,
ResponseLatency,
TokenUsage,
)
metrics = Metrics()
if isinstance(value, dict):
metrics.accumulated_cost = value.get('accumulated_cost', 0.0)

View File

@@ -1,7 +1,6 @@
import copy
from typing import Any
from openhands.events.event import RecallType
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
@@ -32,6 +31,7 @@ from openhands.events.observation.observation import Observation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.observation.success import SuccessObservation
from openhands.events.observation.task_tracking import TaskTrackingObservation
from openhands.events.recall_type import RecallType
observations = (
NullObservation,

View File

@@ -22,7 +22,7 @@ from openhands.events.action import (
)
from openhands.events.action.mcp import MCPAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event, RecallType
from openhands.events.event import Event
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
@@ -44,6 +44,7 @@ from openhands.events.observation.agent import (
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.mcp import MCPObservation
from openhands.events.observation.observation import Observation
from openhands.events.recall_type import RecallType
from openhands.events.serialization.event import truncate_content
from openhands.utils.prompt import (
ConversationInstructions,

View File

@@ -9,12 +9,13 @@ import openhands
from openhands.core.config.mcp_config import MCPConfig
from openhands.core.logger import openhands_logger as logger
from openhands.events.action.agent import RecallAction
from openhands.events.event import Event, EventSource, RecallType
from openhands.events.event import Event, EventSource
from openhands.events.observation.agent import (
MicroagentKnowledge,
RecallObservation,
)
from openhands.events.observation.empty import NullObservation
from openhands.events.recall_type import RecallType
from openhands.events.stream import EventStream, EventStreamSubscriber
from openhands.microagent import (
BaseMicroagent,

View File

@@ -151,7 +151,7 @@ class AgentSession:
await provider_handler.set_event_stream_secrets(self.event_stream)
if custom_secrets:
custom_secrets_handler.set_event_stream_secrets(self.event_stream)
self.event_stream.set_secrets(custom_secrets_handler.get_env_vars())
self.memory = await self._create_memory(
selected_repository=selected_repository,

View File

@@ -13,7 +13,6 @@ from pydantic import (
)
from pydantic.json import pydantic_encoder
from openhands.events.stream import EventStream
from openhands.integrations.provider import (
CUSTOM_SECRETS_TYPE,
PROVIDER_TOKEN_TYPE,
@@ -144,14 +143,6 @@ class Secrets(BaseModel):
return new_data
def set_event_stream_secrets(self, event_stream: EventStream) -> None:
"""This ensures that provider tokens and custom secrets masked from the event stream
Args:
event_stream: Agent session's event stream
"""
secrets = self.get_env_vars()
event_stream.set_secrets(secrets)
def get_env_vars(self) -> dict[str, str]:
secret_store = self.model_dump(context={'expose_secrets': True})
custom_secrets = secret_store.get('custom_secrets', {})

2
poetry.lock generated
View File

@@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand.
[[package]]
name = "aiofiles"

View File

@@ -199,9 +199,8 @@ class TestJwtService:
# Check that standard JWT claims are added
assert 'iat' in decrypted_payload
assert 'exp' in decrypted_payload
assert 'exp' not in decrypted_payload
assert isinstance(decrypted_payload['iat'], int) # JWE uses timestamp integers
assert isinstance(decrypted_payload['exp'], int)
def test_jwe_token_round_trip_specific_key(self, jwt_service):
"""Test JWE token creation and decryption with specific key."""
@@ -420,7 +419,7 @@ class TestJwtService:
# Should still have standard claims
assert 'iat' in jwe_decrypted
assert 'exp' in jwe_decrypted
assert 'exp' not in jwe_decrypted
def test_unicode_and_special_characters(self, jwt_service):
"""Test JWS and JWE with unicode and special characters."""

View File

@@ -27,6 +27,8 @@ from openhands.sdk.llm import MetricsSnapshot
from openhands.sdk.llm.utils.metrics import TokenUsage
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
# Note: org_id column exists but foreign key constraint is not enforced in tests
# Note: MetricsSnapshot from SDK is not available in test environment
# We'll use None for metrics field in tests since it's optional
@@ -106,7 +108,7 @@ def multiple_conversation_infos() -> list[AppConversationInfo]:
return [
AppConversationInfo(
id=uuid4(),
created_by_user_id='test_user_123',
created_by_user_id=None,
sandbox_id=f'sandbox_{i}',
selected_repository=f'https://github.com/test/repo{i}',
selected_branch='main',
@@ -151,10 +153,6 @@ class TestSQLAppConversationInfoService:
# Verify the retrieved info matches the original
assert retrieved_info is not None
assert retrieved_info.id == sample_conversation_info.id
assert (
retrieved_info.created_by_user_id
== sample_conversation_info.created_by_user_id
)
assert retrieved_info.sandbox_id == sample_conversation_info.sandbox_id
assert (
retrieved_info.selected_repository
@@ -206,7 +204,6 @@ class TestSQLAppConversationInfoService:
# Verify all fields
assert retrieved_info is not None
assert retrieved_info.id == original_info.id
assert retrieved_info.created_by_user_id == original_info.created_by_user_id
assert retrieved_info.sandbox_id == original_info.sandbox_id
assert retrieved_info.selected_repository == original_info.selected_repository
assert retrieved_info.selected_branch == original_info.selected_branch
@@ -235,7 +232,6 @@ class TestSQLAppConversationInfoService:
# Verify required fields
assert retrieved_info is not None
assert retrieved_info.id == minimal_info.id
assert retrieved_info.created_by_user_id == minimal_info.created_by_user_id
assert retrieved_info.sandbox_id == minimal_info.sandbox_id
# Verify optional fields are None or default values
@@ -486,58 +482,6 @@ class TestSQLAppConversationInfoService:
count = await service.count_app_conversation_info(title__contains='nonexistent')
assert count == 0
@pytest.mark.asyncio
async def test_user_isolation(
self,
async_session: AsyncSession,
multiple_conversation_infos: list[AppConversationInfo],
):
"""Test that user isolation works correctly."""
# Create services for different users
user1_service = SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id='user1')
)
user2_service = SQLAppConversationInfoService(
db_session=async_session, user_context=SpecifyUserContext(user_id='user2')
)
# Create conversations for different users
user1_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='user1',
sandbox_id='sandbox_user1',
title='User 1 Conversation',
)
user2_info = AppConversationInfo(
id=uuid4(),
created_by_user_id='user2',
sandbox_id='sandbox_user2',
title='User 2 Conversation',
)
# Save conversations
await user1_service.save_app_conversation_info(user1_info)
await user2_service.save_app_conversation_info(user2_info)
# User 1 should only see their conversation
user1_page = await user1_service.search_app_conversation_info()
assert len(user1_page.items) == 1
assert user1_page.items[0].created_by_user_id == 'user1'
# User 2 should only see their conversation
user2_page = await user2_service.search_app_conversation_info()
assert len(user2_page.items) == 1
assert user2_page.items[0].created_by_user_id == 'user2'
# User 1 should not be able to get user 2's conversation
user2_from_user1 = await user1_service.get_app_conversation_info(user2_info.id)
assert user2_from_user1 is None
# User 2 should not be able to get user 1's conversation
user1_from_user2 = await user2_service.get_app_conversation_info(user1_info.id)
assert user1_from_user2 is None
@pytest.mark.asyncio
async def test_update_conversation_info(
self,
@@ -567,10 +511,6 @@ class TestSQLAppConversationInfoService:
assert retrieved_info.pr_number == [789]
# Verify other fields remain unchanged
assert (
retrieved_info.created_by_user_id
== sample_conversation_info.created_by_user_id
)
assert retrieved_info.sandbox_id == sample_conversation_info.sandbox_id
@pytest.mark.asyncio

View File

@@ -75,7 +75,6 @@ async def v1_conversation_metadata(async_session, service):
conversation_id = uuid4()
stored = StoredConversationMetadata(
conversation_id=str(conversation_id),
user_id='test_user_123',
sandbox_id='sandbox_123',
conversation_version='V1',
title='Test Conversation',
@@ -267,7 +266,6 @@ class TestUpdateConversationStatistics:
conversation_id = uuid4()
stored = StoredConversationMetadata(
conversation_id=str(conversation_id),
user_id='test_user_123',
sandbox_id='sandbox_123',
conversation_version='V0', # V0 conversation
title='V0 Conversation',

View File

@@ -25,13 +25,13 @@ from openhands.events import Event, EventSource, EventStream, EventStreamSubscri
from openhands.events.action import ChangeAgentStateAction, CmdRunAction, MessageAction
from openhands.events.action.agent import CondensationAction, RecallAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import RecallType
from openhands.events.observation import (
AgentStateChangedObservation,
ErrorObservation,
)
from openhands.events.observation.agent import RecallObservation
from openhands.events.observation.empty import NullObservation
from openhands.events.recall_type import RecallType
from openhands.events.serialization import event_to_dict
from openhands.llm import LLM
from openhands.llm.llm_registry import LLMRegistry, RegistryEvent

View File

@@ -25,8 +25,9 @@ from openhands.events.action import (
from openhands.events.action.agent import RecallAction
from openhands.events.action.commands import CmdRunAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import Event, RecallType
from openhands.events.event import Event
from openhands.events.observation.agent import RecallObservation
from openhands.events.recall_type import RecallType
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.llm.llm_registry import LLMRegistry

View File

@@ -1,6 +1,5 @@
from openhands.core.schema.observation import ObservationType
from openhands.events.action.files import FileEditSource
from openhands.events.event import RecallType
from openhands.events.observation import (
CmdOutputMetadata,
CmdOutputObservation,
@@ -10,6 +9,7 @@ from openhands.events.observation import (
)
from openhands.events.observation.agent import MicroagentKnowledge
from openhands.events.observation.commands import MAX_CMD_OUTPUT_SIZE
from openhands.events.recall_type import RecallType
from openhands.events.serialization import (
event_from_dict,
event_to_dict,

View File

@@ -19,11 +19,11 @@ from openhands.events import EventSource
from openhands.events.action import CmdRunAction, MessageAction, RecallAction
from openhands.events.action.agent import CondensationAction
from openhands.events.action.message import SystemMessageAction
from openhands.events.event import RecallType
from openhands.events.observation import (
CmdOutputObservation,
RecallObservation,
)
from openhands.events.recall_type import RecallType
from openhands.memory.condenser.condenser import Condensation, View
from openhands.memory.condenser.impl.conversation_window_condenser import (
ConversationWindowCondenser,

View File

@@ -19,7 +19,6 @@ from openhands.events.event import (
EventSource,
FileEditSource,
FileReadSource,
RecallType,
)
from openhands.events.observation import CmdOutputObservation
from openhands.events.observation.agent import (
@@ -35,6 +34,7 @@ from openhands.events.observation.delegate import AgentDelegateObservation
from openhands.events.observation.error import ErrorObservation
from openhands.events.observation.files import FileEditObservation, FileReadObservation
from openhands.events.observation.reject import UserRejectObservation
from openhands.events.recall_type import RecallType
from openhands.events.tool import ToolCallMetadata
from openhands.memory.conversation_memory import ConversationMemory
from openhands.utils.prompt import PromptManager, RepositoryInfo, RuntimeInfo

View File

@@ -16,8 +16,8 @@ from openhands.events.action.message import MessageAction, SystemMessageAction
from openhands.events.event import EventSource
from openhands.events.observation.agent import (
RecallObservation,
RecallType,
)
from openhands.events.recall_type import RecallType
from openhands.events.serialization.observation import observation_from_dict
from openhands.events.stream import EventStream
from openhands.llm import LLM