mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
feat(backend): add api for switching between orgs (#12799)
This commit is contained in:
@@ -603,3 +603,78 @@ async def remove_org_member(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to remove member',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/{org_id}/switch', response_model=OrgResponse, status_code=status.HTTP_200_OK
|
||||
)
|
||||
async def switch_org(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgResponse:
|
||||
"""Switch to a different organization.
|
||||
|
||||
This endpoint allows authenticated users to switch their current active
|
||||
organization. The user must be a member of the target organization.
|
||||
|
||||
Args:
|
||||
org_id: Organization ID to switch to (UUID)
|
||||
user_id: Authenticated user ID (injected by dependency)
|
||||
|
||||
Returns:
|
||||
OrgResponse: The organization details that was switched to
|
||||
|
||||
Raises:
|
||||
HTTPException: 422 if org_id is not a valid UUID (handled by FastAPI)
|
||||
HTTPException: 403 if user is not a member of the organization
|
||||
HTTPException: 404 if organization not found
|
||||
HTTPException: 500 if switch fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
# Use service layer to switch organization with membership validation
|
||||
org = await OrgService.switch_org(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Retrieve credits from LiteLLM for the new current org
|
||||
credits = await OrgService.get_org_credits(user_id, org.id)
|
||||
|
||||
return OrgResponse.from_org(org, credits=credits, user_id=user_id)
|
||||
|
||||
except OrgNotFoundError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgAuthorizationError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=str(e),
|
||||
)
|
||||
except OrgDatabaseError as e:
|
||||
logger.error(
|
||||
'Database operation failed during organization switch',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to switch organization',
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error switching organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
|
||||
@@ -865,3 +865,71 @@ class OrgService:
|
||||
return False
|
||||
|
||||
return org.byor_export_enabled
|
||||
|
||||
@staticmethod
|
||||
async def switch_org(user_id: str, org_id: UUID) -> Org:
|
||||
"""
|
||||
Switch user's current organization to the specified organization.
|
||||
|
||||
This method:
|
||||
1. Validates that the organization exists
|
||||
2. Validates that the user is a member of the organization
|
||||
3. Updates the user's current_org_id
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID to switch to
|
||||
|
||||
Returns:
|
||||
Org: The organization that was switched to
|
||||
|
||||
Raises:
|
||||
OrgNotFoundError: If organization doesn't exist
|
||||
OrgAuthorizationError: If user is not a member of the organization
|
||||
OrgDatabaseError: If database update fails
|
||||
"""
|
||||
logger.info(
|
||||
'Switching user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
# Step 1: Check if organization exists
|
||||
org = OrgStore.get_org_by_id(org_id)
|
||||
if not org:
|
||||
raise OrgNotFoundError(str(org_id))
|
||||
|
||||
# Step 2: Validate user is a member of the organization
|
||||
if not OrgService.is_org_member(user_id, org_id):
|
||||
logger.warning(
|
||||
'User attempted to switch to organization they are not a member of',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise OrgAuthorizationError(
|
||||
'User must be a member of the organization to switch to it'
|
||||
)
|
||||
|
||||
# Step 3: Update user's current_org_id
|
||||
try:
|
||||
updated_user = UserStore.update_current_org(user_id, org_id)
|
||||
if not updated_user:
|
||||
raise OrgDatabaseError('User not found')
|
||||
|
||||
logger.info(
|
||||
'Successfully switched user organization',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'org_name': org.name,
|
||||
},
|
||||
)
|
||||
|
||||
return org
|
||||
|
||||
except OrgDatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'Failed to switch user organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'error': str(e)},
|
||||
)
|
||||
raise OrgDatabaseError(f'Failed to switch organization: {str(e)}')
|
||||
|
||||
@@ -5,6 +5,7 @@ Store class for managing users.
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
@@ -773,6 +774,32 @@ class UserStore:
|
||||
with session_maker() as session:
|
||||
return session.query(User).all()
|
||||
|
||||
@staticmethod
|
||||
def update_current_org(user_id: str, org_id: UUID) -> Optional[User]:
|
||||
"""Update the user's current organization.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
org_id: The organization ID to set as current
|
||||
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
with session_maker() as session:
|
||||
user = (
|
||||
session.query(User)
|
||||
.filter(User.id == uuid.UUID(user_id))
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user.current_org_id = org_id
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
|
||||
"""Update contact_name on the personal org if it still has a username-style value.
|
||||
|
||||
@@ -2635,3 +2635,136 @@ class TestGetMeEndpoint:
|
||||
await get_me(org_id=test_org_id, user_id=test_user_id)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_success(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: Valid org_id and authenticated user who is a member
|
||||
WHEN: POST /api/organizations/{org_id}/switch is called
|
||||
THEN: User's current org is switched and org details returned with 200 status
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
org_version=5,
|
||||
default_llm_model='claude-opus-4-5-20251101',
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.switch_org',
|
||||
AsyncMock(return_value=mock_org),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgService.get_org_credits',
|
||||
AsyncMock(return_value=100.0),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(f'/api/organizations/{org_id}/switch')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
response_data = response.json()
|
||||
assert response_data['id'] == str(org_id)
|
||||
assert response_data['name'] == 'Target Organization'
|
||||
assert response_data['credits'] == 100.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_not_member(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: User is not a member of the target organization
|
||||
WHEN: POST /api/organizations/{org_id}/switch is called
|
||||
THEN: 403 Forbidden error is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.switch_org',
|
||||
AsyncMock(
|
||||
side_effect=OrgAuthorizationError(
|
||||
'User must be a member of the organization to switch to it'
|
||||
)
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(f'/api/organizations/{org_id}/switch')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'member' in response.json()['detail'].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_not_found(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: Organization does not exist
|
||||
WHEN: POST /api/organizations/{org_id}/switch is called
|
||||
THEN: 404 Not Found error is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.switch_org',
|
||||
AsyncMock(side_effect=OrgNotFoundError(str(org_id))),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(f'/api/organizations/{org_id}/switch')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_invalid_uuid(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: Invalid UUID format for org_id
|
||||
WHEN: POST /api/organizations/{org_id}/switch is called
|
||||
THEN: 422 Unprocessable Entity error is returned
|
||||
"""
|
||||
# Arrange
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post('/api/organizations/not-a-valid-uuid/switch')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_database_error(mock_app_with_get_user_id):
|
||||
"""
|
||||
GIVEN: Database operation fails during switch
|
||||
WHEN: POST /api/organizations/{org_id}/switch is called
|
||||
THEN: 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgService.switch_org',
|
||||
AsyncMock(side_effect=OrgDatabaseError('Database connection failed')),
|
||||
):
|
||||
client = TestClient(mock_app_with_get_user_id)
|
||||
|
||||
# Act
|
||||
response = client.post(f'/api/organizations/{org_id}/switch')
|
||||
|
||||
# Assert
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
assert 'Failed to switch organization' in response.json()['detail']
|
||||
|
||||
@@ -1801,3 +1801,114 @@ async def test_check_byor_export_enabled_returns_false_when_org_not_found():
|
||||
|
||||
# Assert
|
||||
assert result is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_success():
|
||||
"""
|
||||
GIVEN: Valid org_id and user_id where user is a member
|
||||
WHEN: switch_org is called
|
||||
THEN: User's current_org_id is updated and org is returned
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
mock_updated_user = User(id=uuid.UUID(user_id), current_org_id=org_id)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch(
|
||||
'storage.org_service.UserStore.update_current_org',
|
||||
return_value=mock_updated_user,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.id == org_id
|
||||
assert result.name == 'Target Organization'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_org_not_found():
|
||||
"""
|
||||
GIVEN: Organization does not exist
|
||||
WHEN: switch_org is called
|
||||
THEN: OrgNotFoundError is raised
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
with patch('storage.org_service.OrgStore.get_org_by_id', return_value=None):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgNotFoundError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
assert str(org_id) in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_user_not_member():
|
||||
"""
|
||||
GIVEN: User is not a member of the organization
|
||||
WHEN: switch_org is called
|
||||
THEN: OrgAuthorizationError is raised
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=False),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgAuthorizationError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
assert 'member' in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_switch_org_user_not_found():
|
||||
"""
|
||||
GIVEN: User does not exist in database
|
||||
WHEN: switch_org is called
|
||||
THEN: OrgDatabaseError is raised
|
||||
"""
|
||||
# Arrange
|
||||
org_id = uuid.uuid4()
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_org = Org(
|
||||
id=org_id,
|
||||
name='Target Organization',
|
||||
contact_name='John Doe',
|
||||
contact_email='john@example.com',
|
||||
)
|
||||
|
||||
with (
|
||||
patch('storage.org_service.OrgStore.get_org_by_id', return_value=mock_org),
|
||||
patch('storage.org_service.OrgService.is_org_member', return_value=True),
|
||||
patch('storage.org_service.UserStore.update_current_org', return_value=None),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(OrgDatabaseError) as exc_info:
|
||||
await OrgService.switch_org(user_id, org_id)
|
||||
|
||||
assert 'User not found' in str(exc_info.value)
|
||||
|
||||
@@ -522,3 +522,46 @@ async def test_backfill_contact_name_preserves_custom_value(session_maker):
|
||||
with session_maker() as session:
|
||||
org = session.query(Org).filter(Org.id == uuid.UUID(user_id)).first()
|
||||
assert org.contact_name == 'Custom Corp Name'
|
||||
|
||||
|
||||
def test_update_current_org_success(session_maker):
|
||||
"""
|
||||
GIVEN: User exists in database
|
||||
WHEN: update_current_org is called with new org_id
|
||||
THEN: User's current_org_id is updated and user is returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
initial_org_id = uuid.uuid4()
|
||||
new_org_id = uuid.uuid4()
|
||||
|
||||
with session_maker() as session:
|
||||
user = User(id=uuid.UUID(user_id), current_org_id=initial_org_id)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
result = UserStore.update_current_org(user_id, new_org_id)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result.current_org_id == new_org_id
|
||||
|
||||
|
||||
def test_update_current_org_user_not_found(session_maker):
|
||||
"""
|
||||
GIVEN: User does not exist in database
|
||||
WHEN: update_current_org is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Act
|
||||
with patch('storage.user_store.session_maker', session_maker):
|
||||
result = UserStore.update_current_org(user_id, org_id)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
Reference in New Issue
Block a user