mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
fix: disable pro subscription upgrade on LLM page for self-hosted installs (#11479)
This commit is contained in:
@@ -31,6 +31,37 @@ stripe.api_key = STRIPE_API_KEY
|
||||
billing_router = APIRouter(prefix='/api/billing')
|
||||
|
||||
|
||||
# TODO: Add a new app_mode named "ON_PREM" to support self-hosted customers instead of doing this
|
||||
# and members should comment out the "validate_saas_environment" function if they are developing and testing locally.
|
||||
def is_all_hands_saas_environment(request: Request) -> bool:
|
||||
"""Check if the current domain is an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
"""
|
||||
hostname = request.url.hostname or ''
|
||||
return hostname.endswith('all-hands.dev') or hostname.endswith('openhands.dev')
|
||||
|
||||
|
||||
def validate_saas_environment(request: Request) -> None:
|
||||
"""Validate that the request is coming from an All Hands SaaS environment.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Raises:
|
||||
HTTPException: If the request is not from an All Hands SaaS environment
|
||||
"""
|
||||
if not is_all_hands_saas_environment(request):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Checkout sessions are only available for All Hands SaaS environments',
|
||||
)
|
||||
|
||||
|
||||
class BillingSessionType(Enum):
|
||||
DIRECT_PAYMENT = 'DIRECT_PAYMENT'
|
||||
MONTHLY_SUBSCRIPTION = 'MONTHLY_SUBSCRIPTION'
|
||||
@@ -196,6 +227,8 @@ async def cancel_subscription(user_id: str = Depends(get_user_id)) -> JSONRespon
|
||||
async def create_customer_setup_session(
|
||||
request: Request, user_id: str = Depends(get_user_id)
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
@@ -214,6 +247,8 @@ async def create_checkout_session(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
customer_id = await stripe_service.find_or_create_customer(user_id)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_id,
|
||||
@@ -268,6 +303,8 @@ async def create_subscription_checkout_session(
|
||||
billing_session_type: BillingSessionType = BillingSessionType.MONTHLY_SUBSCRIPTION,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
validate_saas_environment(request)
|
||||
|
||||
# Prevent duplicate subscriptions for the same user
|
||||
with session_maker() as session:
|
||||
now = datetime.now(UTC)
|
||||
@@ -343,6 +380,8 @@ async def create_subscription_checkout_session_via_get(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> RedirectResponse:
|
||||
"""Create a subscription checkout session using a GET request (For easier copy / paste to URL bar)."""
|
||||
validate_saas_environment(request)
|
||||
|
||||
response = await create_subscription_checkout_session(
|
||||
request, billing_session_type, user_id
|
||||
)
|
||||
|
||||
@@ -36,6 +36,46 @@ def session_maker(engine):
|
||||
return sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock request object with proper URL structure for testing."""
|
||||
return Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/test',
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_checkout_request():
|
||||
"""Create a mock request object for checkout session tests."""
|
||||
request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/create-checkout-session',
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
request._base_url = URL('http://test.com/')
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subscription_request():
|
||||
"""Create a mock request object for subscription checkout session tests."""
|
||||
request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
'path': '/api/billing/subscription-checkout-session',
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
request._base_url = URL('http://test.com/')
|
||||
return request
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_credits_lite_llm_error():
|
||||
mock_request = Request(scope={'type': 'http', 'state': {'user_id': 'mock_user'}})
|
||||
@@ -90,14 +130,10 @@ async def test_get_credits_success():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_stripe_error(session_maker):
|
||||
async def test_create_checkout_session_stripe_error(
|
||||
session_maker, mock_checkout_request
|
||||
):
|
||||
"""Test handling of Stripe API errors."""
|
||||
mock_request = Request(
|
||||
scope={
|
||||
'type': 'http',
|
||||
}
|
||||
)
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_customer = stripe.Customer(
|
||||
id='mock-customer', metadata={'user_id': 'mock-user'}
|
||||
@@ -118,17 +154,16 @@ async def test_create_checkout_session_stripe_error(session_maker):
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user'
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkout_session_success(session_maker):
|
||||
async def test_create_checkout_session_success(session_maker, mock_checkout_request):
|
||||
"""Test successful creation of checkout session."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.url = 'https://checkout.stripe.com/test-session'
|
||||
@@ -152,12 +187,13 @@ async def test_create_checkout_session_success(session_maker):
|
||||
'server.auth.token_manager.TokenManager.get_user_info_from_user_id',
|
||||
AsyncMock(return_value={'email': 'testy@tester.com'}),
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
mock_db_session = MagicMock()
|
||||
mock_session_maker.return_value.__enter__.return_value = mock_db_session
|
||||
|
||||
result = await create_checkout_session(
|
||||
CreateCheckoutSessionRequest(amount=25), mock_request, 'mock_user'
|
||||
CreateCheckoutSessionRequest(amount=25), mock_checkout_request, 'mock_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
@@ -590,7 +626,9 @@ async def test_cancel_subscription_stripe_error():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
async def test_create_subscription_checkout_session_duplicate_prevention(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription when user already has active subscription raises error."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
@@ -609,11 +647,9 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
cancelled_at=None,
|
||||
)
|
||||
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.session_maker') as mock_session_maker,
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return existing active subscription
|
||||
mock_session = MagicMock()
|
||||
@@ -623,7 +659,7 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
# Call the function and expect HTTPException
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await create_subscription_checkout_session(
|
||||
mock_request, user_id='test_user'
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
@@ -634,10 +670,10 @@ async def test_create_subscription_checkout_session_duplicate_prevention():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
async def test_create_subscription_checkout_session_allows_after_cancellation(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test that creating a subscription is allowed when previous subscription was cancelled."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
@@ -657,6 +693,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session - the query should return None because cancelled subscriptions are filtered out
|
||||
mock_session = MagicMock()
|
||||
@@ -665,7 +702,7 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_request, user_id='test_user'
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
@@ -673,10 +710,10 @@ async def test_create_subscription_checkout_session_allows_after_cancellation():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_subscription_checkout_session_success_no_existing():
|
||||
async def test_create_subscription_checkout_session_success_no_existing(
|
||||
mock_subscription_request,
|
||||
):
|
||||
"""Test successful subscription creation when no existing subscription."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_session_obj = MagicMock()
|
||||
mock_session_obj.url = 'https://checkout.stripe.com/test-session'
|
||||
@@ -696,6 +733,7 @@ async def test_create_subscription_checkout_session_success_no_existing():
|
||||
'server.routes.billing.SUBSCRIPTION_PRICE_DATA',
|
||||
{'MONTHLY_SUBSCRIPTION': {'unit_amount': 2000}},
|
||||
),
|
||||
patch('server.routes.billing.validate_saas_environment'),
|
||||
):
|
||||
# Setup mock session to return no existing subscription
|
||||
mock_session = MagicMock()
|
||||
@@ -704,7 +742,7 @@ async def test_create_subscription_checkout_session_success_no_existing():
|
||||
|
||||
# Should succeed
|
||||
result = await create_subscription_checkout_session(
|
||||
mock_request, user_id='test_user'
|
||||
mock_subscription_request, user_id='test_user'
|
||||
)
|
||||
|
||||
assert isinstance(result, CreateBillingSessionResponse)
|
||||
|
||||
Reference in New Issue
Block a user