mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Add org_id use in queries
This commit is contained in:
@@ -9,6 +9,7 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.api_key import ApiKey
|
||||
from storage.database import session_maker
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -36,10 +37,15 @@ class ApiKeyStore:
|
||||
The generated API key
|
||||
"""
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key, user_id=user_id, name=name, expires_at=expires_at
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=name,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(key_record)
|
||||
session.commit()
|
||||
@@ -99,8 +105,15 @@ class ApiKeyStore:
|
||||
|
||||
def list_api_keys(self, user_id: str) -> list[dict]:
|
||||
"""List all API keys for a user."""
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys = session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
keys = (
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
@@ -115,9 +128,14 @@ class ApiKeyStore:
|
||||
]
|
||||
|
||||
def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
keys: list[ApiKey] = (
|
||||
session.query(ApiKey).filter(ApiKey.user_id == user_id).all()
|
||||
session.query(ApiKey)
|
||||
.filter(ApiKey.user_id == user_id)
|
||||
.filter(ApiKey.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
for key in keys:
|
||||
if key.name == 'MCP_API_KEY':
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Store class for managing organizations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
@@ -45,7 +44,7 @@ class OrgStore:
|
||||
user = (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(keycloak_user_id))
|
||||
.filter(User.id == UUID(keycloak_user_id))
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi import Request
|
||||
from sqlalchemy import func, select
|
||||
from storage.stored_conversation_metadata import StoredConversationMetadata
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
from storage.user import User
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_info_service import (
|
||||
AppConversationInfoService,
|
||||
@@ -273,6 +274,10 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
user_query = select(User).where(User.id == user_id_uuid)
|
||||
result = await self.db_session.execute(user_query)
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
@@ -282,7 +287,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
@@ -291,7 +296,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=user_id_uuid, # Set org_id to user_id as it will not be specified
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from cryptography.fernet import Fernet
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from storage.database import session_maker
|
||||
from storage.stored_custom_secrets import StoredCustomSecrets
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -24,12 +25,15 @@ class SaasSecretsStore(SecretsStore):
|
||||
async def load(self) -> Secrets | None:
|
||||
if not self.user_id:
|
||||
return None
|
||||
user = UserStore.get_user_by_id(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Fetch all secrets for the given user ID
|
||||
settings = (
|
||||
session.query(StoredCustomSecrets)
|
||||
.filter(StoredCustomSecrets.keycloak_user_id == self.user_id)
|
||||
.filter(StoredCustomSecrets.org_id == org_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -48,6 +52,8 @@ class SaasSecretsStore(SecretsStore):
|
||||
return Secrets(custom_secrets=kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def store(self, item: Secrets):
|
||||
user = UserStore.get_user_by_id(self.user_id)
|
||||
org_id = user.current_org_id
|
||||
with self.session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete all existing records and override with incoming ones
|
||||
@@ -76,6 +82,7 @@ class SaasSecretsStore(SecretsStore):
|
||||
for secret_name, secret_value, description in secret_tuples:
|
||||
new_secret = StoredCustomSecrets(
|
||||
keycloak_user_id=self.user_id,
|
||||
org_id=org_id,
|
||||
secret_name=secret_name,
|
||||
secret_value=secret_value,
|
||||
description=description,
|
||||
|
||||
@@ -27,7 +27,7 @@ class UserStore:
|
||||
|
||||
@staticmethod
|
||||
async def create_user(
|
||||
keycloak_user_id: str,
|
||||
user_id: str,
|
||||
user_info: dict,
|
||||
role_id: Optional[int] = None,
|
||||
) -> User | None:
|
||||
@@ -35,15 +35,15 @@ class UserStore:
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
name=f'user_{keycloak_user_id}_org',
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
settings = await UserStore.create_default_settings(
|
||||
org_id=str(org.id), keycloak_user_id=keycloak_user_id
|
||||
org_id=str(org.id), user_id=user_id
|
||||
)
|
||||
|
||||
if not settings:
|
||||
@@ -56,7 +56,7 @@ class UserStore:
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(settings)
|
||||
user = User(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=role_id,
|
||||
**user_kwargs,
|
||||
@@ -80,17 +80,17 @@ class UserStore:
|
||||
|
||||
@staticmethod
|
||||
async def migrate_user(
|
||||
keycloak_user_id: str,
|
||||
user_id: str,
|
||||
user_settings: UserSettings,
|
||||
user_info: dict,
|
||||
) -> User:
|
||||
if not keycloak_user_id or not user_settings:
|
||||
if not user_id or not user_settings:
|
||||
return None
|
||||
|
||||
# Check if user is already migrated to prevent double migration
|
||||
if user_settings.migration_status is True:
|
||||
logger.warning(f'User {keycloak_user_id} already migrated, skipping')
|
||||
return UserStore.get_user_by_id(keycloak_user_id)
|
||||
logger.warning(f'User {user_id} already migrated, skipping')
|
||||
return UserStore.get_user_by_id(user_id)
|
||||
kwargs = decrypt_model(
|
||||
[
|
||||
'llm_api_key',
|
||||
@@ -104,18 +104,18 @@ class UserStore:
|
||||
with session_maker() as session:
|
||||
# create personal org
|
||||
org = Org(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
name=f'user_{keycloak_user_id}_org',
|
||||
id=uuid.UUID(user_id),
|
||||
name=f'user_{user_id}_org',
|
||||
contact_name=user_info['preferred_username'],
|
||||
contact_email=user_info['email'],
|
||||
)
|
||||
session.add(org)
|
||||
|
||||
await LiteLlmManager.migrate_entries(
|
||||
str(org.id), keycloak_user_id, decrypted_user_settings
|
||||
str(org.id), user_id, decrypted_user_settings
|
||||
)
|
||||
|
||||
await migrate_customer(session, keycloak_user_id, org)
|
||||
await migrate_customer(session, user_id, org)
|
||||
|
||||
org_kwargs = OrgStore.get_kwargs_from_settings(decrypted_user_settings)
|
||||
org_kwargs.pop('id', None)
|
||||
@@ -126,7 +126,7 @@ class UserStore:
|
||||
user_kwargs = UserStore.get_kwargs_from_settings(decrypted_user_settings)
|
||||
user_kwargs.pop('id', None)
|
||||
user = User(
|
||||
id=uuid.UUID(keycloak_user_id),
|
||||
id=uuid.UUID(user_id),
|
||||
current_org_id=org.id,
|
||||
role_id=None,
|
||||
**user_kwargs,
|
||||
@@ -178,13 +178,13 @@ class UserStore:
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def get_user_by_id(keycloak_user_id: str) -> Optional[User]:
|
||||
def get_user_by_id(user_id: str) -> Optional[User]:
|
||||
"""Get user by Keycloak user ID."""
|
||||
with session_maker() as session:
|
||||
return (
|
||||
session.query(User)
|
||||
.options(joinedload(User.org_members))
|
||||
.filter(User.id == uuid.UUID(keycloak_user_id))
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.first()
|
||||
)
|
||||
|
||||
@@ -195,12 +195,10 @@ class UserStore:
|
||||
return session.query(User).all()
|
||||
|
||||
@staticmethod
|
||||
async def create_default_settings(
|
||||
org_id: str, keycloak_user_id: str
|
||||
) -> Optional[Settings]:
|
||||
async def create_default_settings(org_id: str, user_id: str) -> Optional[Settings]:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:start',
|
||||
extra={'org_id': org_id, 'user_id': keycloak_user_id},
|
||||
extra={'org_id': org_id, 'user_id': user_id},
|
||||
)
|
||||
# You must log in before you get default settings
|
||||
if not org_id:
|
||||
@@ -208,9 +206,7 @@ class UserStore:
|
||||
|
||||
settings = Settings(language='en', enable_proactive_conversation_starters=True)
|
||||
|
||||
settings = await LiteLlmManager.create_entries(
|
||||
org_id, keycloak_user_id, settings
|
||||
)
|
||||
settings = await LiteLlmManager.create_entries(org_id, user_id, settings)
|
||||
if not settings:
|
||||
logger.info(
|
||||
'UserStore:create_default_settings:litellm_create_failed',
|
||||
|
||||
@@ -228,6 +228,7 @@ class TestSaasSQLAppConversationInfoService:
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
saas_metadata.org_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
|
||||
# Test the _to_info_with_user_id method
|
||||
result = saas_service_user1._to_info_with_user_id(
|
||||
|
||||
Reference in New Issue
Block a user