From 134c122026b4a3477f284284814479f9cc8d2c82 Mon Sep 17 00:00:00 2001 From: Hiep Le <69354317+hieptl@users.noreply.github.com> Date: Thu, 23 Oct 2025 01:11:04 +0700 Subject: [PATCH] fix: disable pro subscription upgrade on LLM page for self-hosted installs (#11479) --- enterprise/server/routes/billing.py | 39 ++++++++ enterprise/tests/unit/test_billing.py | 88 +++++++++++++------ .../__tests__/routes/llm-settings.test.tsx | 10 ++- .../use-is-all-hands-saas-environment.ts | 13 +++ frontend/src/routes/llm-settings.tsx | 7 +- 5 files changed, 130 insertions(+), 27 deletions(-) create mode 100644 frontend/src/hooks/use-is-all-hands-saas-environment.ts diff --git a/enterprise/server/routes/billing.py b/enterprise/server/routes/billing.py index 1d52b54ffb..b1a6fc96fb 100644 --- a/enterprise/server/routes/billing.py +++ b/enterprise/server/routes/billing.py @@ -31,6 +31,37 @@ stripe.api_key = STRIPE_API_KEY billing_router = APIRouter(prefix='/api/billing') +# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this +# and members should comment out the "validate_saas_environment" function if they are developing and testing locally. +def is_all_hands_saas_environment(request: Request) -> bool: + """Check if the current domain is an All Hands SaaS environment. + + Args: + request: FastAPI Request object + + Returns: + True if the current domain contains "all-hands.dev" or "openhands.dev" postfix + """ + hostname = request.url.hostname or '' + return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev') + + +def validate_saas_environment(request: Request) -> None: + """Validate that the request is coming from an All Hands SaaS environment. + + Args: + request: FastAPI Request object + + Raises: + HTTPException: If the request is not from an All Hands SaaS environment + """ + if not is_all_hands_saas_environment(request): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail='Checkout sessions are only available for All Hands SaaS environments', + ) + + class BillingSessionType(Enum): DIRECT_PAYMENT = 'DIRECT_PAYMENT' MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION' @@ -196,6 +227,8 @@ async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONRespon async def create_customer_setup_session( request: Request, user_id: str = Depends(get_user_id) ) -> CreateBillingSessionResponse: + validate_saas_environment(request) + customer_id = await stripe_service.find_or_create_customer(user_id) checkout_session = await stripe.checkout.Session.create_async( customer=customer_id, @@ -214,6 +247,8 @@ async def create_checkout_session( request: Request, user_id: str = Depends(get_user_id), ) -> CreateBillingSessionResponse: + validate_saas_environment(request) + customer_id = await stripe_service.find_or_create_customer(user_id) checkout_session = await stripe.checkout.Session.create_async( customer=customer_id, @@ -268,6 +303,8 @@ async def create_subscription_checkout_session( billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION, user_id: str = Depends(get_user_id), ) -> CreateBillingSessionResponse: + validate_saas_environment(request) + # Prevent duplicate subscriptions for the same user with session_maker() as session: now = datetime.now(UTC) @@ -343,6 +380,8 @@ async def create_subscription_checkout_session_via_get( user_id: str = Depends(get_user_id), ) -> RedirectResponse: """Create a subscription checkout session using a GET request (For easier copy / paste to URL bar).""" + validate_saas_environment(request) + response = await create_subscription_checkout_session( request, billing_session_type, user_id ) diff --git a/enterprise/tests/unit/test_billing.py b/enterprise/tests/unit/test_billing.py index e35577e431..cc05af60e2 100644 --- a/enterprise/tests/unit/test_billing.py +++ b/enterprise/tests/unit/test_billing.py @@ -36,6 +36,46 @@ def session_maker(engine): return sessionmaker(bind=engine) +@pytest.fixture +def mock_request(): + """Create a mock request object with proper URL structure for testing.""" + return Request( + scope={ + 'type': 'http', + 'path': '/api/billing/test', + 'server': ('test.com', 80), + } + ) + + +@pytest.fixture +def mock_checkout_request(): + """Create a mock request object for checkout session tests.""" + request = Request( + scope={ + 'type': 'http', + 'path': '/api/billing/create-checkout-session', + 'server': ('test.com', 80), + } + ) + request._base_url = URL('http://test.com/') + return request + + +@pytest.fixture +def mock_subscription_request(): + """Create a mock request object for subscription checkout session tests.""" + request = Request( + scope={ + 'type': 'http', + 'path': '/api/billing/subscription-checkout-session', + 'server': ('test.com', 80), + } + ) + request._base_url = URL('http://test.com/') + return request + + @pytest.mark.asyncio async def test_get_credits_lite_llm_error(): mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}}) @@ -90,14 +130,10 @@ async def test_get_credits_success(): @pytest.mark.asyncio -async def test_create_checkout_session_stripe_error(session_maker): +async def test_create_checkout_session_stripe_error( + session_maker, mock_checkout_request +): """Test handling of Stripe API errors.""" - mock_request = Request( - scope={ - 'type': 'http', - } - ) - mock_request._base_url = URL('http://test.com/') mock_customer = stripe.Customer( id='mock-customer', metadata={'user_id': 'mock-user'} @@ -118,17 +154,16 @@ async def test_create_checkout_session_stripe_error(session_maker): 'server.auth.token_manager.TokenManager.get_user_info_from_user_id', AsyncMock(return_value={'email': 'testy@tester.com'}), ), + patch('server.routes.billing.validate_saas_environment'), ): await create_checkout_session( - CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user' + CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user' ) @pytest.mark.asyncio -async def test_create_checkout_session_success(session_maker): +async def test_create_checkout_session_success(session_maker, mock_checkout_request): """Test successful creation of checkout session.""" - mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') mock_session = MagicMock() mock_session.url = 'https://checkout.stripe.com/test-session' @@ -152,12 +187,13 @@ async def test_create_checkout_session_success(session_maker): 'server.auth.token_manager.TokenManager.get_user_info_from_user_id', AsyncMock(return_value={'email': 'testy@tester.com'}), ), + patch('server.routes.billing.validate_saas_environment'), ): mock_db_session = MagicMock() mock_session_maker.return_value.__enter__.return_value = mock_db_session result = await create_checkout_session( - CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user' + CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user' ) assert isinstance(result, CreateBillingSessionResponse) @@ -590,7 +626,9 @@ async def test_cancel_subscription_stripe_error(): @pytest.mark.asyncio -async def test_create_subscription_checkout_session_duplicate_prevention(): +async def test_create_subscription_checkout_session_duplicate_prevention( + mock_subscription_request, +): """Test that creating a subscription when user already has active subscription raises error.""" from datetime import UTC, datetime @@ -609,11 +647,9 @@ async def test_create_subscription_checkout_session_duplicate_prevention(): cancelled_at=None, ) - mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') - with ( patch('server.routes.billing.session_maker') as mock_session_maker, + patch('server.routes.billing.validate_saas_environment'), ): # Setup mock session to return existing active subscription mock_session = MagicMock() @@ -623,7 +659,7 @@ async def test_create_subscription_checkout_session_duplicate_prevention(): # Call the function and expect HTTPException with pytest.raises(HTTPException) as exc_info: await create_subscription_checkout_session( - mock_request, user_id='test_user' + mock_subscription_request, user_id='test_user' ) assert exc_info.value.status_code == 400 @@ -634,10 +670,10 @@ async def test_create_subscription_checkout_session_duplicate_prevention(): @pytest.mark.asyncio -async def test_create_subscription_checkout_session_allows_after_cancellation(): +async def test_create_subscription_checkout_session_allows_after_cancellation( + mock_subscription_request, +): """Test that creating a subscription is allowed when previous subscription was cancelled.""" - mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') mock_session_obj = MagicMock() mock_session_obj.url = 'https://checkout.stripe.com/test-session' @@ -657,6 +693,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation(): 'server.routes.billing.SUBSCRIPTION_PRICE_DATA', {'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}}, ), + patch('server.routes.billing.validate_saas_environment'), ): # Setup mock session - the query should return None because cancelled subscriptions are filtered out mock_session = MagicMock() @@ -665,7 +702,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation(): # Should succeed result = await create_subscription_checkout_session( - mock_request, user_id='test_user' + mock_subscription_request, user_id='test_user' ) assert isinstance(result, CreateBillingSessionResponse) @@ -673,10 +710,10 @@ async def test_create_subscription_checkout_session_allows_after_cancellation(): @pytest.mark.asyncio -async def test_create_subscription_checkout_session_success_no_existing(): +async def test_create_subscription_checkout_session_success_no_existing( + mock_subscription_request, +): """Test successful subscription creation when no existing subscription.""" - mock_request = Request(scope={'type': 'http'}) - mock_request._base_url = URL('http://test.com/') mock_session_obj = MagicMock() mock_session_obj.url = 'https://checkout.stripe.com/test-session' @@ -696,6 +733,7 @@ async def test_create_subscription_checkout_session_success_no_existing(): 'server.routes.billing.SUBSCRIPTION_PRICE_DATA', {'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}}, ), + patch('server.routes.billing.validate_saas_environment'), ): # Setup mock session to return no existing subscription mock_session = MagicMock() @@ -704,7 +742,7 @@ async def test_create_subscription_checkout_session_success_no_existing(): # Should succeed result = await create_subscription_checkout_session( - mock_request, user_id='test_user' + mock_subscription_request, user_id='test_user' ) assert isinstance(result, CreateBillingSessionResponse) diff --git a/frontend/__tests__/routes/llm-settings.test.tsx b/frontend/__tests__/routes/llm-settings.test.tsx index 4a52282efc..7459579b92 100644 --- a/frontend/__tests__/routes/llm-settings.test.tsx +++ b/frontend/__tests__/routes/llm-settings.test.tsx @@ -25,6 +25,12 @@ vi.mock("#/hooks/query/use-is-authed", () => ({ useIsAuthed: () => mockUseIsAuthed(), })); +// Mock useIsAllHandsSaaSEnvironment hook +const mockUseIsAllHandsSaaSEnvironment = vi.fn(); +vi.mock("#/hooks/use-is-all-hands-saas-environment", () => ({ + useIsAllHandsSaaSEnvironment: () => mockUseIsAllHandsSaaSEnvironment(), +})); + const renderLlmSettingsScreen = () => render(, { wrapper: ({ children }) => ( @@ -48,6 +54,9 @@ beforeEach(() => { // Default mock for useIsAuthed - returns authenticated by default mockUseIsAuthed.mockReturnValue({ data: true, isLoading: false }); + + // Default mock for useIsAllHandsSaaSEnvironment - returns true for SaaS environment + mockUseIsAllHandsSaaSEnvironment.mockReturnValue(true); }); describe("Content", () => { @@ -104,7 +113,6 @@ describe("Content", () => { expect(screen.getByTestId("set-indicator")).toBeInTheDocument(); }); }); - }); describe("Advanced form", () => { diff --git a/frontend/src/hooks/use-is-all-hands-saas-environment.ts b/frontend/src/hooks/use-is-all-hands-saas-environment.ts new file mode 100644 index 0000000000..68ae66bd5d --- /dev/null +++ b/frontend/src/hooks/use-is-all-hands-saas-environment.ts @@ -0,0 +1,13 @@ +import { useMemo } from "react"; + +/** + * Hook to check if the current domain is an All Hands SaaS environment + * @returns True if the current domain contains "all-hands.dev" or "openhands.dev" postfix + */ +export const useIsAllHandsSaaSEnvironment = (): boolean => + useMemo(() => { + const { hostname } = window.location; + return ( + hostname.endsWith("all-hands.dev") || hostname.endsWith("openhands.dev") + ); + }, []); diff --git a/frontend/src/routes/llm-settings.tsx b/frontend/src/routes/llm-settings.tsx index 810074ccee..b717febaac 100644 --- a/frontend/src/routes/llm-settings.tsx +++ b/frontend/src/routes/llm-settings.tsx @@ -33,6 +33,7 @@ import { UpgradeBannerWithBackdrop } from "#/components/features/settings/upgrad import { useCreateSubscriptionCheckoutSession } from "#/hooks/mutation/stripe/use-create-subscription-checkout-session"; import { useIsAuthed } from "#/hooks/query/use-is-authed"; import { cn } from "#/utils/utils"; +import { useIsAllHandsSaaSEnvironment } from "#/hooks/use-is-all-hands-saas-environment"; interface OpenHandsApiKeyHelpProps { testId: string; @@ -78,6 +79,7 @@ function LlmSettingsScreen() { const { data: isAuthed } = useIsAuthed(); const { mutate: createSubscriptionCheckoutSession } = useCreateSubscriptionCheckoutSession(); + const isAllHandsSaaSEnvironment = useIsAllHandsSaaSEnvironment(); const [view, setView] = React.useState<"basic" | "advanced">("basic"); @@ -441,8 +443,11 @@ function LlmSettingsScreen() { if (!settings || isFetching) return ; // Show upgrade banner and disable form in SaaS mode when user doesn't have an active subscription + // Exclude self-hosted enterprise customers (those not on all-hands.dev domains) const shouldShowUpgradeBanner = - config?.APP_MODE === "saas" && !subscriptionAccess; + config?.APP_MODE === "saas" && + !subscriptionAccess && + isAllHandsSaaSEnvironment; const formAction = (formData: FormData) => { // Prevent form submission for unsubscribed SaaS users