From aa0b2d0b74c49c5ae3bc6d84bd15352c018f82d4 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Tue, 10 Feb 2026 14:22:52 +0700 Subject: [PATCH] feat(backend): add api for switching between orgs (#12799) --- enterprise/server/routes/orgs.py | 75 ++++++++++ enterprise/storage/org_service.py | 68 +++++++++ enterprise/storage/user_store.py | 27 ++++ .../tests/unit/server/routes/test_orgs.py | 133 ++++++++++++++++++ enterprise/tests/unit/test_org_service.py | 111 +++++++++++++++ enterprise/tests/unit/test_user_store.py | 43 ++++++ 6 files changed, 457 insertions(+) diff --git a/enterprise/server/routes/orgs.py b/enterprise/server/routes/orgs.py index 4d00cb9764..44b2fdc513 100644 --- a/enterprise/server/routes/orgs.py +++ b/enterprise/server/routes/orgs.py @@ -603,3 +603,78 @@ async def remove_org_member( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail='Failed to remove member', ) + + +@org_router.post( + '/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK +) +async def switch_org( + org_id: UUID, + user_id: str = Depends(get_user_id), +) -> OrgResponse: + """Switch to a different organization. + + This endpoint allows authenticated users to switch their current active + organization. The user must be a member of the target organization. + + Args: + org_id: Organization ID to switch to (UUID) + user_id: Authenticated user ID (injected by dependency) + + Returns: + OrgResponse: The organization details that was switched to + + Raises: + HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI) + HTTPException: 403 if user is not a member of the organization + HTTPException: 404 if organization not found + HTTPException: 500 if switch fails + """ + logger.info( + 'Switching organization', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + }, + ) + + try: + # Use service layer to switch organization with membership validation + org = await OrgService.switch_org( + user_id=user_id, + org_id=org_id, + ) + + # Retrieve credits from LiteLLM for the new current org + credits = await OrgService.get_org_credits(user_id, org.id) + + return OrgResponse.from_org(org, credits=credits, user_id=user_id) + + except OrgNotFoundError as e: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=str(e), + ) + except OrgAuthorizationError as e: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) + except OrgDatabaseError as e: + logger.error( + 'Database operation failed during organization switch', + extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='Failed to switch organization', + ) + except Exception as e: + logger.exception( + 'Unexpected error switching organization', + extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)}, + ) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail='An unexpected error occurred', + ) diff --git a/enterprise/storage/org_service.py b/enterprise/storage/org_service.py index af25c05a91..1537607f94 100644 --- a/enterprise/storage/org_service.py +++ b/enterprise/storage/org_service.py @@ -865,3 +865,71 @@ class OrgService: return False return org.byor_export_enabled + + @staticmethod + async def switch_org(user_id: str, org_id: UUID) -> Org: + """ + Switch user's current organization to the specified organization. + + This method: + 1. Validates that the organization exists + 2. Validates that the user is a member of the organization + 3. Updates the user's current_org_id + + Args: + user_id: User ID (string that will be converted to UUID) + org_id: Organization ID to switch to + + Returns: + Org: The organization that was switched to + + Raises: + OrgNotFoundError: If organization doesn't exist + OrgAuthorizationError: If user is not a member of the organization + OrgDatabaseError: If database update fails + """ + logger.info( + 'Switching user organization', + extra={'user_id': user_id, 'org_id': str(org_id)}, + ) + + # Step 1: Check if organization exists + org = OrgStore.get_org_by_id(org_id) + if not org: + raise OrgNotFoundError(str(org_id)) + + # Step 2: Validate user is a member of the organization + if not OrgService.is_org_member(user_id, org_id): + logger.warning( + 'User attempted to switch to organization they are not a member of', + extra={'user_id': user_id, 'org_id': str(org_id)}, + ) + raise OrgAuthorizationError( + 'User must be a member of the organization to switch to it' + ) + + # Step 3: Update user's current_org_id + try: + updated_user = UserStore.update_current_org(user_id, org_id) + if not updated_user: + raise OrgDatabaseError('User not found') + + logger.info( + 'Successfully switched user organization', + extra={ + 'user_id': user_id, + 'org_id': str(org_id), + 'org_name': org.name, + }, + ) + + return org + + except OrgDatabaseError: + raise + except Exception as e: + logger.error( + 'Failed to switch user organization', + extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)}, + ) + raise OrgDatabaseError(f'Failed to switch organization: {str(e)}') diff --git a/enterprise/storage/user_store.py b/enterprise/storage/user_store.py index fd909b11cc..651e98176d 100644 --- a/enterprise/storage/user_store.py +++ b/enterprise/storage/user_store.py @@ -5,6 +5,7 @@ Store class for managing users. import asyncio import uuid from typing import Optional +from uuid import UUID from server.auth.token_manager import TokenManager from server.constants import ( @@ -773,6 +774,32 @@ class UserStore: with session_maker() as session: return session.query(User).all() + @staticmethod + def update_current_org(user_id: str, org_id: UUID) -> Optional[User]: + """Update the user's current organization. + + Args: + user_id: The user's ID (Keycloak user ID) + org_id: The organization ID to set as current + + Returns: + User: The updated user object, or None if user not found + """ + with session_maker() as session: + user = ( + session.query(User) + .filter(User.id == uuid.UUID(user_id)) + .with_for_update() + .first() + ) + if not user: + return None + + user.current_org_id = org_id + session.commit() + session.refresh(user) + return user + @staticmethod async def backfill_contact_name(user_id: str, user_info: dict) -> None: """Update contact_name on the personal org if it still has a username-style value. diff --git a/enterprise/tests/unit/server/routes/test_orgs.py b/enterprise/tests/unit/server/routes/test_orgs.py index fad3dffc65..4364737109 100644 --- a/enterprise/tests/unit/server/routes/test_orgs.py +++ b/enterprise/tests/unit/server/routes/test_orgs.py @@ -2635,3 +2635,136 @@ class TestGetMeEndpoint: await get_me(org_id=test_org_id, user_id=test_user_id) assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + + +@pytest.mark.asyncio +async def test_switch_org_success(mock_app_with_get_user_id): + """ + GIVEN: Valid org_id and authenticated user who is a member + WHEN: POST /api/organizations/{org_id}/switch is called + THEN: User's current org is switched and org details returned with 200 status + """ + # Arrange + org_id = uuid.uuid4() + mock_org = Org( + id=org_id, + name='Target Organization', + contact_name='John Doe', + contact_email='john@example.com', + org_version=5, + default_llm_model='claude-opus-4-5-20251101', + ) + + with ( + patch( + 'server.routes.orgs.OrgService.switch_org', + AsyncMock(return_value=mock_org), + ), + patch( + 'server.routes.orgs.OrgService.get_org_credits', + AsyncMock(return_value=100.0), + ), + ): + client = TestClient(mock_app_with_get_user_id) + + # Act + response = client.post(f'/api/organizations/{org_id}/switch') + + # Assert + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + assert response_data['id'] == str(org_id) + assert response_data['name'] == 'Target Organization' + assert response_data['credits'] == 100.0 + + +@pytest.mark.asyncio +async def test_switch_org_not_member(mock_app_with_get_user_id): + """ + GIVEN: User is not a member of the target organization + WHEN: POST /api/organizations/{org_id}/switch is called + THEN: 403 Forbidden error is returned + """ + # Arrange + org_id = uuid.uuid4() + + with patch( + 'server.routes.orgs.OrgService.switch_org', + AsyncMock( + side_effect=OrgAuthorizationError( + 'User must be a member of the organization to switch to it' + ) + ), + ): + client = TestClient(mock_app_with_get_user_id) + + # Act + response = client.post(f'/api/organizations/{org_id}/switch') + + # Assert + assert response.status_code == status.HTTP_403_FORBIDDEN + assert 'member' in response.json()['detail'].lower() + + +@pytest.mark.asyncio +async def test_switch_org_not_found(mock_app_with_get_user_id): + """ + GIVEN: Organization does not exist + WHEN: POST /api/organizations/{org_id}/switch is called + THEN: 404 Not Found error is returned + """ + # Arrange + org_id = uuid.uuid4() + + with patch( + 'server.routes.orgs.OrgService.switch_org', + AsyncMock(side_effect=OrgNotFoundError(str(org_id))), + ): + client = TestClient(mock_app_with_get_user_id) + + # Act + response = client.post(f'/api/organizations/{org_id}/switch') + + # Assert + assert response.status_code == status.HTTP_404_NOT_FOUND + + +@pytest.mark.asyncio +async def test_switch_org_invalid_uuid(mock_app_with_get_user_id): + """ + GIVEN: Invalid UUID format for org_id + WHEN: POST /api/organizations/{org_id}/switch is called + THEN: 422 Unprocessable Entity error is returned + """ + # Arrange + client = TestClient(mock_app_with_get_user_id) + + # Act + response = client.post('/api/organizations/not-a-valid-uuid/switch') + + # Assert + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_switch_org_database_error(mock_app_with_get_user_id): + """ + GIVEN: Database operation fails during switch + WHEN: POST /api/organizations/{org_id}/switch is called + THEN: 500 Internal Server Error is returned + """ + # Arrange + org_id = uuid.uuid4() + + with patch( + 'server.routes.orgs.OrgService.switch_org', + AsyncMock(side_effect=OrgDatabaseError('Database connection failed')), + ): + client = TestClient(mock_app_with_get_user_id) + + # Act + response = client.post(f'/api/organizations/{org_id}/switch') + + # Assert + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert 'Failed to switch organization' in response.json()['detail'] diff --git a/enterprise/tests/unit/test_org_service.py b/enterprise/tests/unit/test_org_service.py index fdc8b598bb..8e0438ba1d 100644 --- a/enterprise/tests/unit/test_org_service.py +++ b/enterprise/tests/unit/test_org_service.py @@ -1801,3 +1801,114 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found(): # Assert assert result is False + + +@pytest.mark.asyncio +async def test_switch_org_success(): + """ + GIVEN: Valid org_id and user_id where user is a member + WHEN: switch_org is called + THEN: User's current_org_id is updated and org is returned + """ + # Arrange + org_id = uuid.uuid4() + user_id = str(uuid.uuid4()) + mock_org = Org( + id=org_id, + name='Target Organization', + contact_name='John Doe', + contact_email='john@example.com', + ) + mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id) + + with ( + patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch('storage.org_service.OrgService.is_org_member', return_value=True), + patch( + 'storage.org_service.UserStore.update_current_org', + return_value=mock_updated_user, + ), + ): + # Act + result = await OrgService.switch_org(user_id, org_id) + + # Assert + assert result is not None + assert result.id == org_id + assert result.name == 'Target Organization' + + +@pytest.mark.asyncio +async def test_switch_org_org_not_found(): + """ + GIVEN: Organization does not exist + WHEN: switch_org is called + THEN: OrgNotFoundError is raised + """ + # Arrange + org_id = uuid.uuid4() + user_id = str(uuid.uuid4()) + + with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None): + # Act & Assert + with pytest.raises(OrgNotFoundError) as exc_info: + await OrgService.switch_org(user_id, org_id) + + assert str(org_id) in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_switch_org_user_not_member(): + """ + GIVEN: User is not a member of the organization + WHEN: switch_org is called + THEN: OrgAuthorizationError is raised + """ + # Arrange + org_id = uuid.uuid4() + user_id = str(uuid.uuid4()) + mock_org = Org( + id=org_id, + name='Target Organization', + contact_name='John Doe', + contact_email='john@example.com', + ) + + with ( + patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch('storage.org_service.OrgService.is_org_member', return_value=False), + ): + # Act & Assert + with pytest.raises(OrgAuthorizationError) as exc_info: + await OrgService.switch_org(user_id, org_id) + + assert 'member' in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_switch_org_user_not_found(): + """ + GIVEN: User does not exist in database + WHEN: switch_org is called + THEN: OrgDatabaseError is raised + """ + # Arrange + org_id = uuid.uuid4() + user_id = str(uuid.uuid4()) + mock_org = Org( + id=org_id, + name='Target Organization', + contact_name='John Doe', + contact_email='john@example.com', + ) + + with ( + patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org), + patch('storage.org_service.OrgService.is_org_member', return_value=True), + patch('storage.org_service.UserStore.update_current_org', return_value=None), + ): + # Act & Assert + with pytest.raises(OrgDatabaseError) as exc_info: + await OrgService.switch_org(user_id, org_id) + + assert 'User not found' in str(exc_info.value) diff --git a/enterprise/tests/unit/test_user_store.py b/enterprise/tests/unit/test_user_store.py index 5aa863e849..48a875b0c4 100644 --- a/enterprise/tests/unit/test_user_store.py +++ b/enterprise/tests/unit/test_user_store.py @@ -522,3 +522,46 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker): with session_maker() as session: org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first() assert org.contact_name == 'Custom Corp Name' + + +def test_update_current_org_success(session_maker): + """ + GIVEN: User exists in database + WHEN: update_current_org is called with new org_id + THEN: User's current_org_id is updated and user is returned + """ + # Arrange + user_id = str(uuid.uuid4()) + initial_org_id = uuid.uuid4() + new_org_id = uuid.uuid4() + + with session_maker() as session: + user = User(id=uuid.UUID(user_id), current_org_id=initial_org_id) + session.add(user) + session.commit() + + # Act + with patch('storage.user_store.session_maker', session_maker): + result = UserStore.update_current_org(user_id, new_org_id) + + # Assert + assert result is not None + assert result.current_org_id == new_org_id + + +def test_update_current_org_user_not_found(session_maker): + """ + GIVEN: User does not exist in database + WHEN: update_current_org is called + THEN: None is returned + """ + # Arrange + user_id = str(uuid.uuid4()) + org_id = uuid.uuid4() + + # Act + with patch('storage.user_store.session_maker', session_maker): + result = UserStore.update_current_org(user_id, org_id) + + # Assert + assert result is None