feat(backend): add api for switching between orgs (#12799)

This commit is contained in:
Hiep Le
2026-02-10 14:22:52 +07:00
committed by GitHub
parent bef9b80b9d
commit aa0b2d0b74
6 changed files with 457 additions and 0 deletions

View File

@@ -603,3 +603,78 @@ async def remove_org_member(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to remove member',
)
@org_router.post(
'/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK
)
async def switch_org(
org_id: UUID,
user_id: str = Depends(get_user_id),
) -> OrgResponse:
"""Switch to a different organization.
This endpoint allows authenticated users to switch their current active
organization. The user must be a member of the target organization.
Args:
org_id: Organization ID to switch to (UUID)
user_id: Authenticated user ID (injected by dependency)
Returns:
OrgResponse: The organization details that was switched to
Raises:
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
HTTPException: 403 if user is not a member of the organization
HTTPException: 404 if organization not found
HTTPException: 500 if switch fails
"""
logger.info(
'Switching organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
},
)
try:
# Use service layer to switch organization with membership validation
org = await OrgService.switch_org(
user_id=user_id,
org_id=org_id,
)
# Retrieve credits from LiteLLM for the new current org
credits = await OrgService.get_org_credits(user_id, org.id)
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
except OrgNotFoundError as e:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=str(e),
)
except OrgAuthorizationError as e:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=str(e),
)
except OrgDatabaseError as e:
logger.error(
'Database operation failed during organization switch',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='Failed to switch organization',
)
except Exception as e:
logger.exception(
'Unexpected error switching organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail='An unexpected error occurred',
)

View File

