Add org_id use in queries

This commit is contained in:
Chuck Butkus
2025-11-04 15:19:41 -05:00
parent d61b47a134
commit 69186bc6c8
6 changed files with 57 additions and 31 deletions

View File

@@ -9,6 +9,7 @@ from sqlalchemy import update
from sqlalchemy.orm import sessionmaker
from storage.api_key import ApiKey
from storage.database import session_maker
from storage.user_store import UserStore
from openhands.core.logger import openhands_logger as logger
@@ -36,10 +37,15 @@ class ApiKeyStore:
The generated API key
"""
api_key = self.generate_api_key()
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
key_record = ApiKey(
key=api_key, user_id=user_id, name=name, expires_at=expires_at
key=api_key,
user_id=user_id,
org_id=org_id,
name=name,
expires_at=expires_at,
)
session.add(key_record)
session.commit()
@@ -99,8 +105,15 @@ class ApiKeyStore:
def list_api_keys(self, user_id: str) -> list[dict]:
"""List all API keys for a user."""
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
keys = (
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
return [
{
@@ -115,9 +128,14 @@ class ApiKeyStore:
]
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
user = UserStore.get_user_by_id(user_id)
org_id = user.current_org_id
with self.session_maker() as session:
keys: list[ApiKey] = (
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
session.query(ApiKey)
.filter(ApiKey.user_id == user_id)
.filter(ApiKey.org_id == org_id)
.all()
)
for key in keys:
if key.name == 'MCP_API_KEY':

View File

@@ -2,7 +2,6 @@
Store class for managing organizations.
"""
import uuid
from typing import Optional
from uuid import UUID
@@ -45,7 +44,7 @@ class OrgStore:
user = (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(keycloak_user_id))
.filter(User.id == UUID(keycloak_user_id))
.first()
)
if not user:

View File

@@ -8,6 +8,7 @@ from fastapi import Request
from sqlalchemy import func, select
from storage.stored_conversation_metadata import StoredConversationMetadata
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
from storage.user import User
from openhands.app_server.app_conversation.app_conversation_info_service import (
AppConversationInfoService,
@@ -273,6 +274,10 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
if user_id_str:
# Convert string user_id to UUID
user_id_uuid = UUID(user_id_str)
user_query = select(User).where(User.id == user_id_uuid)
result = await self.db_session.execute(user_query)
user = result.scalar_one_or_none()
assert user
# Check if SAAS metadata already exists
saas_query = select(StoredConversationMetadataSaas).where(
@@ -282,7 +287,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
existing_saas_metadata = result.scalar_one_or_none()
assert existing_saas_metadata is None or (
existing_saas_metadata.user_id == user_id_uuid
and existing_saas_metadata.org_id == user_id_uuid
and existing_saas_metadata.org_id == user.current_org_id
)
if not existing_saas_metadata:
@@ -291,7 +296,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
saas_metadata = StoredConversationMetadataSaas(
conversation_id=str(info.id),
user_id=user_id_uuid,
org_id=user_id_uuid, # Set org_id to user_id as it will not be specified
org_id=user.current_org_id,
)
self.db_session.add(saas_metadata)

View File

@@ -8,6 +8,7 @@ from cryptography.fernet import Fernet
from sqlalchemy.orm import sessionmaker
from storage.database import session_maker
from storage.stored_custom_secrets import StoredCustomSecrets
from storage.user_store import UserStore
from openhands.core.config.openhands_config import OpenHandsConfig
from openhands.core.logger import openhands_logger as logger
@@ -24,12 +25,15 @@ class SaasSecretsStore(SecretsStore):
async def load(self) -> Secrets | None:
if not self.user_id:
return None
user = UserStore.get_user_by_id(self.user_id)
org_id = user.current_org_id
with self.session_maker() as session:
# Fetch all secrets for the given user ID
settings = (
session.query(StoredCustomSecrets)
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
.filter(StoredCustomSecrets.org_id == org_id)
.all()
)
@@ -48,6 +52,8 @@ class SaasSecretsStore(SecretsStore):
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
async def store(self, item: Secrets):
user = UserStore.get_user_by_id(self.user_id)
org_id = user.current_org_id
with self.session_maker() as session:
# Incoming secrets are always the most updated ones
# Delete all existing records and override with incoming ones
@@ -76,6 +82,7 @@ class SaasSecretsStore(SecretsStore):
for secret_name, secret_value, description in secret_tuples:
new_secret = StoredCustomSecrets(
keycloak_user_id=self.user_id,
org_id=org_id,
secret_name=secret_name,
secret_value=secret_value,
description=description,

View File

@@ -27,7 +27,7 @@ class UserStore:
@staticmethod
async def create_user(
keycloak_user_id: str,
user_id: str,
user_info: dict,
role_id: Optional[int] = None,
) -> User | None:
@@ -35,15 +35,15 @@ class UserStore:
with session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(keycloak_user_id),
name=f'user_{keycloak_user_id}_org',
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_name=user_info['preferred_username'],
contact_email=user_info['email'],
)
session.add(org)
settings = await UserStore.create_default_settings(
org_id=str(org.id), keycloak_user_id=keycloak_user_id
org_id=str(org.id), user_id=user_id
)
if not settings:
@@ -56,7 +56,7 @@ class UserStore:
user_kwargs = UserStore.get_kwargs_from_settings(settings)
user = User(
id=uuid.UUID(keycloak_user_id),
id=uuid.UUID(user_id),
current_org_id=org.id,
role_id=role_id,
**user_kwargs,
@@ -80,17 +80,17 @@ class UserStore:
@staticmethod
async def migrate_user(
keycloak_user_id: str,
user_id: str,
user_settings: UserSettings,
user_info: dict,
) -> User:
if not keycloak_user_id or not user_settings:
if not user_id or not user_settings:
return None
# Check if user is already migrated to prevent double migration
if user_settings.migration_status is True:
logger.warning(f'User {keycloak_user_id} already migrated, skipping')
return UserStore.get_user_by_id(keycloak_user_id)
logger.warning(f'User {user_id} already migrated, skipping')
return UserStore.get_user_by_id(user_id)
kwargs = decrypt_model(
[
'llm_api_key',
@@ -104,18 +104,18 @@ class UserStore:
with session_maker() as session:
# create personal org
org = Org(
id=uuid.UUID(keycloak_user_id),
name=f'user_{keycloak_user_id}_org',
id=uuid.UUID(user_id),
name=f'user_{user_id}_org',
contact_name=user_info['preferred_username'],
contact_email=user_info['email'],
)
session.add(org)
await LiteLlmManager.migrate_entries(
str(org.id), keycloak_user_id, decrypted_user_settings
str(org.id), user_id, decrypted_user_settings
)
await migrate_customer(session, keycloak_user_id, org)
await migrate_customer(session, user_id, org)
org_kwargs = OrgStore.get_kwargs_from_settings(decrypted_user_settings)
org_kwargs.pop('id', None)
@@ -126,7 +126,7 @@ class UserStore:
user_kwargs = UserStore.get_kwargs_from_settings(decrypted_user_settings)
user_kwargs.pop('id', None)
user = User(
id=uuid.UUID(keycloak_user_id),
id=uuid.UUID(user_id),
current_org_id=org.id,
role_id=None,
**user_kwargs,
@@ -178,13 +178,13 @@ class UserStore:
return user
@staticmethod
def get_user_by_id(keycloak_user_id: str) -> Optional[User]:
def get_user_by_id(user_id: str) -> Optional[User]:
"""Get user by Keycloak user ID."""
with session_maker() as session:
return (
session.query(User)
.options(joinedload(User.org_members))
.filter(User.id == uuid.UUID(keycloak_user_id))
.filter(User.id == uuid.UUID(user_id))
.first()
)
@@ -195,12 +195,10 @@ class UserStore:
return session.query(User).all()
@staticmethod
async def create_default_settings(
org_id: str, keycloak_user_id: str
) -> Optional[Settings]:
async def create_default_settings(org_id: str, user_id: str) -> Optional[Settings]:
logger.info(
'UserStore:create_default_settings:start',
extra={'org_id': org_id, 'user_id': keycloak_user_id},
extra={'org_id': org_id, 'user_id': user_id},
)
# You must log in before you get default settings
if not org_id:
@@ -208,9 +206,7 @@ class UserStore:
settings = Settings(language='en', enable_proactive_conversation_starters=True)
settings = await LiteLlmManager.create_entries(
org_id, keycloak_user_id, settings
)
settings = await LiteLlmManager.create_entries(org_id, user_id, settings)
if not settings:
logger.info(
'UserStore:create_default_settings:litellm_create_failed',

View File

@@ -228,6 +228,7 @@ class TestSaasSQLAppConversationInfoService:
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
# Test the _to_info_with_user_id method
result = saas_service_user1._to_info_with_user_id(