@@ -865,3 +865,71 @@ class OrgService:
return False
return org.byor_export_enabled
@staticmethod
async def switch_org(user_id: str, org_id: UUID) -> Org:
"""
Switch user's current organization to the specified organization.
This method:
1. Validates that the organization exists
2. Validates that the user is a member of the organization
3. Updates the user's current_org_id
Args:
user_id: User ID (string that will be converted to UUID)
org_id: Organization ID to switch to
Returns:
Org: The organization that was switched to
Raises:
OrgNotFoundError: If organization doesn't exist
OrgAuthorizationError: If user is not a member of the organization
OrgDatabaseError: If database update fails
"""
logger.info(
'Switching user organization',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
# Step 1: Check if organization exists
org = OrgStore.get_org_by_id(org_id)
if not org:
raise OrgNotFoundError(str(org_id))
# Step 2: Validate user is a member of the organization
if not OrgService.is_org_member(user_id, org_id):
logger.warning(
'User attempted to switch to organization they are not a member of',
extra={'user_id': user_id, 'org_id': str(org_id)},
)
raise OrgAuthorizationError(
'User must be a member of the organization to switch to it'
)
# Step 3: Update user's current_org_id
try:
updated_user = UserStore.update_current_org(user_id, org_id)
if not updated_user:
raise OrgDatabaseError('User not found')
logger.info(
'Successfully switched user organization',
extra={
'user_id': user_id,
'org_id': str(org_id),
'org_name': org.name,
},
)
return org
except OrgDatabaseError:
raise
except Exception as e:
logger.error(
'Failed to switch user organization',
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
)
raise OrgDatabaseError(f'Failed to switch organization: {str(e)}')

View File

@@ -5,6 +5,7 @@ Store class for managing users.
import asyncio
import uuid
from typing import Optional
from uuid import UUID
from server.auth.token_manager import TokenManager
from server.constants import (
@@ -773,6 +774,32 @@ class UserStore:
with session_maker() as session:
return session.query(User).all()
@staticmethod
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
"""Update the user's current organization.
Args:
user_id: The user's ID (Keycloak user ID)
org_id: The organization ID to set as current
Returns:
User: The updated user object, or None if user not found
"""
with session_maker() as session:
user = (
session.query(User)
.filter(User.id == uuid.UUID(user_id))
.with_for_update()
.first()
)
if not user:
return None
user.current_org_id = org_id
session.commit()
session.refresh(user)
return user
@staticmethod
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
"""Update contact_name on the personal org if it still has a username-style value.

View File

@@ -2635,3 +2635,136 @@ class TestGetMeEndpoint:
await get_me(org_id=test_org_id, user_id=test_user_id)
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
@pytest.mark.asyncio
async def test_switch_org_success(mock_app_with_get_user_id):
"""
GIVEN: Valid org_id and authenticated user who is a member
WHEN: POST /api/organizations/{org_id}/switch is called
THEN: User's current org is switched and org details returned with 200 status
"""
# Arrange
org_id = uuid.uuid4()
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
org_version=5,
default_llm_model='claude-opus-4-5-20251101',
)
with (
patch(
'server.routes.orgs.OrgService.switch_org',
AsyncMock(return_value=mock_org),
),
patch(
'server.routes.orgs.OrgService.get_org_credits',
AsyncMock(return_value=100.0),
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(f'/api/organizations/{org_id}/switch')
# Assert
assert response.status_code == status.HTTP_200_OK
response_data = response.json()
assert response_data['id'] == str(org_id)
assert response_data['name'] == 'Target Organization'
assert response_data['credits'] == 100.0
@pytest.mark.asyncio
async def test_switch_org_not_member(mock_app_with_get_user_id):
"""
GIVEN: User is not a member of the target organization
WHEN: POST /api/organizations/{org_id}/switch is called
THEN: 403 Forbidden error is returned
"""
# Arrange
org_id = uuid.uuid4()
with patch(
'server.routes.orgs.OrgService.switch_org',
AsyncMock(
side_effect=OrgAuthorizationError(
'User must be a member of the organization to switch to it'
)
),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(f'/api/organizations/{org_id}/switch')
# Assert
assert response.status_code == status.HTTP_403_FORBIDDEN
assert 'member' in response.json()['detail'].lower()
@pytest.mark.asyncio
async def test_switch_org_not_found(mock_app_with_get_user_id):
"""
GIVEN: Organization does not exist
WHEN: POST /api/organizations/{org_id}/switch is called
THEN: 404 Not Found error is returned
"""
# Arrange
org_id = uuid.uuid4()
with patch(
'server.routes.orgs.OrgService.switch_org',
AsyncMock(side_effect=OrgNotFoundError(str(org_id))),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(f'/api/organizations/{org_id}/switch')
# Assert
assert response.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_switch_org_invalid_uuid(mock_app_with_get_user_id):
"""
GIVEN: Invalid UUID format for org_id
WHEN: POST /api/organizations/{org_id}/switch is called
THEN: 422 Unprocessable Entity error is returned
"""
# Arrange
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post('/api/organizations/not-a-valid-uuid/switch')
# Assert
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.asyncio
async def test_switch_org_database_error(mock_app_with_get_user_id):
"""
GIVEN: Database operation fails during switch
WHEN: POST /api/organizations/{org_id}/switch is called
THEN: 500 Internal Server Error is returned
"""
# Arrange
org_id = uuid.uuid4()
with patch(
'server.routes.orgs.OrgService.switch_org',
AsyncMock(side_effect=OrgDatabaseError('Database connection failed')),
):
client = TestClient(mock_app_with_get_user_id)
# Act
response = client.post(f'/api/organizations/{org_id}/switch')
# Assert
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert 'Failed to switch organization' in response.json()['detail']

View File

@@ -1801,3 +1801,114 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found():
# Assert
assert result is False
@pytest.mark.asyncio
async def test_switch_org_success():
"""
GIVEN: Valid org_id and user_id where user is a member
WHEN: switch_org is called
THEN: User's current_org_id is updated and org is returned
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id)
with (
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch(
'storage.org_service.UserStore.update_current_org',
return_value=mock_updated_user,
),
):
# Act
result = await OrgService.switch_org(user_id, org_id)
# Assert
assert result is not None
assert result.id == org_id
assert result.name == 'Target Organization'
@pytest.mark.asyncio
async def test_switch_org_org_not_found():
"""
GIVEN: Organization does not exist
WHEN: switch_org is called
THEN: OrgNotFoundError is raised
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None):
# Act & Assert
with pytest.raises(OrgNotFoundError) as exc_info:
await OrgService.switch_org(user_id, org_id)
assert str(org_id) in str(exc_info.value)
@pytest.mark.asyncio
async def test_switch_org_user_not_member():
"""
GIVEN: User is not a member of the organization
WHEN: switch_org is called
THEN: OrgAuthorizationError is raised
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
with (
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
patch('storage.org_service.OrgService.is_org_member', return_value=False),
):
# Act & Assert
with pytest.raises(OrgAuthorizationError) as exc_info:
await OrgService.switch_org(user_id, org_id)
assert 'member' in str(exc_info.value).lower()
@pytest.mark.asyncio
async def test_switch_org_user_not_found():
"""
GIVEN: User does not exist in database
WHEN: switch_org is called
THEN: OrgDatabaseError is raised
"""
# Arrange
org_id = uuid.uuid4()
user_id = str(uuid.uuid4())
mock_org = Org(
id=org_id,
name='Target Organization',
contact_name='John Doe',
contact_email='john@example.com',
)
with (
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
patch('storage.org_service.OrgService.is_org_member', return_value=True),
patch('storage.org_service.UserStore.update_current_org', return_value=None),
):
# Act & Assert
with pytest.raises(OrgDatabaseError) as exc_info:
await OrgService.switch_org(user_id, org_id)
assert 'User not found' in str(exc_info.value)

View File

@@ -522,3 +522,46 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker):
with session_maker() as session:
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
assert org.contact_name == 'Custom Corp Name'
def test_update_current_org_success(session_maker):
"""
GIVEN: User exists in database
WHEN: update_current_org is called with new org_id
THEN: User's current_org_id is updated and user is returned
"""
# Arrange
user_id = str(uuid.uuid4())
initial_org_id = uuid.uuid4()
new_org_id = uuid.uuid4()
with session_maker() as session:
user = User(id=uuid.UUID(user_id), current_org_id=initial_org_id)
session.add(user)
session.commit()
# Act
with patch('storage.user_store.session_maker', session_maker):
result = UserStore.update_current_org(user_id, new_org_id)
# Assert
assert result is not None
assert result.current_org_id == new_org_id
def test_update_current_org_user_not_found(session_maker):
"""
GIVEN: User does not exist in database
WHEN: update_current_org is called
THEN: None is returned
"""
# Arrange
user_id = str(uuid.uuid4())
org_id = uuid.uuid4()
# Act
with patch('storage.user_store.session_maker', session_maker):
result = UserStore.update_current_org(user_id, org_id)
# Assert
assert result is None