mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ff5274f60b | |||
| d57e6c134c | |||
| d3cc121b08 |
@@ -234,8 +234,6 @@ yarn-error.log*
|
||||
|
||||
logs
|
||||
|
||||
ralph/
|
||||
|
||||
# agent
|
||||
.envrc
|
||||
/workspace
|
||||
|
||||
Generated
+29
-27
@@ -6190,14 +6190,14 @@ llama = ["llama-index (>=0.12.29,<0.13.0)", "llama-index-core (>=0.12.29,<0.13.0
|
||||
|
||||
[[package]]
|
||||
name = "openhands-agent-server"
|
||||
version = "1.13.0"
|
||||
version = "1.12.0"
|
||||
description = "OpenHands Agent Server - REST/WebSocket interface for OpenHands AI Agent"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_agent_server-1.13.0-py3-none-any.whl", hash = "sha256:88bb8bfb03ff0cc7a7d32ffabd108d0a284f4333f33a9de27ce158b6d828bc29"},
|
||||
{file = "openhands_agent_server-1.13.0.tar.gz", hash = "sha256:6f8b296c0f26a478d4eb49668a353e2b6997c39022c2bbcc36325f5f08887a7a"},
|
||||
{file = "openhands_agent_server-1.12.0-py3-none-any.whl", hash = "sha256:3bd62fef10092f1155af116a8a7417041d574eff9d4e4b6f7a24bfc432de2fad"},
|
||||
{file = "openhands_agent_server-1.12.0.tar.gz", hash = "sha256:7ea7ce579175f713ed68b68cde5d685ef694627ac7bbff40d2e22913f065c46d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6214,7 +6214,7 @@ wsproto = ">=1.2.0"
|
||||
|
||||
[[package]]
|
||||
name = "openhands-ai"
|
||||
version = "1.5.0"
|
||||
version = "1.4.0"
|
||||
description = "OpenHands: Code Less, Make More"
|
||||
optional = false
|
||||
python-versions = "^3.12,<3.14"
|
||||
@@ -6259,9 +6259,9 @@ memory-profiler = ">=0.61"
|
||||
numpy = "*"
|
||||
openai = "2.8"
|
||||
openhands-aci = "0.3.3"
|
||||
openhands-agent-server = "1.13"
|
||||
openhands-sdk = "1.13"
|
||||
openhands-tools = "1.13"
|
||||
openhands-agent-server = "1.12"
|
||||
openhands-sdk = "1.12"
|
||||
openhands-tools = "1.12"
|
||||
opentelemetry-api = ">=1.33.1"
|
||||
opentelemetry-exporter-otlp-proto-grpc = ">=1.33.1"
|
||||
pathspec = ">=0.12.1"
|
||||
@@ -6315,14 +6315,14 @@ url = ".."
|
||||
|
||||
[[package]]
|
||||
name = "openhands-sdk"
|
||||
version = "1.13.0"
|
||||
version = "1.12.0"
|
||||
description = "OpenHands SDK - Core functionality for building AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_sdk-1.13.0-py3-none-any.whl", hash = "sha256:ec83f9fa2934aae9c4ce1c0365a7037f7e17869affa44a40e71ba49d2bef7185"},
|
||||
{file = "openhands_sdk-1.13.0.tar.gz", hash = "sha256:fbb2a2dc4852ea23cc697a36fb3f95ca47cfef432b0d195c496de6f374caad9c"},
|
||||
{file = "openhands_sdk-1.12.0-py3-none-any.whl", hash = "sha256:857793f5c27fd63c0d4d37762550e6c504a03dd06116475c23adcc14bb5c4c02"},
|
||||
{file = "openhands_sdk-1.12.0.tar.gz", hash = "sha256:ac348e7134ea21e1ab453978962504aff8eb47e62df1fb7a503d769d55658ea9"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -6345,14 +6345,14 @@ boto3 = ["boto3 (>=1.35.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "openhands-tools"
|
||||
version = "1.13.0"
|
||||
version = "1.12.0"
|
||||
description = "OpenHands Tools - Runtime tools for AI agents"
|
||||
optional = false
|
||||
python-versions = ">=3.12"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "openhands_tools-1.13.0-py3-none-any.whl", hash = "sha256:87073b868e20f9c769497f480e0d15b14ca41314c3d1cb5076029f37408a1d68"},
|
||||
{file = "openhands_tools-1.13.0.tar.gz", hash = "sha256:e1181701efab5bc3133566e3b1640027824147438959cd8ce7430c941896704d"},
|
||||
{file = "openhands_tools-1.12.0-py3-none-any.whl", hash = "sha256:57207e9e30f9d7fe9121cd21b072580cfdc2a00831edeaf8e8d685d721bb9e33"},
|
||||
{file = "openhands_tools-1.12.0.tar.gz", hash = "sha256:f2b4d81d0b6771f5416f8b702db09a14999fa8e553073bcf38f344e29aae770c"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -11599,14 +11599,14 @@ diagrams = ["jinja2", "railroad-diagrams"]
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.7.5"
|
||||
description = "A pure-python PDF library capable of splitting, merging, cropping, and transforming PDF files"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7"},
|
||||
{file = "pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b"},
|
||||
{file = "pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13"},
|
||||
{file = "pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
@@ -13771,22 +13771,24 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.5.5"
|
||||
version = "6.5.4"
|
||||
description = "Tornado is a Python web framework and asynchronous networking library, originally developed at FriendFeed."
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "tornado-6.5.5-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:487dc9cc380e29f58c7ab88f9e27cdeef04b2140862e5076a66fb6bb68bb1bfa"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:65a7f1d46d4bb41df1ac99f5fcb685fb25c7e61613742d5108b010975a9a6521"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e74c92e8e65086b338fd56333fb9a68b9f6f2fe7ad532645a290a464bcf46be5"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:435319e9e340276428bbdb4e7fa732c2d399386d1de5686cb331ec8eee754f07"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3f54aa540bdbfee7b9eb268ead60e7d199de5021facd276819c193c0fb28ea4e"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:36abed1754faeb80fbd6e64db2758091e1320f6bba74a4cf8c09cd18ccce8aca"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-win32.whl", hash = "sha256:dd3eafaaeec1c7f2f8fdcd5f964e8907ad788fe8a5a32c4426fbbdda621223b7"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-win_amd64.whl", hash = "sha256:6443a794ba961a9f619b1ae926a2e900ac20c34483eea67be4ed8f1e58d3ef7b"},
|
||||
{file = "tornado-6.5.5-cp39-abi3-win_arm64.whl", hash = "sha256:2c9a876e094109333f888539ddb2de4361743e5d21eece20688e3e351e4990a6"},
|
||||
{file = "tornado-6.5.5.tar.gz", hash = "sha256:192b8f3ea91bd7f1f50c06955416ed76c6b72f96779b962f07f911b91e8d30e9"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d6241c1a16b1c9e4cc28148b1cda97dd1c6cb4fb7068ac1bedc610768dff0ba9"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2d50f63dda1d2cac3ae1fa23d254e16b5e38153758470e9956cbc3d813d40843"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1cf66105dc6acb5af613c054955b8137e34a03698aa53272dbda4afe252be17"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50ff0a58b0dc97939d29da29cd624da010e7f804746621c78d14b80238669335"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fb5e04efa54cf0baabdd10061eb4148e0be137166146fff835745f59ab9f7f"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9c86b1643b33a4cd415f8d0fe53045f913bf07b4a3ef646b735a6a86047dda84"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:6eb82872335a53dd063a4f10917b3efd28270b56a33db69009606a0312660a6f"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6076d5dda368c9328ff41ab5d9dd3608e695e8225d1cd0fd1e006f05da3635a8"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-win32.whl", hash = "sha256:1768110f2411d5cd281bac0a090f707223ce77fd110424361092859e089b38d1"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-win_amd64.whl", hash = "sha256:fa07d31e0cd85c60713f2b995da613588aa03e1303d75705dca6af8babc18ddc"},
|
||||
{file = "tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1"},
|
||||
{file = "tornado-6.5.4.tar.gz", hash = "sha256:a22fa9047405d03260b483980635f0b041989d8bcc9a313f8fe18b411d84b1d7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
"""Entry point for the automation executor.
|
||||
|
||||
Usage: python -m run_automation_executor
|
||||
|
||||
This runs as a Kubernetes Deployment (long-running). It polls the automation_events
|
||||
inbox, matches events to automations, claims and executes runs, and monitors
|
||||
conversation completion.
|
||||
|
||||
Environment variables:
|
||||
OPENHANDS_API_URL Base URL for the V1 API (default: http://openhands-service:3000)
|
||||
MAX_CONCURRENT_RUNS Max concurrent runs per executor (default: 5)
|
||||
RUN_TIMEOUT_SECONDS Max time for a single run (default: 7200)
|
||||
POLL_INTERVAL_SECONDS Fallback poll interval (default: 30)
|
||||
HEARTBEAT_INTERVAL_SECONDS Heartbeat update interval (default: 60)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger('saas.automation.executor')
|
||||
|
||||
|
||||
def _setup_logging() -> None:
|
||||
"""Configure logging, deferring to enterprise logger if available."""
|
||||
try:
|
||||
from server.logger import setup_all_loggers
|
||||
|
||||
setup_all_loggers()
|
||||
except ImportError:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s %(name)s %(levelname)s %(message)s',
|
||||
stream=sys.stdout,
|
||||
)
|
||||
|
||||
|
||||
def _install_signal_handlers(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Install signal handlers for graceful shutdown."""
|
||||
from services.automation_executor import request_shutdown
|
||||
|
||||
def _handle_signal(signum: int, _frame: object) -> None:
|
||||
sig_name = signal.Signals(signum).name
|
||||
logger.info('Received %s, initiating graceful shutdown...', sig_name)
|
||||
request_shutdown()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
signal.signal(sig, _handle_signal)
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
from services.automation_executor import executor_main
|
||||
|
||||
await executor_main()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_setup_logging()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
_install_signal_handlers(loop)
|
||||
|
||||
logger.info('Starting automation executor')
|
||||
try:
|
||||
loop.run_until_complete(main())
|
||||
except KeyboardInterrupt:
|
||||
logger.info('Interrupted by user')
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info('Automation executor process exiting')
|
||||
@@ -77,9 +77,6 @@ PERMITTED_CORS_ORIGINS = [
|
||||
)
|
||||
]
|
||||
|
||||
# Controls whether new orgs/users default to V1 API (env: DEFAULT_V1_ENABLED)
|
||||
DEFAULT_V1_ENABLED = os.getenv('DEFAULT_V1_ENABLED', '1').lower() in ('1', 'true')
|
||||
|
||||
|
||||
def build_litellm_proxy_model_path(model_name: str) -> str:
|
||||
"""Build the LiteLLM proxy model path based on model name.
|
||||
|
||||
@@ -12,8 +12,11 @@ from server.auth.auth_error import (
|
||||
)
|
||||
from server.auth.gitlab_sync import schedule_gitlab_repo_sync
|
||||
from server.auth.saas_user_auth import SaasUserAuth, token_manager
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.url_utils import get_cookie_domain, get_cookie_samesite
|
||||
from server.routes.auth import (
|
||||
get_cookie_domain,
|
||||
get_cookie_samesite,
|
||||
set_response_cookie,
|
||||
)
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth.user_auth import AuthType, UserAuth, get_user_auth
|
||||
@@ -90,8 +93,8 @@ class SetAuthCookieMiddleware:
|
||||
if keycloak_auth_cookie:
|
||||
response.delete_cookie(
|
||||
key='keycloak_auth',
|
||||
domain=get_cookie_domain(),
|
||||
samesite=get_cookie_samesite(),
|
||||
domain=get_cookie_domain(request),
|
||||
samesite=get_cookie_samesite(request),
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional, cast
|
||||
from typing import Annotated, Literal, Optional, cast
|
||||
from urllib.parse import quote, urlencode
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
@@ -27,7 +27,7 @@ from server.auth.user.user_authorizer import (
|
||||
depends_user_authorizer,
|
||||
)
|
||||
from server.config import sign_token
|
||||
from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
@@ -37,12 +37,12 @@ from server.services.org_invitation_service import (
|
||||
UserAlreadyMemberError,
|
||||
)
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
from server.utils.url_utils import get_cookie_domain, get_cookie_samesite, get_web_url
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.user import User
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType, TokenResponse
|
||||
@@ -77,7 +77,7 @@ def set_response_cookie(
|
||||
signed_token = sign_token(cookie_data, config.jwt_secret.get_secret_value()) # type: ignore
|
||||
|
||||
# Set secure cookie with signed token
|
||||
domain = get_cookie_domain()
|
||||
domain = get_cookie_domain(request)
|
||||
if domain:
|
||||
response.set_cookie(
|
||||
key='keycloak_auth',
|
||||
@@ -85,7 +85,7 @@ def set_response_cookie(
|
||||
domain=domain,
|
||||
httponly=True,
|
||||
secure=secure,
|
||||
samesite=get_cookie_samesite(),
|
||||
samesite=get_cookie_samesite(request),
|
||||
)
|
||||
else:
|
||||
response.set_cookie(
|
||||
@@ -93,10 +93,30 @@ def set_response_cookie(
|
||||
value=signed_token,
|
||||
httponly=True,
|
||||
secure=secure,
|
||||
samesite=get_cookie_samesite(),
|
||||
samesite=get_cookie_samesite(request),
|
||||
)
|
||||
|
||||
|
||||
def get_cookie_domain(request: Request) -> str | None:
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
None
|
||||
if not request.url.hostname
|
||||
or request.url.hostname.endswith('staging.all-hands.dev')
|
||||
else request.url.hostname
|
||||
)
|
||||
|
||||
|
||||
def get_cookie_samesite(request: Request) -> Literal['lax', 'strict']:
|
||||
# for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict'
|
||||
return (
|
||||
'lax'
|
||||
if request.url.hostname == 'localhost'
|
||||
or (request.url.hostname or '').endswith('staging.all-hands.dev')
|
||||
else 'strict'
|
||||
)
|
||||
|
||||
|
||||
def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None]:
|
||||
"""Extract redirect URL, reCAPTCHA token, and invitation token from OAuth state.
|
||||
|
||||
@@ -120,6 +140,19 @@ def _extract_oauth_state(state: str | None) -> tuple[str, str | None, str | None
|
||||
return state, None, None
|
||||
|
||||
|
||||
# Keep alias for backward compatibility
|
||||
def _extract_recaptcha_state(state: str | None) -> tuple[str, str | None]:
|
||||
"""Extract redirect URL and reCAPTCHA token from OAuth state.
|
||||
|
||||
Deprecated: Use _extract_oauth_state instead.
|
||||
|
||||
Returns:
|
||||
Tuple of (redirect_url, recaptcha_token). Token may be None.
|
||||
"""
|
||||
redirect_url, recaptcha_token, _ = _extract_oauth_state(state)
|
||||
return redirect_url, recaptcha_token
|
||||
|
||||
|
||||
@oauth_router.get('/keycloak/callback')
|
||||
async def keycloak_callback(
|
||||
request: Request,
|
||||
@@ -150,7 +183,10 @@ async def keycloak_callback(
|
||||
detail='Missing code in request params',
|
||||
)
|
||||
|
||||
web_url = get_web_url(request)
|
||||
web_url = get_global_config().web_url
|
||||
if not web_url:
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
web_url = f'{scheme}://{request.url.netloc}'
|
||||
redirect_uri = web_url + request.url.path
|
||||
|
||||
(
|
||||
@@ -277,9 +313,7 @@ async def keycloak_callback(
|
||||
else:
|
||||
raise
|
||||
|
||||
verification_redirect_url = (
|
||||
f'{web_url}/login?email_verification_required=true&user_id={user_id}'
|
||||
)
|
||||
verification_redirect_url = f'{request.base_url}login?email_verification_required=true&user_id={user_id}'
|
||||
if rate_limited:
|
||||
verification_redirect_url = f'{verification_redirect_url}&rate_limited=true'
|
||||
|
||||
@@ -440,7 +474,9 @@ async def keycloak_callback(
|
||||
# If the user hasn't accepted the TOS, redirect to the TOS page
|
||||
if not has_accepted_tos:
|
||||
encoded_redirect_url = quote(redirect_url, safe='')
|
||||
tos_redirect_url = f'{web_url}/accept-tos?redirect_url={encoded_redirect_url}'
|
||||
tos_redirect_url = (
|
||||
f'{request.base_url}accept-tos?redirect_url={encoded_redirect_url}'
|
||||
)
|
||||
if invitation_token:
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
@@ -472,9 +508,10 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={'error': 'Missing code in request params'},
|
||||
)
|
||||
|
||||
web_url = get_web_url(request)
|
||||
redirect_uri = web_url + request.url.path
|
||||
scheme = 'https'
|
||||
if request.url.hostname == 'localhost':
|
||||
scheme = 'http'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}{request.url.path}'
|
||||
logger.debug(f'code: {code}, redirect_uri: {redirect_uri}')
|
||||
|
||||
(
|
||||
@@ -496,14 +533,15 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
|
||||
)
|
||||
|
||||
redirect_url, _, _ = _extract_oauth_state(state)
|
||||
return RedirectResponse(redirect_url if redirect_url else web_url, status_code=302)
|
||||
return RedirectResponse(
|
||||
redirect_url if redirect_url else request.base_url, status_code=302
|
||||
)
|
||||
|
||||
|
||||
@oauth_router.get('/github/callback')
|
||||
async def github_dummy_callback(request: Request):
|
||||
"""Callback for GitHub that just forwards the user to the app base URL."""
|
||||
web_url = get_web_url(request)
|
||||
return RedirectResponse(web_url, status_code=302)
|
||||
return RedirectResponse(request.base_url, status_code=302)
|
||||
|
||||
|
||||
@api_router.post('/authenticate')
|
||||
@@ -525,8 +563,8 @@ async def authenticate(request: Request):
|
||||
if keycloak_auth_cookie:
|
||||
response.delete_cookie(
|
||||
key='keycloak_auth',
|
||||
domain=get_cookie_domain(),
|
||||
samesite=get_cookie_samesite(),
|
||||
domain=get_cookie_domain(request),
|
||||
samesite=get_cookie_samesite(request),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -550,8 +588,7 @@ async def accept_tos(request: Request):
|
||||
|
||||
# Get redirect URL from request body
|
||||
body = await request.json()
|
||||
web_url = get_web_url(request)
|
||||
redirect_url = body.get('redirect_url', str(web_url))
|
||||
redirect_url = body.get('redirect_url', str(request.base_url))
|
||||
|
||||
# Update user settings with TOS acceptance
|
||||
accepted_tos: datetime = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
@@ -581,7 +618,7 @@ async def accept_tos(request: Request):
|
||||
response=response,
|
||||
keycloak_access_token=access_token.get_secret_value(),
|
||||
keycloak_refresh_token=refresh_token.get_secret_value(),
|
||||
secure=not IS_LOCAL_ENV,
|
||||
secure=False if request.url.hostname == 'localhost' else True,
|
||||
accepted_tos=True,
|
||||
)
|
||||
return response
|
||||
@@ -598,8 +635,8 @@ async def logout(request: Request):
|
||||
# Always delete the cookie regardless of what happens
|
||||
response.delete_cookie(
|
||||
key='keycloak_auth',
|
||||
domain=get_cookie_domain(),
|
||||
samesite=get_cookie_samesite(),
|
||||
domain=get_cookie_domain(request),
|
||||
samesite=get_cookie_samesite(request),
|
||||
)
|
||||
|
||||
# Try to properly logout from Keycloak, but don't fail if it doesn't work
|
||||
|
||||
@@ -11,8 +11,8 @@ from integrations import stripe_service
|
||||
from pydantic import BaseModel
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from server.utils.url_utils import get_web_url
|
||||
from sqlalchemy import select
|
||||
from starlette.datastructures import URL
|
||||
from storage.billing_session import BillingSession
|
||||
from storage.database import a_session_maker
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
@@ -151,7 +151,7 @@ async def create_customer_setup_session(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='Could not find or create customer for user',
|
||||
)
|
||||
base_url = get_web_url(request)
|
||||
base_url = _get_base_url(request)
|
||||
checkout_session = await stripe.checkout.Session.create_async(
|
||||
customer=customer_info['customer_id'],
|
||||
mode='setup',
|
||||
@@ -170,7 +170,7 @@ async def create_checkout_session(
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CreateBillingSessionResponse:
|
||||
await validate_billing_enabled()
|
||||
base_url = get_web_url(request)
|
||||
base_url = _get_base_url(request)
|
||||
customer_info = await stripe_service.find_or_create_customer_by_user_id(user_id)
|
||||
if not customer_info:
|
||||
raise HTTPException(
|
||||
@@ -198,8 +198,8 @@ async def create_checkout_session(
|
||||
saved_payment_method_options={
|
||||
'payment_method_save': 'enabled',
|
||||
},
|
||||
success_url=f'{base_url}/api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{base_url}/api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
success_url=f'{base_url}api/billing/success?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
cancel_url=f'{base_url}api/billing/cancel?session_id={{CHECKOUT_SESSION_ID}}',
|
||||
)
|
||||
logger.info(
|
||||
'created_stripe_checkout_session',
|
||||
@@ -300,7 +300,7 @@ async def success_callback(session_id: str, request: Request):
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{get_web_url(request)}/settings/billing?checkout=success', status_code=302
|
||||
f'{_get_base_url(request)}settings/billing?checkout=success', status_code=302
|
||||
)
|
||||
|
||||
|
||||
@@ -325,9 +325,17 @@ async def cancel_callback(session_id: str, request: Request):
|
||||
)
|
||||
billing_session.status = 'cancelled'
|
||||
billing_session.updated_at = datetime.now(UTC)
|
||||
await session.merge(billing_session)
|
||||
session.merge(billing_session)
|
||||
await session.commit()
|
||||
|
||||
return RedirectResponse(
|
||||
f'{get_web_url(request)}/settings/billing?checkout=cancel', status_code=302
|
||||
f'{_get_base_url(request)}settings/billing?checkout=cancel', status_code=302
|
||||
)
|
||||
|
||||
|
||||
def _get_base_url(request: Request) -> URL:
|
||||
# Never send any part of the credit card process over a non secure connection
|
||||
base_url = request.base_url
|
||||
if base_url.hostname != 'localhost':
|
||||
base_url = base_url.replace(scheme='https')
|
||||
return base_url
|
||||
|
||||
@@ -7,10 +7,8 @@ from pydantic import BaseModel, field_validator
|
||||
from server.auth.constants import KEYCLOAK_CLIENT_ID
|
||||
from server.auth.keycloak_manager import get_keycloak_admin
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.constants import IS_LOCAL_ENV
|
||||
from server.routes.auth import set_response_cookie
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
from server.utils.url_utils import get_web_url
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
@@ -89,7 +87,7 @@ async def update_email(
|
||||
response=response,
|
||||
keycloak_access_token=user_auth.access_token.get_secret_value(),
|
||||
keycloak_refresh_token=user_auth.refresh_token.get_secret_value(),
|
||||
secure=not IS_LOCAL_ENV,
|
||||
secure=False if request.url.hostname == 'localhost' else True,
|
||||
accepted_tos=user_auth.accepted_tos or False,
|
||||
)
|
||||
|
||||
@@ -158,8 +156,8 @@ async def verified_email(request: Request):
|
||||
await user_auth.refresh() # refresh so access token has updated email
|
||||
user_auth.email_verified = True
|
||||
await UserStore.update_user_email(user_id=user_auth.user_id, email_verified=True)
|
||||
|
||||
redirect_uri = f'{get_web_url(request)}/settings/user'
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/settings/user'
|
||||
response = RedirectResponse(redirect_uri, status_code=302)
|
||||
|
||||
# need to set auth cookie to the new tokens
|
||||
@@ -182,10 +180,11 @@ async def verified_email(request: Request):
|
||||
|
||||
async def verify_email(request: Request, user_id: str, is_auth_flow: bool = False):
|
||||
keycloak_admin = get_keycloak_admin()
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
if is_auth_flow:
|
||||
redirect_uri = f'{get_web_url(request)}/login?email_verified=true'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/login?email_verified=true'
|
||||
else:
|
||||
redirect_uri = f'{get_web_url(request)}/api/email/verified'
|
||||
redirect_uri = f'{scheme}://{request.url.netloc}/api/email/verified'
|
||||
logger.info(f'Redirect URI: {redirect_uri}')
|
||||
await keycloak_admin.a_send_verify_email(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Form, HTTPException, Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from server.utils.url_utils import get_web_url
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.device_code_store import DeviceCodeStore
|
||||
|
||||
@@ -94,7 +93,7 @@ async def device_authorization(
|
||||
expires_in=DEVICE_CODE_EXPIRES_IN,
|
||||
)
|
||||
|
||||
base_url = get_web_url(http_request)
|
||||
base_url = str(http_request.base_url).rstrip('/')
|
||||
verification_uri = f'{base_url}/oauth/device/verify'
|
||||
verification_uri_complete = (
|
||||
f'{verification_uri}?user_code={device_code_entry.user_code}'
|
||||
|
||||
@@ -365,12 +365,14 @@ class OrgInvitationService:
|
||||
'Failed to set up organization access. Please try again.'
|
||||
)
|
||||
|
||||
# Step 4.5: Fetch organization to get its LLM settings
|
||||
org = await OrgStore.get_org_by_id(invitation.org_id)
|
||||
if not org:
|
||||
raise InvitationInvalidError('Organization not found')
|
||||
# Step 5: Add user to organization
|
||||
from storage.org_member_store import OrgMemberStore as OMS
|
||||
|
||||
org_member_kwargs = OMS.get_kwargs_from_settings(settings)
|
||||
# Don't override with org defaults - use invitation-specified role
|
||||
org_member_kwargs.pop('llm_model', None)
|
||||
org_member_kwargs.pop('llm_base_url', None)
|
||||
|
||||
# Step 5: Add user to organization with inherited org LLM settings
|
||||
# Get the llm_api_key as string (it's SecretStr | None in Settings)
|
||||
llm_api_key = (
|
||||
settings.llm_api_key.get_secret_value() if settings.llm_api_key else ''
|
||||
@@ -382,9 +384,6 @@ class OrgInvitationService:
|
||||
role_id=invitation.role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status='active',
|
||||
llm_model=org.default_llm_model,
|
||||
llm_base_url=org.default_llm_base_url,
|
||||
max_iterations=org.default_max_iterations,
|
||||
)
|
||||
|
||||
# Step 6: Mark invitation as accepted
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Implementation of SharedEventService for AWS S3.
|
||||
|
||||
This implementation provides read-only access to events from shared conversations:
|
||||
- Validates that the conversation is shared before returning events
|
||||
- Uses existing EventService for actual event retrieval
|
||||
- Uses SharedConversationInfoService for shared conversation validation
|
||||
|
||||
Uses role-based authentication (no credentials needed).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
import boto3
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.aws_event_service import AwsEventService
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AwsSharedEventService(SharedEventService):
|
||||
"""Implementation of SharedEventService for AWS S3 that validates shared access.
|
||||
|
||||
Uses role-based authentication (no credentials needed).
|
||||
"""
|
||||
|
||||
shared_conversation_info_service: SharedConversationInfoService
|
||||
s3_client: Any
|
||||
bucket_name: str
|
||||
|
||||
async def get_event_service(self, conversation_id: UUID) -> EventService | None:
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return None
|
||||
|
||||
return AwsEventService(
|
||||
s3_client=self.s3_client,
|
||||
bucket_name=self.bucket_name,
|
||||
prefix=Path('users'),
|
||||
user_id=shared_conversation_info.created_by_user_id,
|
||||
app_conversation_info_service=None,
|
||||
app_conversation_info_load_tasks={},
|
||||
)
|
||||
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: UUID
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
# First check if the conversation is shared
|
||||
event_service = await self.get_event_service(conversation_id)
|
||||
if event_service is None:
|
||||
return None
|
||||
|
||||
# If conversation is shared, get the event
|
||||
return await event_service.get_event(conversation_id, event_id)
|
||||
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
event_service = await self.get_event_service(conversation_id)
|
||||
if event_service is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return EventPage(items=[], next_page_id=None)
|
||||
|
||||
# If conversation is shared, search events for this conversation
|
||||
return await event_service.search_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
event_service = await self.get_event_service(conversation_id)
|
||||
if event_service is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return 0
|
||||
|
||||
# If conversation is shared, count events for this conversation
|
||||
return await event_service.count_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
)
|
||||
|
||||
|
||||
class AwsSharedEventServiceInjector(SharedEventServiceInjector):
|
||||
bucket_name: str | None = Field(
|
||||
default_factory=lambda: os.environ.get('FILE_STORE_PATH')
|
||||
)
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedEventService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state, request) as db_session:
|
||||
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
bucket_name = self.bucket_name
|
||||
if bucket_name is None:
|
||||
raise ValueError(
|
||||
'bucket_name is required. Set FILE_STORE_PATH environment variable.'
|
||||
)
|
||||
|
||||
# Use role-based authentication - boto3 will automatically
|
||||
# use IAM role credentials when running in AWS
|
||||
s3_client = boto3.client(
|
||||
's3',
|
||||
endpoint_url=os.getenv('AWS_S3_ENDPOINT'),
|
||||
)
|
||||
|
||||
service = AwsSharedEventService(
|
||||
shared_conversation_info_service=shared_conversation_info_service,
|
||||
s3_client=s3_client,
|
||||
bucket_name=bucket_name,
|
||||
)
|
||||
yield service
|
||||
@@ -5,45 +5,19 @@ from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
from server.sharing.google_cloud_shared_event_service import (
|
||||
GoogleCloudSharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_event_service import SharedEventService
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
from openhands.utils.environment import StorageProvider, get_storage_provider
|
||||
|
||||
|
||||
def get_shared_event_service_injector() -> SharedEventServiceInjector:
|
||||
"""Get the appropriate SharedEventServiceInjector based on configuration.
|
||||
|
||||
Uses get_storage_provider() to determine the storage backend.
|
||||
See openhands.utils.environment for supported environment variables.
|
||||
|
||||
Note: Shared events only support AWS and GCP storage. Filesystem storage
|
||||
falls back to GCP for shared events.
|
||||
"""
|
||||
provider = get_storage_provider()
|
||||
|
||||
if provider == StorageProvider.AWS:
|
||||
from server.sharing.aws_shared_event_service import (
|
||||
AwsSharedEventServiceInjector,
|
||||
)
|
||||
|
||||
return AwsSharedEventServiceInjector()
|
||||
else:
|
||||
# GCP is the default for shared events (including filesystem fallback)
|
||||
from server.sharing.google_cloud_shared_event_service import (
|
||||
GoogleCloudSharedEventServiceInjector,
|
||||
)
|
||||
|
||||
return GoogleCloudSharedEventServiceInjector()
|
||||
|
||||
|
||||
router = APIRouter(prefix='/api/shared-events', tags=['Sharing'])
|
||||
shared_event_service_dependency = Depends(get_shared_event_service_injector().depends)
|
||||
shared_event_service_dependency = Depends(
|
||||
GoogleCloudSharedEventServiceInjector().depends
|
||||
)
|
||||
|
||||
|
||||
# Read methods
|
||||
|
||||
@@ -119,7 +119,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sandbox_id__eq: str | None = None,
|
||||
sort_order: AppConversationSortOrder = AppConversationSortOrder.CREATED_AT_DESC,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
@@ -142,7 +141,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
sandbox_id__eq=sandbox_id__eq,
|
||||
)
|
||||
|
||||
# Add sort order
|
||||
@@ -200,7 +198,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sandbox_id__eq: str | None = None,
|
||||
) -> int:
|
||||
"""Count conversations matching the given filters with SAAS metadata."""
|
||||
query = (
|
||||
@@ -223,7 +220,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
created_at__lt=created_at__lt,
|
||||
updated_at__gte=updated_at__gte,
|
||||
updated_at__lt=updated_at__lt,
|
||||
sandbox_id__eq=sandbox_id__eq,
|
||||
)
|
||||
|
||||
result = await self.db_session.execute(query)
|
||||
@@ -238,7 +234,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
created_at__lt: datetime | None = None,
|
||||
updated_at__gte: datetime | None = None,
|
||||
updated_at__lt: datetime | None = None,
|
||||
sandbox_id__eq: str | None = None,
|
||||
):
|
||||
"""Apply filters to query that includes SAAS metadata."""
|
||||
# Apply the same filters as the base class
|
||||
@@ -264,9 +259,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
StoredConversationMetadata.last_updated_at < updated_at__lt
|
||||
)
|
||||
|
||||
if sandbox_id__eq is not None:
|
||||
conditions.append(StoredConversationMetadata.sandbox_id == sandbox_id__eq)
|
||||
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
return query
|
||||
@@ -342,10 +334,7 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
await super().save_app_conversation_info(info)
|
||||
|
||||
# Get current user_id for SAAS metadata
|
||||
# Fall back to info.created_by_user_id for webhook callbacks (which use ADMIN context)
|
||||
user_id_str = await self.user_context.get_user_id()
|
||||
if not user_id_str and info.created_by_user_id:
|
||||
user_id_str = info.created_by_user_id
|
||||
if user_id_str:
|
||||
# Convert string user_id to UUID
|
||||
user_id_uuid = UUID(user_id_str)
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import Request
|
||||
from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV, IS_STAGING_ENV
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from openhands.app_server.config import get_global_config
|
||||
|
||||
|
||||
def get_web_url(request: Request):
|
||||
web_url = get_global_config().web_url
|
||||
if not web_url:
|
||||
scheme = 'http' if request.url.hostname == 'localhost' else 'https'
|
||||
web_url = f'{scheme}://{request.url.netloc}'
|
||||
else:
|
||||
web_url = web_url.rstrip('/')
|
||||
return web_url
|
||||
|
||||
|
||||
def get_cookie_domain() -> str | None:
|
||||
config = get_global_config()
|
||||
web_url = config.web_url
|
||||
# for now just use the full hostname except for staging stacks.
|
||||
return (
|
||||
URL(web_url).hostname
|
||||
if web_url and not (IS_FEATURE_ENV or IS_STAGING_ENV or IS_LOCAL_ENV)
|
||||
else None
|
||||
)
|
||||
|
||||
|
||||
def get_cookie_samesite() -> Literal['lax', 'strict']:
|
||||
# for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict'
|
||||
web_url = get_global_config().web_url
|
||||
return (
|
||||
'strict'
|
||||
if web_url and not (IS_FEATURE_ENV or IS_STAGING_ENV or IS_LOCAL_ENV)
|
||||
else 'lax'
|
||||
)
|
||||
@@ -0,0 +1,555 @@
|
||||
"""Automation executor — processes events, claims and executes runs.
|
||||
|
||||
The executor is a long-running process with three phases:
|
||||
1. Process inbox: match NEW events to automations, create PENDING runs
|
||||
2. Claim and execute: claim PENDING runs, submit to V1 API, heartbeat
|
||||
3. Stale recovery: recover RUNNING runs with expired heartbeats
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from services.openhands_api_client import OpenHandsAPIClient
|
||||
from sqlalchemy import or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
logger = logging.getLogger('saas.automation.executor')
|
||||
|
||||
# Environment-configurable settings
|
||||
POLL_INTERVAL_SECONDS = float(os.getenv('POLL_INTERVAL_SECONDS', '30'))
|
||||
HEARTBEAT_INTERVAL_SECONDS = float(os.getenv('HEARTBEAT_INTERVAL_SECONDS', '60'))
|
||||
RUN_TIMEOUT_SECONDS = float(os.getenv('RUN_TIMEOUT_SECONDS', '7200'))
|
||||
MAX_CONCURRENT_RUNS = int(os.getenv('MAX_CONCURRENT_RUNS', '5'))
|
||||
STALE_THRESHOLD_MINUTES = 5
|
||||
MAX_EVENTS_PER_BATCH = 50
|
||||
MAX_RETRIES_DEFAULT = 3
|
||||
|
||||
# Terminal conversation statuses
|
||||
TERMINAL_STATUSES = frozenset({'STOPPED', 'ERROR', 'COMPLETED', 'CANCELLED'})
|
||||
|
||||
# Shutdown flag — set by signal handlers
|
||||
_shutdown_event: asyncio.Event | None = None
|
||||
|
||||
# Background task tracking for graceful shutdown
|
||||
_pending_tasks: set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def get_shutdown_event() -> asyncio.Event:
|
||||
global _shutdown_event
|
||||
if _shutdown_event is None:
|
||||
_shutdown_event = asyncio.Event()
|
||||
return _shutdown_event
|
||||
|
||||
|
||||
def should_continue() -> bool:
|
||||
return not get_shutdown_event().is_set()
|
||||
|
||||
|
||||
def request_shutdown() -> None:
|
||||
get_shutdown_event().set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 1: Process inbox (event matching)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def find_matching_automations(
|
||||
session: AsyncSession, event: AutomationEvent
|
||||
) -> list[Automation]:
|
||||
"""Find automations that match the given event.
|
||||
|
||||
Phase 1 supports cron and manual triggers only — both carry
|
||||
``automation_id`` in the event payload.
|
||||
"""
|
||||
source_type = event.source_type
|
||||
payload = event.payload
|
||||
if payload is None:
|
||||
logger.error('Event %s has None payload — possible data corruption', event.id)
|
||||
return []
|
||||
|
||||
if source_type in ('cron', 'manual'):
|
||||
automation_id = payload.get('automation_id')
|
||||
if not automation_id:
|
||||
logger.warning(
|
||||
'Event %s (source=%s) missing automation_id in payload',
|
||||
event.id,
|
||||
source_type,
|
||||
)
|
||||
return []
|
||||
|
||||
result = await session.execute(
|
||||
select(Automation).where(
|
||||
Automation.id == automation_id,
|
||||
Automation.enabled.is_(True),
|
||||
)
|
||||
)
|
||||
automation = result.scalar_one_or_none()
|
||||
return [automation] if automation else []
|
||||
|
||||
logger.debug('Unhandled event source_type=%s for event %s', source_type, event.id)
|
||||
return []
|
||||
|
||||
|
||||
async def process_new_events(session: AsyncSession) -> int:
|
||||
"""Claim NEW events from inbox, match to automations, create runs.
|
||||
|
||||
Returns the number of events processed.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(AutomationEvent)
|
||||
.where(AutomationEvent.status == 'NEW')
|
||||
.order_by(AutomationEvent.created_at)
|
||||
.limit(MAX_EVENTS_PER_BATCH)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
events = list(result.scalars())
|
||||
|
||||
processed = 0
|
||||
for event in events:
|
||||
try:
|
||||
automations = await find_matching_automations(session, event)
|
||||
if not automations:
|
||||
event.status = 'NO_MATCH'
|
||||
event.processed_at = utc_now()
|
||||
else:
|
||||
for automation in automations:
|
||||
run = AutomationRun(
|
||||
id=uuid4().hex,
|
||||
automation_id=automation.id,
|
||||
event_id=event.id,
|
||||
status='PENDING',
|
||||
event_payload=event.payload,
|
||||
)
|
||||
session.add(run)
|
||||
event.status = 'PROCESSED'
|
||||
event.processed_at = utc_now()
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
logger.exception('Error processing event %s', event.id)
|
||||
event.status = 'ERROR'
|
||||
event.error_detail = f'Failed during event matching: {type(e).__name__}: {e}'
|
||||
event.processed_at = utc_now()
|
||||
|
||||
if processed:
|
||||
await session.commit()
|
||||
logger.info('Processed %d events', processed)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 2: Claim and execute runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def resolve_user_api_key(session: AsyncSession, user_id: str) -> str | None:
|
||||
"""Look up a user's API key from the api_keys table.
|
||||
|
||||
Returns the first active key found, or None.
|
||||
"""
|
||||
from storage.api_key import ApiKey
|
||||
|
||||
result = await session.execute(
|
||||
select(ApiKey.key).where(ApiKey.user_id == user_id).limit(1)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return row
|
||||
|
||||
|
||||
async def download_automation_file(file_store_key: str) -> bytes:
|
||||
"""Download the automation .py file from object storage."""
|
||||
try:
|
||||
from openhands.server.shared import file_store
|
||||
except ImportError as exc:
|
||||
raise RuntimeError(
|
||||
'file_store is not available — ensure the enterprise server '
|
||||
'has been initialised before calling download_automation_file'
|
||||
) from exc
|
||||
|
||||
content = file_store.read(file_store_key)
|
||||
if isinstance(content, str):
|
||||
return content.encode('utf-8')
|
||||
return content
|
||||
|
||||
|
||||
def is_terminal(conversation: dict) -> bool:
|
||||
"""Check if a conversation has reached a terminal status."""
|
||||
status = (conversation.get('status') or '').upper()
|
||||
return status in TERMINAL_STATUSES
|
||||
|
||||
|
||||
async def _prepare_run(
|
||||
run: AutomationRun,
|
||||
automation: Automation,
|
||||
session_factory: object,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Resolve the user's API key and download the automation file.
|
||||
|
||||
Returns:
|
||||
(api_key, automation_file) tuple ready for submission.
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is found.
|
||||
RuntimeError: If file_store is unavailable.
|
||||
"""
|
||||
async with session_factory() as key_session:
|
||||
api_key = await resolve_user_api_key(key_session, automation.user_id)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(f'No API key found for user {automation.user_id}')
|
||||
|
||||
automation_file = await download_automation_file(automation.file_store_key)
|
||||
return api_key, automation_file
|
||||
|
||||
|
||||
async def _monitor_conversation(
|
||||
run: AutomationRun,
|
||||
conversation_id: str,
|
||||
api_client: OpenHandsAPIClient,
|
||||
api_key: str,
|
||||
session_factory: object,
|
||||
) -> bool:
|
||||
"""Monitor a conversation until completion or timeout.
|
||||
|
||||
Returns True if completed successfully, False if shutdown requested.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If the run exceeds RUN_TIMEOUT_SECONDS.
|
||||
"""
|
||||
start_time = utc_now()
|
||||
while should_continue():
|
||||
elapsed = (utc_now() - start_time).total_seconds()
|
||||
if elapsed > RUN_TIMEOUT_SECONDS:
|
||||
raise TimeoutError(f'Run exceeded {RUN_TIMEOUT_SECONDS}s timeout')
|
||||
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL_SECONDS)
|
||||
|
||||
# Update heartbeat
|
||||
async with session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
run_obj.heartbeat_at = utc_now()
|
||||
await session.commit()
|
||||
|
||||
# Check conversation status
|
||||
conversation = (
|
||||
await api_client.get_conversation(api_key, conversation_id) or {}
|
||||
)
|
||||
if is_terminal(conversation):
|
||||
return True
|
||||
|
||||
return False # shutdown requested
|
||||
|
||||
|
||||
async def _submit_and_monitor(
|
||||
run: AutomationRun,
|
||||
api_key: str,
|
||||
automation_file: bytes,
|
||||
automation: Automation,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> None:
|
||||
"""Submit the automation to the V1 API and monitor until completion.
|
||||
|
||||
Updates the run's conversation_id, sends heartbeats, and marks the
|
||||
final status when the conversation reaches a terminal state.
|
||||
"""
|
||||
conversation = await api_client.start_conversation(
|
||||
api_key=api_key,
|
||||
automation_file=automation_file,
|
||||
title=f'Automation: {automation.name}',
|
||||
event_payload=run.event_payload,
|
||||
)
|
||||
|
||||
conversation_id = conversation.get('app_conversation_id') or conversation.get(
|
||||
'conversation_id'
|
||||
)
|
||||
|
||||
# Persist conversation ID
|
||||
async with session_factory() as update_session:
|
||||
run_obj = await update_session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
run_obj.conversation_id = conversation_id
|
||||
await update_session.commit()
|
||||
|
||||
# Monitor with heartbeats
|
||||
completed = await _monitor_conversation(
|
||||
run, conversation_id, api_client, api_key, session_factory
|
||||
)
|
||||
|
||||
# Update final status
|
||||
async with session_factory() as final_session:
|
||||
run_obj = await final_session.get(AutomationRun, run.id)
|
||||
if run_obj:
|
||||
if not completed:
|
||||
# Leave as RUNNING — stale recovery will handle it if needed.
|
||||
# The conversation may still be running on the API side.
|
||||
logger.info(
|
||||
'Run %s left as RUNNING due to executor shutdown', run.id
|
||||
)
|
||||
else:
|
||||
run_obj.status = 'COMPLETED'
|
||||
run_obj.completed_at = utc_now()
|
||||
logger.info('Run %s completed successfully', run.id)
|
||||
await final_session.commit()
|
||||
|
||||
|
||||
async def execute_run(
|
||||
run: AutomationRun,
|
||||
automation: Automation,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> None:
|
||||
"""Execute a single automation run end-to-end.
|
||||
|
||||
Orchestrates preparation (API key + file download) and submission/monitoring.
|
||||
On failure, marks the run for retry or dead-letter.
|
||||
"""
|
||||
try:
|
||||
api_key, automation_file = await _prepare_run(
|
||||
run, automation, session_factory
|
||||
)
|
||||
await _submit_and_monitor(
|
||||
run, api_key, automation_file, automation, api_client, session_factory
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception('Run %s failed: %s', run.id, e)
|
||||
await _mark_run_failed(run, str(e), session_factory)
|
||||
|
||||
|
||||
async def _mark_run_failed(
|
||||
run: AutomationRun, error: str, session_factory: object
|
||||
) -> None:
|
||||
"""Mark a run as FAILED or return to PENDING for retry."""
|
||||
async with session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, run.id)
|
||||
if not run_obj:
|
||||
return
|
||||
|
||||
run_obj.retry_count = (run_obj.retry_count or 0) + 1
|
||||
run_obj.error_detail = error
|
||||
|
||||
if run_obj.retry_count >= (run_obj.max_retries or MAX_RETRIES_DEFAULT):
|
||||
run_obj.status = 'DEAD_LETTER'
|
||||
run_obj.completed_at = utc_now()
|
||||
logger.error(
|
||||
'Run %s moved to DEAD_LETTER after %d retries',
|
||||
run.id,
|
||||
run_obj.retry_count,
|
||||
)
|
||||
else:
|
||||
run_obj.status = 'PENDING'
|
||||
run_obj.claimed_by = None
|
||||
backoff_seconds = 30 * (2 ** (run_obj.retry_count - 1))
|
||||
run_obj.next_retry_at = utc_now() + timedelta(seconds=backoff_seconds)
|
||||
logger.warning(
|
||||
'Run %s returned to PENDING, retry %d/%d in %ds',
|
||||
run.id,
|
||||
run_obj.retry_count,
|
||||
run_obj.max_retries or MAX_RETRIES_DEFAULT,
|
||||
backoff_seconds,
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def claim_and_execute_runs(
|
||||
session: AsyncSession,
|
||||
executor_id: str,
|
||||
api_client: OpenHandsAPIClient,
|
||||
session_factory: object,
|
||||
) -> bool:
|
||||
"""Claim a PENDING run and start executing it.
|
||||
|
||||
Returns True if a run was claimed, False otherwise.
|
||||
"""
|
||||
result = await session.execute(
|
||||
select(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'PENDING',
|
||||
or_(
|
||||
AutomationRun.next_retry_at.is_(None),
|
||||
AutomationRun.next_retry_at <= utc_now(),
|
||||
),
|
||||
)
|
||||
.order_by(AutomationRun.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
run = result.scalar_one_or_none()
|
||||
if not run:
|
||||
return False
|
||||
|
||||
# Claim the run
|
||||
run.status = 'RUNNING'
|
||||
run.claimed_by = executor_id
|
||||
run.claimed_at = utc_now()
|
||||
run.heartbeat_at = utc_now()
|
||||
run.started_at = utc_now()
|
||||
await session.commit()
|
||||
|
||||
# Load automation for the run
|
||||
auto_result = await session.execute(
|
||||
select(Automation).where(Automation.id == run.automation_id)
|
||||
)
|
||||
automation = auto_result.scalar_one_or_none()
|
||||
if not automation:
|
||||
logger.error('Automation %s not found for run %s', run.automation_id, run.id)
|
||||
await _mark_run_failed(
|
||||
run, f'Automation {run.automation_id} not found', session_factory
|
||||
)
|
||||
return True
|
||||
|
||||
# Execute in background (long-running) with task tracking
|
||||
task = asyncio.create_task(
|
||||
execute_run(run, automation, api_client, session_factory),
|
||||
name=f'execute-run-{run.id}',
|
||||
)
|
||||
_pending_tasks.add(task)
|
||||
task.add_done_callback(_pending_tasks.discard)
|
||||
|
||||
logger.info(
|
||||
'Claimed run %s (automation=%s) by executor %s',
|
||||
run.id,
|
||||
run.automation_id,
|
||||
executor_id,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Phase 3: Stale run recovery
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def recover_stale_runs(session: AsyncSession) -> int:
|
||||
"""Mark RUNNING runs with expired heartbeats as PENDING for retry.
|
||||
|
||||
Returns the number of recovered runs.
|
||||
"""
|
||||
stale_threshold = utc_now() - timedelta(minutes=STALE_THRESHOLD_MINUTES)
|
||||
timeout_threshold = utc_now() - timedelta(seconds=RUN_TIMEOUT_SECONDS)
|
||||
|
||||
# Recover stale runs (heartbeat expired)
|
||||
result = await session.execute(
|
||||
update(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'RUNNING',
|
||||
AutomationRun.heartbeat_at < stale_threshold,
|
||||
AutomationRun.heartbeat_at >= timeout_threshold,
|
||||
)
|
||||
.values(
|
||||
status='PENDING',
|
||||
claimed_by=None,
|
||||
retry_count=AutomationRun.retry_count + 1,
|
||||
next_retry_at=utc_now() + timedelta(seconds=30),
|
||||
)
|
||||
.returning(AutomationRun.id)
|
||||
)
|
||||
recovered_rows = result.fetchall()
|
||||
|
||||
# Mark truly timed-out runs as DEAD_LETTER
|
||||
timeout_result = await session.execute(
|
||||
update(AutomationRun)
|
||||
.where(
|
||||
AutomationRun.status == 'RUNNING',
|
||||
AutomationRun.heartbeat_at < timeout_threshold,
|
||||
)
|
||||
.values(
|
||||
status='DEAD_LETTER',
|
||||
error_detail='Run exceeded timeout',
|
||||
completed_at=utc_now(),
|
||||
)
|
||||
.returning(AutomationRun.id)
|
||||
)
|
||||
timed_out_rows = timeout_result.fetchall()
|
||||
|
||||
await session.commit()
|
||||
|
||||
recovered_count = len(recovered_rows)
|
||||
timed_out_count = len(timed_out_rows)
|
||||
|
||||
if recovered_count:
|
||||
logger.warning('Recovered %d stale automation runs', recovered_count)
|
||||
if timed_out_count:
|
||||
logger.warning(
|
||||
'Marked %d automation runs as DEAD_LETTER (timeout)', timed_out_count
|
||||
)
|
||||
|
||||
return recovered_count + timed_out_count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main executor loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def executor_main(session_factory: object | None = None) -> None:
|
||||
"""Main executor loop.
|
||||
|
||||
Args:
|
||||
session_factory: Async context manager that yields AsyncSession instances.
|
||||
If None, uses the default ``a_session_maker`` from database module.
|
||||
"""
|
||||
if session_factory is None:
|
||||
from storage.database import a_session_maker
|
||||
|
||||
session_factory = a_session_maker
|
||||
|
||||
executor_id = f'executor-{socket.gethostname()}-{os.getpid()}'
|
||||
api_url = os.getenv('OPENHANDS_API_URL', 'http://openhands-service:3000')
|
||||
api_client = OpenHandsAPIClient(base_url=api_url)
|
||||
|
||||
logger.info(
|
||||
'Automation executor %s starting (api_url=%s, poll=%ss, heartbeat=%ss)',
|
||||
executor_id,
|
||||
api_url,
|
||||
POLL_INTERVAL_SECONDS,
|
||||
HEARTBEAT_INTERVAL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
while should_continue():
|
||||
try:
|
||||
async with session_factory() as session:
|
||||
await process_new_events(session)
|
||||
|
||||
async with session_factory() as session:
|
||||
await claim_and_execute_runs(
|
||||
session, executor_id, api_client, session_factory
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
await recover_stale_runs(session)
|
||||
|
||||
except Exception:
|
||||
logger.exception('Error in executor main loop iteration')
|
||||
|
||||
# Wait for next poll interval (or early wakeup on shutdown)
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
get_shutdown_event().wait(),
|
||||
timeout=POLL_INTERVAL_SECONDS,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass # Normal — poll interval elapsed
|
||||
|
||||
finally:
|
||||
if _pending_tasks:
|
||||
logger.info(
|
||||
'Waiting for %d running tasks to complete...', len(_pending_tasks)
|
||||
)
|
||||
await asyncio.gather(*_pending_tasks, return_exceptions=True)
|
||||
await api_client.close()
|
||||
logger.info('Automation executor %s shut down', executor_id)
|
||||
@@ -0,0 +1,93 @@
|
||||
"""HTTP client for the main OpenHands V1 API (internal cluster calls).
|
||||
|
||||
Used by the automation executor to create and monitor conversations
|
||||
in the main OpenHands server.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger('saas.automation.api_client')
|
||||
|
||||
|
||||
def _raise_with_body(resp: httpx.Response) -> None:
|
||||
"""Call raise_for_status, enriching the error with the response body."""
|
||||
try:
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_body = resp.text[:500] if resp.text else 'no response body'
|
||||
raise httpx.HTTPStatusError(
|
||||
f'{e.args[0]} — Response: {error_body}',
|
||||
request=e.request,
|
||||
response=e.response,
|
||||
) from e
|
||||
|
||||
|
||||
class OpenHandsAPIClient:
|
||||
"""Async HTTP client for the OpenHands V1 API."""
|
||||
|
||||
def __init__(self, base_url: str = 'http://openhands-service:3000'):
|
||||
self.base_url = base_url.rstrip('/')
|
||||
self.client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0)
|
||||
|
||||
async def start_conversation(
|
||||
self,
|
||||
api_key: str,
|
||||
automation_file: bytes,
|
||||
title: str,
|
||||
event_payload: dict | None = None,
|
||||
) -> dict:
|
||||
"""Submit an SDK script for sandboxed execution via V1 API.
|
||||
|
||||
Args:
|
||||
api_key: User's API key for authentication.
|
||||
automation_file: Raw bytes of the .py automation script.
|
||||
title: Display title for the conversation.
|
||||
event_payload: Optional trigger event data (injected as env var).
|
||||
|
||||
Returns:
|
||||
Parsed JSON response containing conversation details.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API returns a non-2xx status.
|
||||
"""
|
||||
resp = await self.client.post(
|
||||
'/api/v1/app-conversations',
|
||||
json={
|
||||
'automation_file': base64.b64encode(automation_file).decode(),
|
||||
'trigger': 'automation',
|
||||
'title': title,
|
||||
'event_payload': event_payload,
|
||||
},
|
||||
headers={'Authorization': f'Bearer {api_key}'},
|
||||
)
|
||||
_raise_with_body(resp)
|
||||
return resp.json()
|
||||
|
||||
async def get_conversation(self, api_key: str, conversation_id: str) -> dict | None:
|
||||
"""Get conversation status.
|
||||
|
||||
Args:
|
||||
api_key: User's API key for authentication.
|
||||
conversation_id: The conversation ID to look up.
|
||||
|
||||
Returns:
|
||||
Conversation data dict, or None if not found.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API returns a non-2xx status.
|
||||
"""
|
||||
resp = await self.client.get(
|
||||
'/api/v1/app-conversations',
|
||||
params={'ids': [conversation_id]},
|
||||
headers={'Authorization': f'Bearer {api_key}'},
|
||||
)
|
||||
_raise_with_body(resp)
|
||||
conversations = resp.json()
|
||||
return conversations[0] if conversations else None
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying HTTP client."""
|
||||
await self.client.aclose()
|
||||
@@ -0,0 +1,77 @@
|
||||
"""SQLAlchemy models for automations and automation runs.
|
||||
|
||||
Stub for Task 1 (Data Foundation). These models will be replaced when Task 1
|
||||
is merged into automations-phase1.
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.types import JSON
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class Automation(Base):
|
||||
__tablename__ = 'automations'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False, index=True)
|
||||
org_id = Column(String, nullable=True, index=True)
|
||||
|
||||
name = Column(String, nullable=False)
|
||||
enabled = Column(Boolean, nullable=False, server_default=text('true'))
|
||||
|
||||
config = Column(JSON, nullable=False)
|
||||
trigger_type = Column(String, nullable=False)
|
||||
file_store_key = Column(String, nullable=False)
|
||||
|
||||
last_triggered_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
|
||||
runs = relationship('AutomationRun', back_populates='automation')
|
||||
|
||||
|
||||
class AutomationRun(Base):
|
||||
__tablename__ = 'automation_runs'
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
automation_id = Column(
|
||||
String, ForeignKey('automations.id', ondelete='CASCADE'), nullable=False
|
||||
)
|
||||
event_id = Column(Integer, ForeignKey('automation_events.id'), nullable=True)
|
||||
conversation_id = Column(String, nullable=True)
|
||||
status = Column(String, nullable=False, server_default=text("'PENDING'"))
|
||||
claimed_by = Column(String, nullable=True)
|
||||
claimed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
heartbeat_at = Column(DateTime(timezone=True), nullable=True)
|
||||
retry_count = Column(Integer, nullable=False, server_default=text('0'))
|
||||
max_retries = Column(Integer, nullable=False, server_default=text('3'))
|
||||
next_retry_at = Column(DateTime(timezone=True), nullable=True)
|
||||
event_payload = Column(JSON, nullable=True)
|
||||
error_detail = Column(Text, nullable=True)
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
|
||||
automation = relationship('Automation', back_populates='runs')
|
||||
@@ -0,0 +1,27 @@
|
||||
"""SQLAlchemy model for automation events (the inbox).
|
||||
|
||||
Stub for Task 1 (Data Foundation). This model will be replaced when Task 1
|
||||
is merged into automations-phase1.
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, DateTime, Integer, String, Text, text
|
||||
from sqlalchemy.types import JSON
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class AutomationEvent(Base):
|
||||
__tablename__ = 'automation_events'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source_type = Column(String, nullable=False)
|
||||
payload = Column(JSON, nullable=False)
|
||||
metadata_ = Column('metadata', JSON, nullable=True)
|
||||
dedup_key = Column(String, nullable=False, unique=True)
|
||||
status = Column(String, nullable=False, server_default=text("'NEW'"))
|
||||
error_detail = Column(Text, nullable=True)
|
||||
created_at = Column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=text('CURRENT_TIMESTAMP'),
|
||||
)
|
||||
processed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
@@ -29,15 +29,6 @@ KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
try:
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', 0.0))
|
||||
if DEFAULT_INITIAL_BUDGET < 0:
|
||||
raise ValueError(
|
||||
f'DEFAULT_INITIAL_BUDGET must be non-negative, got {DEFAULT_INITIAL_BUDGET}'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f'Invalid DEFAULT_INITIAL_BUDGET environment variable: {e}') from e
|
||||
|
||||
|
||||
def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
"""Generate the key alias for OpenHands Cloud managed keys."""
|
||||
@@ -110,7 +101,7 @@ class LiteLlmManager:
|
||||
) as client:
|
||||
# Check if team already exists and get its budget
|
||||
# New users joining existing orgs should inherit the team's budget
|
||||
team_budget: float = DEFAULT_INITIAL_BUDGET
|
||||
team_budget = 0.0
|
||||
try:
|
||||
existing_team = await LiteLlmManager._get_team(client, org_id)
|
||||
if existing_team:
|
||||
|
||||
@@ -28,9 +28,6 @@ class OrgMemberStore:
|
||||
role_id: int,
|
||||
llm_api_key: str,
|
||||
status: Optional[str] = None,
|
||||
llm_model: Optional[str] = None,
|
||||
llm_base_url: Optional[str] = None,
|
||||
max_iterations: Optional[int] = None,
|
||||
) -> OrgMember:
|
||||
"""Add a user to an organization with a specific role."""
|
||||
async with a_session_maker() as session:
|
||||
@@ -40,9 +37,6 @@ class OrgMemberStore:
|
||||
role_id=role_id,
|
||||
llm_api_key=llm_api_key,
|
||||
status=status,
|
||||
llm_model=llm_model,
|
||||
llm_base_url=llm_base_url,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
session.add(org_member)
|
||||
await session.commit()
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from server.constants import (
|
||||
DEFAULT_V1_ENABLED,
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
get_default_litellm_model,
|
||||
@@ -37,8 +36,6 @@ class OrgStore:
|
||||
org = Org(**kwargs)
|
||||
org.org_version = ORG_SETTINGS_VERSION
|
||||
org.default_llm_model = get_default_litellm_model()
|
||||
if org.v1_enabled is None:
|
||||
org.v1_enabled = DEFAULT_V1_ENABLED
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
await session.refresh(org)
|
||||
|
||||
@@ -187,18 +187,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
# Map Settings fields to Org fields with 'default_' prefix
|
||||
# The generic loop above doesn't update these because Org uses
|
||||
# 'default_llm_model' not 'llm_model', etc.
|
||||
# Use exclude_unset to only update explicitly-set fields (allows clearing with null)
|
||||
settings_data = item.model_dump(exclude_unset=True)
|
||||
if 'llm_model' in settings_data:
|
||||
org.default_llm_model = settings_data['llm_model']
|
||||
if 'llm_base_url' in settings_data:
|
||||
org.default_llm_base_url = settings_data['llm_base_url']
|
||||
if 'max_iterations' in settings_data:
|
||||
org.default_max_iterations = settings_data['max_iterations']
|
||||
|
||||
# Propagate LLM settings to all org members
|
||||
# This ensures all members see the same LLM configuration when an admin saves
|
||||
# Note: Concurrent saves by multiple admins will result in last-write-wins.
|
||||
|
||||
@@ -7,7 +7,6 @@ from uuid import UUID
|
||||
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import (
|
||||
DEFAULT_V1_ENABLED,
|
||||
LITE_LLM_API_URL,
|
||||
ORG_SETTINGS_VERSION,
|
||||
PERSONAL_WORKSPACE_VERSION_TO_MODEL,
|
||||
@@ -242,10 +241,6 @@ class UserStore:
|
||||
if hasattr(org, key):
|
||||
setattr(org, key, value)
|
||||
|
||||
# Apply DEFAULT_V1_ENABLED for migrated orgs if v1_enabled was not set
|
||||
if org.v1_enabled is None:
|
||||
org.v1_enabled = DEFAULT_V1_ENABLED
|
||||
|
||||
user_kwargs = UserStore.get_kwargs_from_user_settings(
|
||||
decrypted_user_settings
|
||||
)
|
||||
@@ -897,8 +892,6 @@ class UserStore:
|
||||
language='en', enable_proactive_conversation_starters=True
|
||||
)
|
||||
|
||||
default_settings.v1_enabled = DEFAULT_V1_ENABLED
|
||||
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
settings = await LiteLlmManager.create_entries(
|
||||
|
||||
Executable
+562
@@ -0,0 +1,562 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Common Room Sync
|
||||
|
||||
This script queries the database to count conversations created by each user,
|
||||
then creates or updates a signal in Common Room for each user with their
|
||||
conversation count.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import requests
|
||||
from sqlalchemy import text
|
||||
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from server.auth.token_manager import get_keycloak_admin
|
||||
from storage.database import get_engine
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('common_room_sync')
|
||||
|
||||
# Common Room API configuration
|
||||
COMMON_ROOM_API_KEY = os.environ.get('COMMON_ROOM_API_KEY')
|
||||
COMMON_ROOM_DESTINATION_SOURCE_ID = os.environ.get('COMMON_ROOM_DESTINATION_SOURCE_ID')
|
||||
COMMON_ROOM_API_BASE_URL = 'https://api.commonroom.io/community/v1'
|
||||
|
||||
# Sync configuration
|
||||
BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '100'))
|
||||
KEYCLOAK_BATCH_SIZE = int(os.environ.get('KEYCLOAK_BATCH_SIZE', '20'))
|
||||
MAX_RETRIES = int(os.environ.get('MAX_RETRIES', '3'))
|
||||
INITIAL_BACKOFF_SECONDS = float(os.environ.get('INITIAL_BACKOFF_SECONDS', '1'))
|
||||
MAX_BACKOFF_SECONDS = float(os.environ.get('MAX_BACKOFF_SECONDS', '60'))
|
||||
BACKOFF_FACTOR = float(os.environ.get('BACKOFF_FACTOR', '2'))
|
||||
RATE_LIMIT = float(os.environ.get('RATE_LIMIT', '2')) # Requests per second
|
||||
|
||||
|
||||
class CommonRoomSyncError(Exception):
|
||||
"""Base exception for Common Room sync errors."""
|
||||
|
||||
|
||||
class DatabaseError(CommonRoomSyncError):
|
||||
"""Exception for database errors."""
|
||||
|
||||
|
||||
class CommonRoomAPIError(CommonRoomSyncError):
|
||||
"""Exception for Common Room API errors."""
|
||||
|
||||
|
||||
class KeycloakClientError(CommonRoomSyncError):
|
||||
"""Exception for Keycloak client errors."""
|
||||
|
||||
|
||||
def get_recent_conversations(minutes: int = 60) -> List[Dict[str, Any]]:
|
||||
"""Get conversations created in the past N minutes.
|
||||
|
||||
Args:
|
||||
minutes: Number of minutes to look back for new conversations.
|
||||
|
||||
Returns:
|
||||
A list of dictionaries, each containing conversation details.
|
||||
|
||||
Raises:
|
||||
DatabaseError: If the database query fails.
|
||||
"""
|
||||
try:
|
||||
# Use a different syntax for the interval that works with pg8000
|
||||
query = text("""
|
||||
SELECT
|
||||
conversation_id, user_id, title, created_at
|
||||
FROM
|
||||
conversation_metadata
|
||||
WHERE
|
||||
created_at >= NOW() - (INTERVAL '1 minute' * :minutes)
|
||||
ORDER BY
|
||||
created_at DESC
|
||||
""")
|
||||
|
||||
with get_engine().connect() as connection:
|
||||
result = connection.execute(query, {'minutes': minutes})
|
||||
conversations = [
|
||||
{
|
||||
'conversation_id': row[0],
|
||||
'user_id': row[1],
|
||||
'title': row[2],
|
||||
'created_at': row[3].isoformat() if row[3] else None,
|
||||
}
|
||||
for row in result
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f'Retrieved {len(conversations)} conversations created in the past {minutes} minutes'
|
||||
)
|
||||
return conversations
|
||||
except Exception as e:
|
||||
logger.exception(f'Error querying recent conversations: {e}')
|
||||
raise DatabaseError(f'Failed to query recent conversations: {e}')
|
||||
|
||||
|
||||
async def get_users_from_keycloak(user_ids: Set[str]) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get user information from Keycloak for a set of user IDs.
|
||||
|
||||
Args:
|
||||
user_ids: A set of user IDs to look up.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping user IDs to user information dictionaries.
|
||||
|
||||
Raises:
|
||||
KeycloakClientError: If the Keycloak API call fails.
|
||||
"""
|
||||
try:
|
||||
# Get Keycloak admin client
|
||||
keycloak_admin = get_keycloak_admin()
|
||||
|
||||
# Create a dictionary to store user information
|
||||
user_info_dict = {}
|
||||
|
||||
# Convert set to list for easier batching
|
||||
user_id_list = list(user_ids)
|
||||
|
||||
# Process user IDs in batches
|
||||
for i in range(0, len(user_id_list), KEYCLOAK_BATCH_SIZE):
|
||||
batch = user_id_list[i : i + KEYCLOAK_BATCH_SIZE]
|
||||
batch_tasks = []
|
||||
|
||||
# Create tasks for each user ID in the batch
|
||||
for user_id in batch:
|
||||
# Use the Keycloak admin client to get user by ID
|
||||
batch_tasks.append(get_user_by_id(keycloak_admin, user_id))
|
||||
|
||||
# Run the batch of tasks concurrently
|
||||
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||
|
||||
# Process the results
|
||||
for user_id, result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f'Error getting user {user_id}: {result}')
|
||||
continue
|
||||
|
||||
if result and isinstance(result, dict):
|
||||
user_info_dict[user_id] = {
|
||||
'username': result.get('username'),
|
||||
'email': result.get('email'),
|
||||
'id': result.get('id'),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f'Retrieved information for {len(user_info_dict)} users from Keycloak'
|
||||
)
|
||||
return user_info_dict
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f'Error getting users from Keycloak: {e}'
|
||||
logger.exception(error_msg)
|
||||
raise KeycloakClientError(error_msg)
|
||||
|
||||
|
||||
async def get_user_by_id(keycloak_admin, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a user from Keycloak by ID.
|
||||
|
||||
Args:
|
||||
keycloak_admin: The Keycloak admin client.
|
||||
user_id: The user ID to look up.
|
||||
|
||||
Returns:
|
||||
A dictionary with the user's information, or None if not found.
|
||||
"""
|
||||
try:
|
||||
# Use the Keycloak admin client to get user by ID
|
||||
user = keycloak_admin.get_user(user_id)
|
||||
if user:
|
||||
logger.debug(
|
||||
f"Found user in Keycloak: {user.get('username')}, {user.get('email')}"
|
||||
)
|
||||
return user
|
||||
else:
|
||||
logger.warning(f'User {user_id} not found in Keycloak')
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f'Error getting user {user_id} from Keycloak: {e}')
|
||||
return None
|
||||
|
||||
|
||||
def get_user_info(
|
||||
user_id: str, user_info_cache: Dict[str, Dict[str, Any]]
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Get the email address and GitHub username for a user from the cache.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to look up.
|
||||
user_info_cache: A dictionary mapping user IDs to user information.
|
||||
|
||||
Returns:
|
||||
A dictionary with the user's email and username, or None if not found.
|
||||
"""
|
||||
# Check if the user is in the cache
|
||||
if user_id in user_info_cache:
|
||||
user_info = user_info_cache[user_id]
|
||||
logger.debug(
|
||||
f"Found user info in cache: {user_info.get('username')}, {user_info.get('email')}"
|
||||
)
|
||||
return user_info
|
||||
else:
|
||||
logger.warning(f'User {user_id} not found in user info cache')
|
||||
return None
|
||||
|
||||
|
||||
def register_user_in_common_room(
|
||||
user_id: str, email: str, github_username: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Create or update a user in Common Room.
|
||||
|
||||
Args:
|
||||
user_id: The user ID.
|
||||
email: The user's email address.
|
||||
github_username: The user's GitHub username.
|
||||
|
||||
Returns:
|
||||
The API response from Common Room.
|
||||
|
||||
Raises:
|
||||
CommonRoomAPIError: If the Common Room API request fails.
|
||||
"""
|
||||
if not COMMON_ROOM_API_KEY:
|
||||
raise CommonRoomAPIError('COMMON_ROOM_API_KEY environment variable not set')
|
||||
|
||||
if not COMMON_ROOM_DESTINATION_SOURCE_ID:
|
||||
raise CommonRoomAPIError(
|
||||
'COMMON_ROOM_DESTINATION_SOURCE_ID environment variable not set'
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {COMMON_ROOM_API_KEY}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
# Create or update user in Common Room
|
||||
user_data = {
|
||||
'id': user_id,
|
||||
'email': email,
|
||||
'username': github_username,
|
||||
'github': {'type': 'handle', 'value': github_username},
|
||||
}
|
||||
|
||||
user_url = f'{COMMON_ROOM_API_BASE_URL}/source/{COMMON_ROOM_DESTINATION_SOURCE_ID}/user'
|
||||
user_response = requests.post(user_url, headers=headers, json=user_data)
|
||||
|
||||
if user_response.status_code not in (200, 202):
|
||||
logger.error(
|
||||
f'Failed to create/update user in Common Room: {user_response.text}'
|
||||
)
|
||||
logger.error(f'Response status code: {user_response.status_code}')
|
||||
raise CommonRoomAPIError(
|
||||
f'Failed to create/update user: {user_response.text}'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Registered/updated user {user_id} (GitHub: {github_username}) in Common Room'
|
||||
)
|
||||
return user_response.json()
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f'Error communicating with Common Room API: {e}')
|
||||
raise CommonRoomAPIError(f'Failed to communicate with Common Room API: {e}')
|
||||
|
||||
|
||||
def register_conversation_activity(
|
||||
user_id: str,
|
||||
conversation_id: str,
|
||||
conversation_title: str,
|
||||
created_at: datetime,
|
||||
email: str,
|
||||
github_username: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create an activity in Common Room for a new conversation.
|
||||
|
||||
Args:
|
||||
user_id: The user ID who created the conversation.
|
||||
conversation_id: The ID of the conversation.
|
||||
conversation_title: The title of the conversation.
|
||||
created_at: The datetime object when the conversation was created.
|
||||
email: The user's email address.
|
||||
github_username: The user's GitHub username.
|
||||
|
||||
Returns:
|
||||
The API response from Common Room.
|
||||
|
||||
Raises:
|
||||
CommonRoomAPIError: If the Common Room API request fails.
|
||||
"""
|
||||
if not COMMON_ROOM_API_KEY:
|
||||
raise CommonRoomAPIError('COMMON_ROOM_API_KEY environment variable not set')
|
||||
|
||||
if not COMMON_ROOM_DESTINATION_SOURCE_ID:
|
||||
raise CommonRoomAPIError(
|
||||
'COMMON_ROOM_DESTINATION_SOURCE_ID environment variable not set'
|
||||
)
|
||||
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {COMMON_ROOM_API_KEY}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
# Format the datetime object to the expected ISO format
|
||||
formatted_timestamp = (
|
||||
created_at.strftime('%Y-%m-%dT%H:%M:%SZ')
|
||||
if created_at
|
||||
else time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
|
||||
)
|
||||
|
||||
# Create activity for the conversation
|
||||
activity_data = {
|
||||
'id': f'conversation_{conversation_id}', # Use conversation ID to ensure uniqueness
|
||||
'activityType': 'started_session',
|
||||
'user': {
|
||||
'id': user_id,
|
||||
'email': email,
|
||||
'github': {'type': 'handle', 'value': github_username},
|
||||
'username': github_username,
|
||||
},
|
||||
'activityTitle': {
|
||||
'type': 'text',
|
||||
'value': conversation_title or 'New Conversation',
|
||||
},
|
||||
'content': {
|
||||
'type': 'text',
|
||||
'value': f'Started a new conversation: {conversation_title or "Untitled"}',
|
||||
},
|
||||
'timestamp': formatted_timestamp,
|
||||
'url': f'https://app.all-hands.dev/conversations/{conversation_id}',
|
||||
}
|
||||
|
||||
# Log the activity data for debugging
|
||||
logger.info(f'Activity data payload: {activity_data}')
|
||||
|
||||
activity_url = f'{COMMON_ROOM_API_BASE_URL}/source/{COMMON_ROOM_DESTINATION_SOURCE_ID}/activity'
|
||||
activity_response = requests.post(
|
||||
activity_url, headers=headers, json=activity_data
|
||||
)
|
||||
|
||||
if activity_response.status_code not in (200, 202):
|
||||
logger.error(
|
||||
f'Failed to create activity in Common Room: {activity_response.text}'
|
||||
)
|
||||
logger.error(f'Response status code: {activity_response.status_code}')
|
||||
raise CommonRoomAPIError(
|
||||
f'Failed to create activity: {activity_response.text}'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Registered conversation activity for user {user_id}, conversation {conversation_id}'
|
||||
)
|
||||
return activity_response.json()
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f'Error communicating with Common Room API: {e}')
|
||||
raise CommonRoomAPIError(f'Failed to communicate with Common Room API: {e}')
|
||||
|
||||
|
||||
def retry_with_backoff(func, *args, **kwargs):
|
||||
"""Retry a function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func: The function to retry.
|
||||
*args: Positional arguments to pass to the function.
|
||||
**kwargs: Keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
The result of the function call.
|
||||
|
||||
Raises:
|
||||
The last exception raised by the function.
|
||||
"""
|
||||
backoff = INITIAL_BACKOFF_SECONDS
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(f'Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}')
|
||||
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
sleep_time = min(backoff, MAX_BACKOFF_SECONDS)
|
||||
logger.info(f'Retrying in {sleep_time:.2f} seconds...')
|
||||
time.sleep(sleep_time)
|
||||
backoff *= BACKOFF_FACTOR
|
||||
else:
|
||||
logger.exception(f'All {MAX_RETRIES} attempts failed')
|
||||
raise last_exception
|
||||
|
||||
|
||||
async def retry_with_backoff_async(func, *args, **kwargs):
|
||||
"""Retry an async function with exponential backoff.
|
||||
|
||||
Args:
|
||||
func: The async function to retry.
|
||||
*args: Positional arguments to pass to the function.
|
||||
**kwargs: Keyword arguments to pass to the function.
|
||||
|
||||
Returns:
|
||||
The result of the function call.
|
||||
|
||||
Raises:
|
||||
The last exception raised by the function.
|
||||
"""
|
||||
backoff = INITIAL_BACKOFF_SECONDS
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.warning(f'Attempt {attempt + 1}/{MAX_RETRIES} failed: {e}')
|
||||
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
sleep_time = min(backoff, MAX_BACKOFF_SECONDS)
|
||||
logger.info(f'Retrying in {sleep_time:.2f} seconds...')
|
||||
await asyncio.sleep(sleep_time)
|
||||
backoff *= BACKOFF_FACTOR
|
||||
else:
|
||||
logger.exception(f'All {MAX_RETRIES} attempts failed')
|
||||
raise last_exception
|
||||
|
||||
|
||||
async def async_sync_recent_conversations_to_common_room(minutes: int = 60):
|
||||
"""Async main function to sync recent conversations to Common Room.
|
||||
|
||||
Args:
|
||||
minutes: Number of minutes to look back for new conversations.
|
||||
"""
|
||||
logger.info(
|
||||
f'Starting Common Room recent conversations sync (past {minutes} minutes)'
|
||||
)
|
||||
|
||||
stats = {
|
||||
'total_conversations': 0,
|
||||
'registered_users': 0,
|
||||
'registered_activities': 0,
|
||||
'errors': 0,
|
||||
'missing_user_info': 0,
|
||||
}
|
||||
|
||||
try:
|
||||
# Get conversations created in the past N minutes
|
||||
recent_conversations = retry_with_backoff(get_recent_conversations, minutes)
|
||||
stats['total_conversations'] = len(recent_conversations)
|
||||
|
||||
logger.info(f'Processing {len(recent_conversations)} recent conversations')
|
||||
|
||||
if not recent_conversations:
|
||||
logger.info('No recent conversations found, exiting')
|
||||
return
|
||||
|
||||
# Extract all unique user IDs
|
||||
user_ids = {conv['user_id'] for conv in recent_conversations if conv['user_id']}
|
||||
|
||||
# Get user information for all users in batches
|
||||
user_info_cache = await retry_with_backoff_async(
|
||||
get_users_from_keycloak, user_ids
|
||||
)
|
||||
|
||||
# Track registered users to avoid duplicate registrations
|
||||
registered_users = set()
|
||||
|
||||
# Process each conversation
|
||||
for conversation in recent_conversations:
|
||||
conversation_id = conversation['conversation_id']
|
||||
user_id = conversation['user_id']
|
||||
title = conversation['title']
|
||||
created_at = conversation[
|
||||
'created_at'
|
||||
] # This might be a string or datetime object
|
||||
|
||||
try:
|
||||
# Get user info from cache
|
||||
user_info = get_user_info(user_id, user_info_cache)
|
||||
if not user_info:
|
||||
logger.warning(
|
||||
f'Could not find user info for user {user_id}, skipping conversation {conversation_id}'
|
||||
)
|
||||
stats['missing_user_info'] += 1
|
||||
continue
|
||||
|
||||
email = user_info['email']
|
||||
github_username = user_info['username']
|
||||
|
||||
if not email:
|
||||
logger.warning(
|
||||
f'User {user_id} has no email, skipping conversation {conversation_id}'
|
||||
)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
# Register user in Common Room if not already registered in this run
|
||||
if user_id not in registered_users:
|
||||
register_user_in_common_room(user_id, email, github_username)
|
||||
registered_users.add(user_id)
|
||||
stats['registered_users'] += 1
|
||||
|
||||
# If created_at is a string, parse it to a datetime object
|
||||
# If it's already a datetime object, use it as is
|
||||
# If it's None, use current time
|
||||
created_at_datetime = (
|
||||
created_at
|
||||
if isinstance(created_at, datetime)
|
||||
else datetime.fromisoformat(created_at.replace('Z', '+00:00'))
|
||||
if created_at
|
||||
else datetime.now(UTC)
|
||||
)
|
||||
|
||||
# Register conversation activity with email and github username
|
||||
register_conversation_activity(
|
||||
user_id,
|
||||
conversation_id,
|
||||
title,
|
||||
created_at_datetime,
|
||||
email,
|
||||
github_username,
|
||||
)
|
||||
stats['registered_activities'] += 1
|
||||
|
||||
# Sleep to respect rate limit
|
||||
await asyncio.sleep(1 / RATE_LIMIT)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f'Error processing conversation {conversation_id} for user {user_id}: {e}'
|
||||
)
|
||||
stats['errors'] += 1
|
||||
except Exception as e:
|
||||
logger.exception(f'Sync failed: {e}')
|
||||
raise
|
||||
finally:
|
||||
logger.info(f'Sync completed. Stats: {stats}')
|
||||
|
||||
|
||||
def sync_recent_conversations_to_common_room(minutes: int = 60):
|
||||
"""Main function to sync recent conversations to Common Room.
|
||||
|
||||
Args:
|
||||
minutes: Number of minutes to look back for new conversations.
|
||||
"""
|
||||
# Run the async function in the event loop
|
||||
asyncio.run(async_sync_recent_conversations_to_common_room(minutes))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Default to looking back 60 minutes for new conversations
|
||||
minutes = int(os.environ.get('SYNC_MINUTES', '60'))
|
||||
sync_recent_conversations_to_common_room(minutes)
|
||||
Executable
+51
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for Common Room conversation count sync.
|
||||
|
||||
This script tests the functionality of the Common Room sync script
|
||||
without making any API calls to Common Room or database connections.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from sync.common_room_sync import (
|
||||
retry_with_backoff,
|
||||
)
|
||||
|
||||
|
||||
class TestCommonRoomSync(unittest.TestCase):
|
||||
"""Test cases for Common Room sync functionality."""
|
||||
|
||||
def test_retry_with_backoff(self):
|
||||
"""Test the retry_with_backoff function."""
|
||||
# Mock function that succeeds on the second attempt
|
||||
mock_func = MagicMock(
|
||||
side_effect=[Exception('First attempt failed'), 'success']
|
||||
)
|
||||
|
||||
# Set environment variables for testing
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'MAX_RETRIES': '3',
|
||||
'INITIAL_BACKOFF_SECONDS': '0.01',
|
||||
'BACKOFF_FACTOR': '2',
|
||||
'MAX_BACKOFF_SECONDS': '1',
|
||||
},
|
||||
):
|
||||
result = retry_with_backoff(mock_func, 'arg1', 'arg2', kwarg1='kwarg1')
|
||||
|
||||
# Check that the function was called twice
|
||||
self.assertEqual(mock_func.call_count, 2)
|
||||
# Check that the function was called with the correct arguments
|
||||
mock_func.assert_called_with('arg1', 'arg2', kwarg1='kwarg1')
|
||||
# Check that the function returned the expected result
|
||||
self.assertEqual(result, 'success')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
+83
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Test script to verify the conversation count query.
|
||||
|
||||
This script tests the database query to count conversations by user,
|
||||
without making any API calls to Common Room.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
# Add the parent directory to the path so we can import from storage
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from storage.database import get_engine
|
||||
|
||||
|
||||
def test_conversation_count_query():
|
||||
"""Test the query to count conversations by user."""
|
||||
try:
|
||||
# Query to count conversations by user
|
||||
count_query = text("""
|
||||
SELECT
|
||||
user_id, COUNT(*) as conversation_count
|
||||
FROM
|
||||
conversation_metadata
|
||||
GROUP BY
|
||||
user_id
|
||||
""")
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
with engine.connect() as connection:
|
||||
count_result = connection.execute(count_query)
|
||||
user_counts = [
|
||||
{'user_id': row[0], 'conversation_count': row[1]}
|
||||
for row in count_result
|
||||
]
|
||||
|
||||
print(f'Found {len(user_counts)} users with conversations')
|
||||
|
||||
# Print the first 5 results
|
||||
for i, user_data in enumerate(user_counts[:5]):
|
||||
print(
|
||||
f"User {i+1}: {user_data['user_id']} - {user_data['conversation_count']} conversations"
|
||||
)
|
||||
|
||||
# Test the user_entity query for the first user (if any)
|
||||
if user_counts:
|
||||
first_user_id = user_counts[0]['user_id']
|
||||
|
||||
user_query = text("""
|
||||
SELECT username, email, id
|
||||
FROM user_entity
|
||||
WHERE id = :user_id
|
||||
""")
|
||||
|
||||
with engine.connect() as connection:
|
||||
user_result = connection.execute(user_query, {'user_id': first_user_id})
|
||||
user_row = user_result.fetchone()
|
||||
|
||||
if user_row:
|
||||
print(f'\nUser details for {first_user_id}:')
|
||||
print(f' GitHub Username: {user_row[0]}')
|
||||
print(f' Email: {user_row[1]}')
|
||||
print(f' ID: {user_row[2]}')
|
||||
else:
|
||||
print(
|
||||
f'\nNo user details found for {first_user_id} in user_entity table'
|
||||
)
|
||||
|
||||
print('\nTest completed successfully')
|
||||
except Exception as e:
|
||||
print(f'Error: {str(e)}')
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_conversation_count_query()
|
||||
@@ -0,0 +1,68 @@
|
||||
"""Shared fixtures for services tests.
|
||||
|
||||
Note: We pre-load ``storage`` as a namespace package to avoid the heavy
|
||||
``storage/__init__.py`` that imports the entire enterprise model graph.
|
||||
This must happen *before* any ``from storage.…`` import.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
# Prevent storage/__init__.py from loading the full model graph.
|
||||
# We only need the lightweight automation models for these tests.
|
||||
if 'storage' not in sys.modules:
|
||||
import pathlib
|
||||
|
||||
_storage_dir = str(pathlib.Path(__file__).resolve().parents[3] / 'storage')
|
||||
_mod = types.ModuleType('storage')
|
||||
_mod.__path__ = [_storage_dir]
|
||||
sys.modules['storage'] = _mod
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from storage.automation import Automation, AutomationRun # noqa: F401
|
||||
from storage.automation_event import AutomationEvent # noqa: F401
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create an async SQLite engine for testing."""
|
||||
engine = create_async_engine(
|
||||
'sqlite+aiosqlite:///:memory:',
|
||||
poolclass=StaticPool,
|
||||
connect_args={'check_same_thread': False},
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session_factory(async_engine):
|
||||
"""Create an async session factory that yields context-managed sessions."""
|
||||
factory = async_sessionmaker(
|
||||
bind=async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _session_ctx():
|
||||
async with factory() as session:
|
||||
yield session
|
||||
|
||||
return _session_ctx
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_session_factory):
|
||||
"""Create a single async session for testing."""
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
@@ -0,0 +1,624 @@
|
||||
"""Tests for the automation executor.
|
||||
|
||||
Uses real SQLite database operations for event processing, run claiming,
|
||||
and stale run recovery. HTTP calls to the V1 API are mocked.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from services.automation_executor import (
|
||||
_mark_run_failed,
|
||||
claim_and_execute_runs,
|
||||
find_matching_automations,
|
||||
is_terminal,
|
||||
process_new_events,
|
||||
recover_stale_runs,
|
||||
utc_now,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from storage.automation import Automation, AutomationRun
|
||||
from storage.automation_event import AutomationEvent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_automation(
|
||||
automation_id: str = 'auto-1',
|
||||
user_id: str = 'user-1',
|
||||
enabled: bool = True,
|
||||
trigger_type: str = 'cron',
|
||||
name: str = 'Test Automation',
|
||||
) -> Automation:
|
||||
return Automation(
|
||||
id=automation_id,
|
||||
user_id=user_id,
|
||||
org_id='org-1',
|
||||
name=name,
|
||||
enabled=enabled,
|
||||
config={'triggers': {'cron': {'schedule': '0 9 * * 5'}}},
|
||||
trigger_type=trigger_type,
|
||||
file_store_key=f'automations/{automation_id}/script.py',
|
||||
)
|
||||
|
||||
|
||||
def make_event(
|
||||
source_type: str = 'cron',
|
||||
payload: dict | None = None,
|
||||
status: str = 'NEW',
|
||||
dedup_key: str | None = None,
|
||||
) -> AutomationEvent:
|
||||
return AutomationEvent(
|
||||
source_type=source_type,
|
||||
payload=payload or {'automation_id': 'auto-1'},
|
||||
dedup_key=dedup_key or f'dedup-{uuid4().hex[:8]}',
|
||||
status=status,
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
def make_run(
|
||||
run_id: str | None = None,
|
||||
automation_id: str = 'auto-1',
|
||||
status: str = 'PENDING',
|
||||
claimed_by: str | None = None,
|
||||
heartbeat_at: datetime | None = None,
|
||||
retry_count: int = 0,
|
||||
max_retries: int = 3,
|
||||
next_retry_at: datetime | None = None,
|
||||
) -> AutomationRun:
|
||||
return AutomationRun(
|
||||
id=run_id or uuid4().hex,
|
||||
automation_id=automation_id,
|
||||
status=status,
|
||||
claimed_by=claimed_by,
|
||||
heartbeat_at=heartbeat_at,
|
||||
retry_count=retry_count,
|
||||
max_retries=max_retries,
|
||||
next_retry_at=next_retry_at,
|
||||
event_payload={'automation_id': automation_id},
|
||||
created_at=utc_now(),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_automations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_cron_event(async_session):
|
||||
"""Cron events match by automation_id in payload."""
|
||||
automation = make_automation()
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(source_type='cron', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == 'auto-1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_manual_event(async_session):
|
||||
"""Manual events also match by automation_id in payload."""
|
||||
automation = make_automation()
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(source_type='manual', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].id == 'auto-1'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_disabled_automation(async_session):
|
||||
"""Disabled automations are not matched."""
|
||||
automation = make_automation(enabled=False)
|
||||
async_session.add(automation)
|
||||
await async_session.commit()
|
||||
|
||||
event = make_event(payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_missing_automation_id(async_session):
|
||||
"""Events without automation_id in payload return empty list."""
|
||||
event = make_event(payload={'something_else': 'value'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_nonexistent_automation(async_session):
|
||||
"""Events referencing a non-existent automation return empty list."""
|
||||
event = make_event(payload={'automation_id': 'nonexistent'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_unknown_source_type(async_session):
|
||||
"""Unknown source types return empty list."""
|
||||
event = make_event(source_type='unknown', payload={'automation_id': 'auto-1'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# process_new_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_creates_runs(async_session):
|
||||
"""Processing NEW events creates PENDING runs and marks events PROCESSED."""
|
||||
automation = make_automation()
|
||||
event = make_event(payload={'automation_id': 'auto-1'})
|
||||
async_session.add_all([automation, event])
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 1
|
||||
|
||||
# Event should be PROCESSED
|
||||
await async_session.refresh(event)
|
||||
assert event.status == 'PROCESSED'
|
||||
assert event.processed_at is not None
|
||||
|
||||
# A run should have been created
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
assert runs[0].automation_id == 'auto-1'
|
||||
assert runs[0].status == 'PENDING'
|
||||
assert runs[0].event_payload == {'automation_id': 'auto-1'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_no_match(async_session):
|
||||
"""Events with no matching automation are marked NO_MATCH."""
|
||||
event = make_event(payload={'automation_id': 'nonexistent'})
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 1
|
||||
|
||||
await async_session.refresh(event)
|
||||
assert event.status == 'NO_MATCH'
|
||||
assert event.processed_at is not None
|
||||
|
||||
# No runs created
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_skips_processed(async_session):
|
||||
"""Already processed events are not re-processed."""
|
||||
event = make_event(status='PROCESSED')
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_process_new_events_multiple_events(async_session):
|
||||
"""Multiple NEW events are processed in one batch."""
|
||||
auto1 = make_automation(automation_id='auto-1')
|
||||
auto2 = make_automation(automation_id='auto-2', name='Auto 2')
|
||||
event1 = make_event(payload={'automation_id': 'auto-1'}, dedup_key='dedup-1')
|
||||
event2 = make_event(payload={'automation_id': 'auto-2'}, dedup_key='dedup-2')
|
||||
event3 = make_event(payload={'automation_id': 'nonexistent'}, dedup_key='dedup-3')
|
||||
async_session.add_all([auto1, auto2, event1, event2, event3])
|
||||
await async_session.commit()
|
||||
|
||||
count = await process_new_events(async_session)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Two runs created (for auto-1 and auto-2), none for nonexistent
|
||||
runs = (await async_session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 2
|
||||
|
||||
await async_session.refresh(event1)
|
||||
await async_session.refresh(event2)
|
||||
await async_session.refresh(event3)
|
||||
assert event1.status == 'PROCESSED'
|
||||
assert event2.status == 'PROCESSED'
|
||||
assert event3.status == 'NO_MATCH'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# claim_and_execute_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_claims_pending(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Claims a PENDING run and transitions to RUNNING."""
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='run-1')
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
await async_session.refresh(run)
|
||||
assert run.status == 'RUNNING'
|
||||
assert run.claimed_by == 'executor-test-1'
|
||||
assert run.claimed_at is not None
|
||||
assert run.heartbeat_at is not None
|
||||
assert run.started_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_no_pending(async_session, async_session_factory):
|
||||
"""Returns False when no PENDING runs exist."""
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_respects_next_retry_at(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Runs with future next_retry_at are not claimed."""
|
||||
automation = make_automation()
|
||||
run = make_run(
|
||||
run_id='run-retry',
|
||||
next_retry_at=utc_now() + timedelta(hours=1),
|
||||
)
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_and_execute_runs_past_retry_at(
|
||||
async_session, async_session_factory
|
||||
):
|
||||
"""Runs with past next_retry_at are claimable."""
|
||||
automation = make_automation()
|
||||
run = make_run(
|
||||
run_id='run-retry-past',
|
||||
next_retry_at=utc_now() - timedelta(minutes=5),
|
||||
)
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_skips_running_runs(async_session, async_session_factory):
|
||||
"""RUNNING runs are not claimed."""
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='run-running', status='RUNNING', claimed_by='other-executor')
|
||||
async_session.add_all([automation, run])
|
||||
await async_session.commit()
|
||||
|
||||
api_client = AsyncMock()
|
||||
|
||||
claimed = await claim_and_execute_runs(
|
||||
async_session, 'executor-test-1', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# recover_stale_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_recovers_stale(async_session):
|
||||
"""RUNNING runs with expired heartbeats are recovered to PENDING."""
|
||||
automation = make_automation()
|
||||
stale_run = make_run(
|
||||
run_id='stale-1',
|
||||
status='RUNNING',
|
||||
claimed_by='crashed-executor',
|
||||
heartbeat_at=utc_now() - timedelta(minutes=10),
|
||||
retry_count=0,
|
||||
)
|
||||
async_session.add_all([automation, stale_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count >= 1
|
||||
|
||||
await async_session.refresh(stale_run)
|
||||
assert stale_run.status == 'PENDING'
|
||||
assert stale_run.claimed_by is None
|
||||
assert stale_run.retry_count == 1
|
||||
assert stale_run.next_retry_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_ignores_fresh(async_session):
|
||||
"""RUNNING runs with recent heartbeats are not recovered."""
|
||||
automation = make_automation()
|
||||
fresh_run = make_run(
|
||||
run_id='fresh-1',
|
||||
status='RUNNING',
|
||||
claimed_by='active-executor',
|
||||
heartbeat_at=utc_now() - timedelta(seconds=30),
|
||||
)
|
||||
async_session.add_all([automation, fresh_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
await async_session.refresh(fresh_run)
|
||||
assert fresh_run.status == 'RUNNING'
|
||||
assert fresh_run.claimed_by == 'active-executor'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_ignores_pending(async_session):
|
||||
"""PENDING runs are not affected by recovery."""
|
||||
automation = make_automation()
|
||||
pending_run = make_run(run_id='pending-1', status='PENDING')
|
||||
async_session.add_all([automation, pending_run])
|
||||
await async_session.commit()
|
||||
|
||||
count = await recover_stale_runs(async_session)
|
||||
|
||||
assert count == 0
|
||||
|
||||
await async_session.refresh(pending_run)
|
||||
assert pending_run.status == 'PENDING'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover_stale_runs_increments_retry_count(async_session):
|
||||
"""Recovery increments the retry_count."""
|
||||
automation = make_automation()
|
||||
stale_run = make_run(
|
||||
run_id='stale-retry',
|
||||
status='RUNNING',
|
||||
claimed_by='old-executor',
|
||||
heartbeat_at=utc_now() - timedelta(minutes=10),
|
||||
retry_count=2,
|
||||
)
|
||||
async_session.add_all([automation, stale_run])
|
||||
await async_session.commit()
|
||||
|
||||
await recover_stale_runs(async_session)
|
||||
|
||||
await async_session.refresh(stale_run)
|
||||
assert stale_run.retry_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _mark_run_failed (error handling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_run_failed_retries(async_session_factory):
|
||||
"""Failed runs with retries left return to PENDING."""
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='fail-retry', retry_count=0, max_retries=3)
|
||||
session.add_all([automation, run])
|
||||
await session.commit()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-retry')
|
||||
await _mark_run_failed(run_obj, 'API error', async_session_factory)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-retry')
|
||||
assert run_obj.status == 'PENDING'
|
||||
assert run_obj.retry_count == 1
|
||||
assert run_obj.error_detail == 'API error'
|
||||
assert run_obj.next_retry_at is not None
|
||||
assert run_obj.claimed_by is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_run_failed_dead_letter(async_session_factory):
|
||||
"""Failed runs that exceed max_retries go to DEAD_LETTER."""
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation()
|
||||
run = make_run(run_id='fail-dead', retry_count=2, max_retries=3)
|
||||
session.add_all([automation, run])
|
||||
await session.commit()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-dead')
|
||||
await _mark_run_failed(run_obj, 'Final failure', async_session_factory)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
run_obj = await session.get(AutomationRun, 'fail-dead')
|
||||
assert run_obj.status == 'DEAD_LETTER'
|
||||
assert run_obj.retry_count == 3
|
||||
assert run_obj.error_detail == 'Final failure'
|
||||
assert run_obj.completed_at is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_terminal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_terminal_stopped():
|
||||
assert is_terminal({'status': 'STOPPED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_error():
|
||||
assert is_terminal({'status': 'ERROR'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_completed():
|
||||
assert is_terminal({'status': 'COMPLETED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_cancelled():
|
||||
assert is_terminal({'status': 'CANCELLED'}) is True
|
||||
|
||||
|
||||
def test_is_terminal_running():
|
||||
assert is_terminal({'status': 'RUNNING'}) is False
|
||||
|
||||
|
||||
def test_is_terminal_empty():
|
||||
assert is_terminal({}) is False
|
||||
|
||||
|
||||
def test_is_terminal_case_insensitive():
|
||||
assert is_terminal({'status': 'stopped'}) is True
|
||||
assert is_terminal({'status': 'Completed'}) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# find_matching_automations — None payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_matching_automations_none_payload(async_session):
|
||||
"""Events with None payload return empty list (data corruption guard)."""
|
||||
event = make_event(source_type='cron')
|
||||
event.payload = None
|
||||
async_session.add(event)
|
||||
await async_session.commit()
|
||||
|
||||
result = await find_matching_automations(async_session, event)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: event → run creation → claim
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_integration_event_to_run_to_claim(
|
||||
async_session_factory,
|
||||
):
|
||||
"""Full flow: create event + automation → process_new_events → claim_and_execute_runs.
|
||||
|
||||
Uses a real SQLite database; only the external API client is mocked.
|
||||
"""
|
||||
# 1. Seed an automation and a NEW event
|
||||
async with async_session_factory() as session:
|
||||
automation = make_automation(automation_id='integ-auto')
|
||||
event = make_event(
|
||||
source_type='cron',
|
||||
payload={'automation_id': 'integ-auto'},
|
||||
dedup_key='integ-dedup',
|
||||
)
|
||||
session.add_all([automation, event])
|
||||
await session.commit()
|
||||
event_id = event.id
|
||||
|
||||
# 2. Process inbox — should match and create a PENDING run
|
||||
async with async_session_factory() as session:
|
||||
processed = await process_new_events(session)
|
||||
|
||||
assert processed == 1
|
||||
|
||||
# Verify event is PROCESSED and run was created
|
||||
async with async_session_factory() as session:
|
||||
evt = await session.get(AutomationEvent, event_id)
|
||||
assert evt.status == 'PROCESSED'
|
||||
|
||||
runs = (await session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
run = runs[0]
|
||||
assert run.automation_id == 'integ-auto'
|
||||
assert run.status == 'PENDING'
|
||||
assert run.event_payload == {'automation_id': 'integ-auto'}
|
||||
|
||||
# 3. Claim the run — mock execute_run to avoid real API calls
|
||||
api_client = AsyncMock()
|
||||
|
||||
with patch('services.automation_executor.execute_run', new_callable=AsyncMock):
|
||||
async with async_session_factory() as session:
|
||||
claimed = await claim_and_execute_runs(
|
||||
session, 'executor-integ', api_client, async_session_factory
|
||||
)
|
||||
|
||||
assert claimed is True
|
||||
|
||||
# 4. Verify the run moved to RUNNING with correct executor
|
||||
async with async_session_factory() as session:
|
||||
runs = (await session.execute(select(AutomationRun))).scalars().all()
|
||||
assert len(runs) == 1
|
||||
run = runs[0]
|
||||
assert run.status == 'RUNNING'
|
||||
assert run.claimed_by == 'executor-integ'
|
||||
assert run.started_at is not None
|
||||
assert run.heartbeat_at is not None
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Tests for OpenHandsAPIClient with mocked HTTP responses."""
|
||||
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from services.openhands_api_client import OpenHandsAPIClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_client():
|
||||
client = OpenHandsAPIClient(base_url='http://test-server:3000')
|
||||
yield client
|
||||
# close handled in tests that need it
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start_conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_sends_correct_request(api_client, respx_mock):
|
||||
"""start_conversation sends properly formatted POST with auth header."""
|
||||
automation_file = b'print("hello")'
|
||||
expected_b64 = base64.b64encode(automation_file).decode()
|
||||
|
||||
route = respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json={
|
||||
'app_conversation_id': 'conv-123',
|
||||
'status': 'RUNNING',
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
result = await api_client.start_conversation(
|
||||
api_key='sk-oh-test123',
|
||||
automation_file=automation_file,
|
||||
title='Test Automation',
|
||||
event_payload={'automation_id': 'auto-1'},
|
||||
)
|
||||
|
||||
assert route.called
|
||||
request = route.calls[0].request
|
||||
|
||||
assert request.headers['Authorization'] == 'Bearer sk-oh-test123'
|
||||
|
||||
import json
|
||||
|
||||
body = json.loads(request.content)
|
||||
assert body['automation_file'] == expected_b64
|
||||
assert body['trigger'] == 'automation'
|
||||
assert body['title'] == 'Test Automation'
|
||||
assert body['event_payload'] == {'automation_id': 'auto-1'}
|
||||
|
||||
assert result == {'app_conversation_id': 'conv-123', 'status': 'RUNNING'}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_without_event_payload(api_client, respx_mock):
|
||||
"""start_conversation works with event_payload=None."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json={'app_conversation_id': 'conv-456'})
|
||||
)
|
||||
|
||||
result = await api_client.start_conversation(
|
||||
api_key='sk-oh-test',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
event_payload=None,
|
||||
)
|
||||
|
||||
assert result['app_conversation_id'] == 'conv-456'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_http_error(api_client, respx_mock):
|
||||
"""start_conversation raises on HTTP errors."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(500, json={'error': 'Internal Server Error'})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await api_client.start_conversation(
|
||||
api_key='sk-oh-test',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 500
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_conversation_auth_error(api_client, respx_mock):
|
||||
"""start_conversation raises on 401 Unauthorized."""
|
||||
respx_mock.post('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(401, json={'error': 'Unauthorized'})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError) as exc_info:
|
||||
await api_client.start_conversation(
|
||||
api_key='bad-key',
|
||||
automation_file=b'code',
|
||||
title='Test',
|
||||
)
|
||||
|
||||
assert exc_info.value.response.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_conversation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_returns_data(api_client, respx_mock):
|
||||
"""get_conversation returns the first conversation from the list."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(
|
||||
200,
|
||||
json=[
|
||||
{
|
||||
'conversation_id': 'conv-123',
|
||||
'status': 'RUNNING',
|
||||
'title': 'My Automation',
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
result = await api_client.get_conversation('sk-oh-test', 'conv-123')
|
||||
|
||||
assert result is not None
|
||||
assert result['conversation_id'] == 'conv-123'
|
||||
assert result['status'] == 'RUNNING'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_returns_none_when_empty(api_client, respx_mock):
|
||||
"""get_conversation returns None when API returns empty list."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
result = await api_client.get_conversation('sk-oh-test', 'nonexistent')
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_sends_auth_header(api_client, respx_mock):
|
||||
"""get_conversation sends the correct authorization header."""
|
||||
route = respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(200, json=[])
|
||||
)
|
||||
|
||||
await api_client.get_conversation('sk-oh-mykey', 'conv-1')
|
||||
|
||||
assert route.called
|
||||
request = route.calls[0].request
|
||||
assert request.headers['Authorization'] == 'Bearer sk-oh-mykey'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_http_error(api_client, respx_mock):
|
||||
"""get_conversation raises on HTTP errors."""
|
||||
respx_mock.get('http://test-server:3000/api/v1/app-conversations').mock(
|
||||
return_value=httpx.Response(503, text='Service Unavailable')
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await api_client.get_conversation('sk-oh-test', 'conv-1')
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# close
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close(api_client):
|
||||
"""close() shuts down the HTTP client without errors."""
|
||||
await api_client.close()
|
||||
@@ -10,9 +10,6 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
@@ -20,6 +17,9 @@ from storage.base import Base
|
||||
from storage.org import Org
|
||||
from storage.user import User
|
||||
|
||||
from enterprise.server.utils.saas_app_conversation_info_injector import (
|
||||
SaasSQLAppConversationInfoService,
|
||||
)
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationInfo,
|
||||
)
|
||||
@@ -663,330 +663,3 @@ class TestSaasSQLAppConversationInfoServiceAdminContext:
|
||||
|
||||
admin_page = await admin_service.search_app_conversation_info()
|
||||
assert len(admin_page.items) == 5
|
||||
|
||||
|
||||
class TestSaasSQLAppConversationInfoServiceWebhookFallback:
|
||||
"""Test suite for webhook callback fallback using info.created_by_user_id."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_with_admin_context_uses_created_by_user_id_fallback(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that save_app_conversation_info uses info.created_by_user_id when user_context returns None.
|
||||
|
||||
This is the key fix for SDK-created conversations: when the webhook endpoint
|
||||
uses ADMIN context (user_id=None), the service should fall back to using
|
||||
the created_by_user_id from the AppConversationInfo object.
|
||||
"""
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Arrange: Create service with ADMIN context (user_id=None)
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
# Create conversation info with created_by_user_id set (as would come from sandbox_info)
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID), # This should be used as fallback
|
||||
sandbox_id='sandbox_webhook_test',
|
||||
title='Webhook Created Conversation',
|
||||
)
|
||||
|
||||
# Act: Save using ADMIN context
|
||||
await admin_service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Assert: SAAS metadata should be created with user_id from info.created_by_user_id
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert saas_metadata is not None, 'SAAS metadata should be created'
|
||||
assert (
|
||||
saas_metadata.user_id == USER1_ID
|
||||
), 'user_id should match info.created_by_user_id'
|
||||
assert saas_metadata.org_id == ORG1_ID, 'org_id should match user current org'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_with_admin_context_no_user_id_skips_saas_metadata(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that save_app_conversation_info skips SAAS metadata when both user_context and info have no user_id."""
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Arrange: Create service with ADMIN context (user_id=None)
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
# Create conversation info without created_by_user_id
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=None, # No user_id available
|
||||
sandbox_id='sandbox_no_user',
|
||||
title='No User Conversation',
|
||||
)
|
||||
|
||||
# Act: Save using ADMIN context with no user_id fallback
|
||||
await admin_service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Assert: SAAS metadata should NOT be created
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert (
|
||||
saas_metadata is None
|
||||
), 'SAAS metadata should not be created without user_id'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_created_conversation_visible_to_user(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test end-to-end: conversation saved via webhook is visible to the owning user."""
|
||||
from openhands.app_server.user.specifiy_user_context import ADMIN
|
||||
|
||||
# Arrange: Save conversation using ADMIN context (simulating webhook)
|
||||
admin_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=ADMIN,
|
||||
)
|
||||
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_webhook_e2e',
|
||||
title='E2E Webhook Conversation',
|
||||
)
|
||||
await admin_service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Act: Query as the owning user
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
user1_page = await user1_service.search_app_conversation_info()
|
||||
|
||||
# Assert: User should see the webhook-created conversation
|
||||
assert len(user1_page.items) == 1
|
||||
assert user1_page.items[0].id == conv_id
|
||||
assert user1_page.items[0].title == 'E2E Webhook Conversation'
|
||||
|
||||
|
||||
class TestSandboxIdFilterSaas:
|
||||
"""Test suite for sandbox_id__eq filter parameter in SAAS service."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_by_sandbox_id(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test searching conversations by exact sandbox_id match with SAAS user filtering."""
|
||||
# Create service for user1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations with different sandbox IDs for user1
|
||||
conv1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_alpha',
|
||||
title='Conversation Alpha',
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
conv2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_beta',
|
||||
title='Conversation Beta',
|
||||
created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
conv3 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_alpha',
|
||||
title='Conversation Gamma',
|
||||
created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Save all conversations
|
||||
await user1_service.save_app_conversation_info(conv1)
|
||||
await user1_service.save_app_conversation_info(conv2)
|
||||
await user1_service.save_app_conversation_info(conv3)
|
||||
|
||||
# Search for sandbox_alpha - should return 2 conversations
|
||||
page = await user1_service.search_app_conversation_info(
|
||||
sandbox_id__eq='sandbox_alpha'
|
||||
)
|
||||
assert len(page.items) == 2
|
||||
sandbox_ids = {item.sandbox_id for item in page.items}
|
||||
assert sandbox_ids == {'sandbox_alpha'}
|
||||
conversation_ids = {item.id for item in page.items}
|
||||
assert conv1.id in conversation_ids
|
||||
assert conv3.id in conversation_ids
|
||||
|
||||
# Search for sandbox_beta - should return 1 conversation
|
||||
page = await user1_service.search_app_conversation_info(
|
||||
sandbox_id__eq='sandbox_beta'
|
||||
)
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conv2.id
|
||||
|
||||
# Search for non-existent sandbox - should return 0 conversations
|
||||
page = await user1_service.search_app_conversation_info(
|
||||
sandbox_id__eq='sandbox_nonexistent'
|
||||
)
|
||||
assert len(page.items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_count_by_sandbox_id(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test counting conversations by exact sandbox_id match with SAAS user filtering."""
|
||||
# Create service for user1
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create conversations with different sandbox IDs
|
||||
conv1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_x',
|
||||
title='Conversation X1',
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
conv2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_y',
|
||||
title='Conversation Y1',
|
||||
created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
conv3 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_x',
|
||||
title='Conversation X2',
|
||||
created_at=datetime(2024, 1, 1, 14, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 14, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Save all conversations
|
||||
await user1_service.save_app_conversation_info(conv1)
|
||||
await user1_service.save_app_conversation_info(conv2)
|
||||
await user1_service.save_app_conversation_info(conv3)
|
||||
|
||||
# Count for sandbox_x - should be 2
|
||||
count = await user1_service.count_app_conversation_info(
|
||||
sandbox_id__eq='sandbox_x'
|
||||
)
|
||||
assert count == 2
|
||||
|
||||
# Count for sandbox_y - should be 1
|
||||
count = await user1_service.count_app_conversation_info(
|
||||
sandbox_id__eq='sandbox_y'
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
# Count for non-existent sandbox - should be 0
|
||||
count = await user1_service.count_app_conversation_info(
|
||||
sandbox_id__eq='sandbox_nonexistent'
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sandbox_id_filter_respects_user_isolation(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that sandbox_id filter respects user isolation in SAAS environment."""
|
||||
# Create services for both users
|
||||
user1_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
user2_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER2_ID)),
|
||||
)
|
||||
|
||||
# Create conversation with same sandbox_id for both users
|
||||
shared_sandbox_id = 'sandbox_shared'
|
||||
|
||||
conv_user1 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id=shared_sandbox_id,
|
||||
title='User1 Conversation',
|
||||
created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 12, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
conv_user2 = AppConversationInfo(
|
||||
id=uuid4(),
|
||||
created_by_user_id=str(USER2_ID),
|
||||
sandbox_id=shared_sandbox_id,
|
||||
title='User2 Conversation',
|
||||
created_at=datetime(2024, 1, 1, 13, 0, 0, tzinfo=timezone.utc),
|
||||
updated_at=datetime(2024, 1, 1, 13, 30, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
# Save conversations
|
||||
await user1_service.save_app_conversation_info(conv_user1)
|
||||
await user2_service.save_app_conversation_info(conv_user2)
|
||||
|
||||
# User1 should only see their own conversation with this sandbox_id
|
||||
page = await user1_service.search_app_conversation_info(
|
||||
sandbox_id__eq=shared_sandbox_id
|
||||
)
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conv_user1.id
|
||||
assert page.items[0].title == 'User1 Conversation'
|
||||
|
||||
# User2 should only see their own conversation with this sandbox_id
|
||||
page = await user2_service.search_app_conversation_info(
|
||||
sandbox_id__eq=shared_sandbox_id
|
||||
)
|
||||
assert len(page.items) == 1
|
||||
assert page.items[0].id == conv_user2.id
|
||||
assert page.items[0].title == 'User2 Conversation'
|
||||
|
||||
# Count should also respect user isolation
|
||||
count = await user1_service.count_app_conversation_info(
|
||||
sandbox_id__eq=shared_sandbox_id
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
count = await user2_service.count_app_conversation_info(
|
||||
sandbox_id__eq=shared_sandbox_id
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
@@ -11,6 +11,7 @@ from server.auth.auth_error import AuthError
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.user.user_authorizer import UserAuthorizationResponse, UserAuthorizer
|
||||
from server.routes.auth import (
|
||||
_extract_recaptcha_state,
|
||||
accept_tos,
|
||||
authenticate,
|
||||
keycloak_callback,
|
||||
@@ -54,12 +55,11 @@ def mock_response():
|
||||
def test_set_response_cookie(mock_response, mock_request):
|
||||
"""Test setting the auth cookie on a response."""
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.config') as mock_config,
|
||||
patch('server.utils.url_utils.get_global_config') as get_global_config,
|
||||
):
|
||||
with patch('server.routes.auth.config') as mock_config:
|
||||
mock_config.jwt_secret.get_secret_value.return_value = 'test_secret'
|
||||
get_global_config.return_value = MagicMock(web_url='https://example.com')
|
||||
|
||||
# Configure mock_request.url.hostname
|
||||
mock_request.url.hostname = 'example.com'
|
||||
|
||||
set_response_cookie(
|
||||
request=mock_request,
|
||||
@@ -1036,6 +1036,79 @@ async def test_keycloak_callback_no_email_in_user_info(
|
||||
mock_token_manager.check_duplicate_base_email.assert_not_called()
|
||||
|
||||
|
||||
class TestExtractRecaptchaState:
|
||||
"""Tests for _extract_recaptcha_state() helper function."""
|
||||
|
||||
def test_should_extract_redirect_url_and_token_from_new_json_format(self):
|
||||
"""Test extraction from new base64-encoded JSON format."""
|
||||
# Arrange
|
||||
state_data = {
|
||||
'redirect_url': 'https://example.com',
|
||||
'recaptcha_token': 'test-token',
|
||||
}
|
||||
encoded_state = base64.urlsafe_b64encode(
|
||||
json.dumps(state_data).encode()
|
||||
).decode()
|
||||
|
||||
# Act
|
||||
redirect_url, token = _extract_recaptcha_state(encoded_state)
|
||||
|
||||
# Assert
|
||||
assert redirect_url == 'https://example.com'
|
||||
assert token == 'test-token'
|
||||
|
||||
def test_should_handle_old_format_plain_redirect_url(self):
|
||||
"""Test handling of old format (plain redirect URL string)."""
|
||||
# Arrange
|
||||
state = 'https://example.com'
|
||||
|
||||
# Act
|
||||
redirect_url, token = _extract_recaptcha_state(state)
|
||||
|
||||
# Assert
|
||||
assert redirect_url == 'https://example.com'
|
||||
assert token is None
|
||||
|
||||
def test_should_handle_none_state(self):
|
||||
"""Test handling of None state."""
|
||||
# Arrange
|
||||
state = None
|
||||
|
||||
# Act
|
||||
redirect_url, token = _extract_recaptcha_state(state)
|
||||
|
||||
# Assert
|
||||
assert redirect_url == ''
|
||||
assert token is None
|
||||
|
||||
def test_should_handle_invalid_base64_gracefully(self):
|
||||
"""Test handling of invalid base64/JSON (fallback to old format)."""
|
||||
# Arrange
|
||||
state = 'not-valid-base64!!!'
|
||||
|
||||
# Act
|
||||
redirect_url, token = _extract_recaptcha_state(state)
|
||||
|
||||
# Assert
|
||||
assert redirect_url == state
|
||||
assert token is None
|
||||
|
||||
def test_should_handle_missing_redirect_url_in_json(self):
|
||||
"""Test handling when redirect_url is missing in JSON."""
|
||||
# Arrange
|
||||
state_data = {'recaptcha_token': 'test-token'}
|
||||
encoded_state = base64.urlsafe_b64encode(
|
||||
json.dumps(state_data).encode()
|
||||
).decode()
|
||||
|
||||
# Act
|
||||
redirect_url, token = _extract_recaptcha_state(encoded_state)
|
||||
|
||||
# Assert
|
||||
assert redirect_url == ''
|
||||
assert token == 'test-token'
|
||||
|
||||
|
||||
class TestKeycloakCallbackRecaptcha:
|
||||
"""Tests for reCAPTCHA integration in keycloak_callback()."""
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ def mock_checkout_request():
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
request._url = URL('http://test.com/')
|
||||
request._base_url = URL('http://test.com/')
|
||||
return request
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def mock_subscription_request():
|
||||
'server': ('test.com', 80),
|
||||
}
|
||||
)
|
||||
request._url = URL('http://test.com/')
|
||||
request._base_url = URL('http://test.com/')
|
||||
return request
|
||||
|
||||
|
||||
@@ -264,7 +264,7 @@ async def test_create_checkout_session_success(
|
||||
async def test_success_callback_session_not_found(async_session_maker):
|
||||
"""Test success callback when billing session is not found."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with (
|
||||
patch('server.routes.billing.a_session_maker', async_session_maker),
|
||||
@@ -281,7 +281,7 @@ async def test_success_callback_stripe_incomplete(
|
||||
):
|
||||
"""Test success callback when Stripe session is not complete."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_incomplete_session'
|
||||
async with async_session_maker() as session:
|
||||
@@ -319,7 +319,7 @@ async def test_success_callback_stripe_incomplete(
|
||||
async def test_success_callback_success(async_session_maker, test_org, test_user):
|
||||
"""Test successful payment completion and credit update."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_success_session'
|
||||
async with async_session_maker() as session:
|
||||
@@ -391,7 +391,7 @@ async def test_success_callback_lite_llm_error(
|
||||
):
|
||||
"""Test handling of LiteLLM API errors during success callback."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_litellm_error_session'
|
||||
async with async_session_maker() as session:
|
||||
@@ -445,7 +445,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
the database transaction rolls back.
|
||||
"""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_budget_rollback_session'
|
||||
async with async_session_maker() as session:
|
||||
@@ -502,7 +502,7 @@ async def test_success_callback_lite_llm_update_budget_error_rollback(
|
||||
async def test_cancel_callback_session_not_found(async_session_maker):
|
||||
"""Test cancel callback when billing session is not found."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
with patch('server.routes.billing.a_session_maker', async_session_maker):
|
||||
response = await cancel_callback('nonexistent_session_id', mock_request)
|
||||
@@ -517,7 +517,7 @@ async def test_cancel_callback_session_not_found(async_session_maker):
|
||||
async def test_cancel_callback_success(async_session_maker, test_org, test_user):
|
||||
"""Test successful cancellation of billing session."""
|
||||
mock_request = Request(scope={'type': 'http'})
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
session_id = 'test_cancel_session'
|
||||
async with async_session_maker() as session:
|
||||
@@ -588,7 +588,7 @@ async def test_create_customer_setup_session_success():
|
||||
'headers': [],
|
||||
}
|
||||
)
|
||||
mock_request._url = URL('http://test.com/')
|
||||
mock_request._base_url = URL('http://test.com/')
|
||||
|
||||
mock_customer_info = {'customer_id': 'mock-customer-id', 'org_id': 'mock-org-id'}
|
||||
mock_session = MagicMock()
|
||||
@@ -613,6 +613,6 @@ async def test_create_customer_setup_session_success():
|
||||
customer='mock-customer-id',
|
||||
mode='setup',
|
||||
payment_method_types=['card'],
|
||||
success_url='https://test.com?setup=success',
|
||||
cancel_url='https://test.com',
|
||||
success_url='https://test.com/?setup=success',
|
||||
cancel_url='https://test.com/',
|
||||
)
|
||||
|
||||
@@ -2,9 +2,7 @@
|
||||
Unit tests for LiteLlmManager class.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
@@ -23,71 +21,6 @@ from storage.user_settings import UserSettings
|
||||
from openhands.server.settings import Settings
|
||||
|
||||
|
||||
class TestDefaultInitialBudget:
|
||||
"""Test cases for DEFAULT_INITIAL_BUDGET configuration."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def restore_module_state(self):
|
||||
"""Ensure module is properly restored after each test."""
|
||||
# Save original module if it exists
|
||||
original_module = sys.modules.get('storage.lite_llm_manager')
|
||||
|
||||
yield
|
||||
|
||||
# Restore module state after each test
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
# Clear the env var
|
||||
os.environ.pop('DEFAULT_INITIAL_BUDGET', None)
|
||||
|
||||
# Restore original module or reimport fresh
|
||||
if original_module is not None:
|
||||
sys.modules['storage.lite_llm_manager'] = original_module
|
||||
else:
|
||||
importlib.import_module('storage.lite_llm_manager')
|
||||
|
||||
def test_default_initial_budget_defaults_to_zero(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET defaults to 0.0 when env var not set."""
|
||||
# Temporarily remove the module so we can reimport with different env vars
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
# Clear the env var and reimport
|
||||
os.environ.pop('DEFAULT_INITIAL_BUDGET', None)
|
||||
module = importlib.import_module('storage.lite_llm_manager')
|
||||
assert module.DEFAULT_INITIAL_BUDGET == 0.0
|
||||
|
||||
def test_default_initial_budget_uses_env_var(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET uses value from environment variable."""
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = '100.0'
|
||||
module = importlib.import_module('storage.lite_llm_manager')
|
||||
assert module.DEFAULT_INITIAL_BUDGET == 100.0
|
||||
|
||||
def test_default_initial_budget_rejects_invalid_value(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET raises ValueError for invalid values."""
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = 'abc'
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
importlib.import_module('storage.lite_llm_manager')
|
||||
assert 'Invalid DEFAULT_INITIAL_BUDGET' in str(exc_info.value)
|
||||
|
||||
def test_default_initial_budget_rejects_negative_value(self):
|
||||
"""Test that DEFAULT_INITIAL_BUDGET raises ValueError for negative values."""
|
||||
if 'storage.lite_llm_manager' in sys.modules:
|
||||
del sys.modules['storage.lite_llm_manager']
|
||||
|
||||
os.environ['DEFAULT_INITIAL_BUDGET'] = '-10.0'
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
importlib.import_module('storage.lite_llm_manager')
|
||||
assert 'must be non-negative' in str(exc_info.value)
|
||||
|
||||
|
||||
class TestLiteLlmManager:
|
||||
"""Test cases for LiteLlmManager class."""
|
||||
|
||||
@@ -309,10 +242,10 @@ class TestLiteLlmManager:
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == 30.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_new_org_uses_default_initial_budget(
|
||||
async def test_create_entries_new_org_uses_zero_budget(
|
||||
self, mock_settings, mock_response
|
||||
):
|
||||
"""Test that create_entries uses DEFAULT_INITIAL_BUDGET for new org."""
|
||||
"""Test that create_entries uses budget=0 for new org (team doesn't exist)."""
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
@@ -340,7 +273,6 @@ class TestLiteLlmManager:
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
patch('storage.lite_llm_manager.DEFAULT_INITIAL_BUDGET', 0.0),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
@@ -348,67 +280,16 @@ class TestLiteLlmManager:
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify _create_team was called with DEFAULT_INITIAL_BUDGET (0.0)
|
||||
# Verify _create_team was called with budget=0
|
||||
create_team_call = mock_client.post.call_args_list[0]
|
||||
assert 'team/new' in create_team_call[0][0]
|
||||
assert create_team_call[1]['json']['max_budget'] == 0.0
|
||||
|
||||
# Verify _add_user_to_team was called with DEFAULT_INITIAL_BUDGET (0.0)
|
||||
# Verify _add_user_to_team was called with budget=0
|
||||
add_user_call = mock_client.post.call_args_list[1]
|
||||
assert 'team/member_add' in add_user_call[0][0]
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_new_org_uses_custom_default_budget(
|
||||
self, mock_settings, mock_response
|
||||
):
|
||||
"""Test that create_entries uses custom DEFAULT_INITIAL_BUDGET for new org."""
|
||||
mock_404_response = MagicMock()
|
||||
mock_404_response.status_code = 404
|
||||
mock_404_response.is_success = False
|
||||
|
||||
mock_token_manager = MagicMock()
|
||||
mock_token_manager.return_value.get_user_info_from_user_id = AsyncMock(
|
||||
return_value={'email': 'test@example.com'}
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get.return_value = mock_404_response
|
||||
mock_client.get.return_value.raise_for_status.side_effect = (
|
||||
httpx.HTTPStatusError(
|
||||
message='Not Found', request=MagicMock(), response=mock_404_response
|
||||
)
|
||||
)
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
mock_client_class = MagicMock()
|
||||
mock_client_class.return_value.__aenter__.return_value = mock_client
|
||||
|
||||
custom_budget = 50.0
|
||||
with (
|
||||
patch.dict(os.environ, {'LOCAL_DEPLOYMENT': ''}),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_KEY', 'test-key'),
|
||||
patch('storage.lite_llm_manager.LITE_LLM_API_URL', 'http://test.com'),
|
||||
patch('storage.lite_llm_manager.TokenManager', mock_token_manager),
|
||||
patch('httpx.AsyncClient', mock_client_class),
|
||||
patch('storage.lite_llm_manager.DEFAULT_INITIAL_BUDGET', custom_budget),
|
||||
):
|
||||
result = await LiteLlmManager.create_entries(
|
||||
'test-org-id', 'test-user-id', mock_settings, create_user=False
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify _create_team was called with custom DEFAULT_INITIAL_BUDGET
|
||||
create_team_call = mock_client.post.call_args_list[0]
|
||||
assert 'team/new' in create_team_call[0][0]
|
||||
assert create_team_call[1]['json']['max_budget'] == custom_budget
|
||||
|
||||
# Verify _add_user_to_team was called with custom DEFAULT_INITIAL_BUDGET
|
||||
add_user_call = mock_client.post.call_args_list[1]
|
||||
assert 'team/member_add' in add_user_call[0][0]
|
||||
assert add_user_call[1]['json']['max_budget_in_team'] == custom_budget
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_entries_propagates_non_404_errors(self, mock_settings):
|
||||
"""Test that create_entries propagates non-404 errors from _get_team."""
|
||||
|
||||
@@ -98,11 +98,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
|
||||
mock_keycloak_user_info = {'email': 'alice@example.com'} # Email from Keycloak
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.default_llm_model = 'test-model'
|
||||
mock_org.default_llm_base_url = None
|
||||
mock_org.default_max_iterations = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
@@ -126,10 +121,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org',
|
||||
new_callable=AsyncMock,
|
||||
@@ -154,7 +145,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.llm_api_key = SecretStr('test-key')
|
||||
mock_create_litellm.return_value = mock_settings
|
||||
mock_get_org.return_value = mock_org
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because Keycloak email matches
|
||||
@@ -224,11 +214,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
|
||||
mock_invitation.email = 'alice@example.com' # Lowercase in invitation
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.default_llm_model = 'test-model'
|
||||
mock_org.default_llm_base_url = None
|
||||
mock_org.default_max_iterations = None
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
@@ -249,10 +234,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org',
|
||||
new_callable=AsyncMock,
|
||||
@@ -269,7 +250,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.llm_api_key = SecretStr('test-key')
|
||||
mock_create_litellm.return_value = mock_settings
|
||||
mock_get_org.return_value = mock_org
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act - should not raise error because emails match case-insensitively
|
||||
@@ -278,75 +258,6 @@ class TestAcceptInvitationEmailValidation:
|
||||
# Assert - invitation was accepted (update_invitation_status was called)
|
||||
mock_update_status.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_invitation_inherits_org_llm_settings(self, mock_invitation):
|
||||
"""Test that new members inherit the organization's LLM settings when accepting invitation."""
|
||||
# Arrange
|
||||
user_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
token = 'inv-test-token-12345'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = user_id
|
||||
mock_user.email = 'alice@example.com'
|
||||
|
||||
mock_org = MagicMock()
|
||||
mock_org.default_llm_model = 'claude-sonnet-4'
|
||||
mock_org.default_llm_base_url = 'https://api.anthropic.com'
|
||||
mock_org.default_max_iterations = 100
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.get_invitation_by_token',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_invitation,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.is_token_expired'
|
||||
) as mock_is_expired,
|
||||
patch(
|
||||
'server.services.org_invitation_service.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgService.create_litellm_integration',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_litellm,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgStore.get_org_by_id',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_org,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgMemberStore.add_user_to_org',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_add_user,
|
||||
patch(
|
||||
'server.services.org_invitation_service.OrgInvitationStore.update_invitation_status',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_update_status,
|
||||
):
|
||||
mock_get_invitation.return_value = mock_invitation
|
||||
mock_is_expired.return_value = False
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.llm_api_key = SecretStr('test-key')
|
||||
mock_create_litellm.return_value = mock_settings
|
||||
mock_get_org.return_value = mock_org
|
||||
mock_update_status.return_value = mock_invitation
|
||||
|
||||
# Act
|
||||
await OrgInvitationService.accept_invitation(token, user_id)
|
||||
|
||||
# Assert - verify add_user_to_org was called with org's LLM settings
|
||||
mock_add_user.assert_called_once()
|
||||
call_kwargs = mock_add_user.call_args.kwargs
|
||||
assert call_kwargs['llm_model'] == 'claude-sonnet-4'
|
||||
assert call_kwargs['llm_base_url'] == 'https://api.anthropic.com'
|
||||
assert call_kwargs['max_iterations'] == 100
|
||||
|
||||
|
||||
class TestCreateInvitationsBatch:
|
||||
"""Test cases for batch invitation creation."""
|
||||
|
||||
@@ -246,43 +246,6 @@ async def test_add_user_to_org(async_session_maker):
|
||||
assert org_member.status == 'active'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user_to_org_with_llm_settings(async_session_maker):
|
||||
"""Test that add_user_to_org correctly sets inherited LLM settings from organization."""
|
||||
# Arrange
|
||||
async with async_session_maker() as session:
|
||||
org = Org(name='test-org-llm')
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(id=uuid.uuid4(), current_org_id=org.id)
|
||||
role = Role(name='member', rank=2)
|
||||
session.add_all([user, role])
|
||||
await session.commit()
|
||||
org_id = org.id
|
||||
user_id = user.id
|
||||
role_id = role.id
|
||||
|
||||
# Act
|
||||
with patch('storage.org_member_store.a_session_maker', async_session_maker):
|
||||
org_member = await OrgMemberStore.add_user_to_org(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
llm_api_key='test-api-key',
|
||||
status='active',
|
||||
llm_model='claude-sonnet-4',
|
||||
llm_base_url='https://api.example.com',
|
||||
max_iterations=50,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert org_member is not None
|
||||
assert org_member.llm_model == 'claude-sonnet-4'
|
||||
assert org_member.llm_base_url == 'https://api.example.com'
|
||||
assert org_member.max_iterations == 50
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_role_in_org(async_session_maker):
|
||||
# Test updating user role in org
|
||||
|
||||
@@ -144,86 +144,6 @@ async def test_create_org(async_session_maker, mock_litellm_api):
|
||||
assert org.id is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_v1_enabled_defaults_to_true_when_default_is_true(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: DEFAULT_V1_ENABLED is True and org.v1_enabled is not specified (None)
|
||||
WHEN: create_org is called
|
||||
THEN: org.v1_enabled should be set to True
|
||||
"""
|
||||
with (
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.DEFAULT_V1_ENABLED', True),
|
||||
):
|
||||
org = await OrgStore.create_org(kwargs={'name': 'test-org-v1-default-true'})
|
||||
|
||||
assert org is not None
|
||||
assert org.v1_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_v1_enabled_defaults_to_false_when_default_is_false(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: DEFAULT_V1_ENABLED is False and org.v1_enabled is not specified (None)
|
||||
WHEN: create_org is called
|
||||
THEN: org.v1_enabled should be set to False
|
||||
"""
|
||||
with (
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.DEFAULT_V1_ENABLED', False),
|
||||
):
|
||||
org = await OrgStore.create_org(kwargs={'name': 'test-org-v1-default-false'})
|
||||
|
||||
assert org is not None
|
||||
assert org.v1_enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_v1_enabled_explicit_false_overrides_default_true(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: DEFAULT_V1_ENABLED is True but org.v1_enabled is explicitly set to False
|
||||
WHEN: create_org is called
|
||||
THEN: org.v1_enabled should stay False (explicit value wins over default)
|
||||
"""
|
||||
with (
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.DEFAULT_V1_ENABLED', True),
|
||||
):
|
||||
org = await OrgStore.create_org(
|
||||
kwargs={'name': 'test-org-v1-explicit-false', 'v1_enabled': False}
|
||||
)
|
||||
|
||||
assert org is not None
|
||||
assert org.v1_enabled is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_v1_enabled_explicit_true_overrides_default_false(
|
||||
async_session_maker, mock_litellm_api
|
||||
):
|
||||
"""
|
||||
GIVEN: DEFAULT_V1_ENABLED is False but org.v1_enabled is explicitly set to True
|
||||
WHEN: create_org is called
|
||||
THEN: org.v1_enabled should stay True (explicit value wins over default)
|
||||
"""
|
||||
with (
|
||||
patch('storage.org_store.a_session_maker', async_session_maker),
|
||||
patch('storage.org_store.DEFAULT_V1_ENABLED', False),
|
||||
):
|
||||
org = await OrgStore.create_org(
|
||||
kwargs={'name': 'test-org-v1-explicit-true', 'v1_enabled': True}
|
||||
)
|
||||
|
||||
assert org is not None
|
||||
assert org.v1_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_org_by_name(async_session_maker, mock_litellm_api):
|
||||
# Test getting org by name
|
||||
|
||||
@@ -396,44 +396,3 @@ async def test_store_propagates_llm_settings_to_all_org_members(
|
||||
assert (
|
||||
decrypted_key == 'new-shared-api-key'
|
||||
), f'Expected llm_api_key to decrypt to new-shared-api-key for member {member.user_id}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_updates_org_default_llm_settings(
|
||||
session_maker, async_session_maker, mock_config, org_with_multiple_members_fixture
|
||||
):
|
||||
"""When admin saves LLM settings, org's default_llm_model/base_url/max_iterations should be updated.
|
||||
|
||||
This test verifies that the Org table's default settings are updated so that
|
||||
new members joining later will inherit the correct LLM configuration.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from storage.org import Org
|
||||
|
||||
# Arrange
|
||||
fixture = org_with_multiple_members_fixture
|
||||
org_id = fixture['org_id']
|
||||
admin_user_id = str(fixture['admin_user_id'])
|
||||
|
||||
store = SaasSettingsStore(admin_user_id, mock_config)
|
||||
|
||||
new_settings = DataSettings(
|
||||
llm_model='anthropic/claude-sonnet-4',
|
||||
llm_base_url='https://api.anthropic.com/v1',
|
||||
max_iterations=75,
|
||||
llm_api_key=SecretStr('test-api-key'),
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch('storage.saas_settings_store.a_session_maker', async_session_maker):
|
||||
await store.store(new_settings)
|
||||
|
||||
# Assert - verify org's default fields were updated
|
||||
with session_maker() as session:
|
||||
result = session.execute(select(Org).where(Org.id == org_id))
|
||||
org = result.scalars().first()
|
||||
|
||||
assert org is not None
|
||||
assert org.default_llm_model == 'anthropic/claude-sonnet-4'
|
||||
assert org.default_llm_base_url == 'https://api.anthropic.com/v1'
|
||||
assert org.default_max_iterations == 75
|
||||
|
||||
@@ -1,555 +0,0 @@
|
||||
"""Tests for AwsSharedEventService."""
|
||||
|
||||
import os
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.aws_shared_event_service import (
|
||||
AwsSharedEventService,
|
||||
AwsSharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_conversation_models import SharedConversation
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.sdk.llm import MetricsSnapshot
|
||||
from openhands.sdk.llm.utils.metrics import TokenUsage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_shared_conversation_info_service():
|
||||
"""Create a mock SharedConversationInfoService."""
|
||||
return AsyncMock(spec=SharedConversationInfoService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_s3_client():
|
||||
"""Create a mock S3 client."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_service():
|
||||
"""Create a mock EventService for returned by get_event_service."""
|
||||
return AsyncMock(spec=EventService)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aws_shared_event_service(mock_shared_conversation_info_service, mock_s3_client):
|
||||
"""Create an AwsSharedEventService for testing."""
|
||||
return AwsSharedEventService(
|
||||
shared_conversation_info_service=mock_shared_conversation_info_service,
|
||||
s3_client=mock_s3_client,
|
||||
bucket_name='test-bucket',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_public_conversation():
|
||||
"""Create a sample public conversation."""
|
||||
return SharedConversation(
|
||||
id=uuid4(),
|
||||
created_by_user_id='test_user',
|
||||
sandbox_id='test_sandbox',
|
||||
title='Test Public Conversation',
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
metrics=MetricsSnapshot(
|
||||
accumulated_cost=0.0,
|
||||
max_budget_per_task=10.0,
|
||||
accumulated_token_usage=TokenUsage(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event():
|
||||
"""Create a sample event."""
|
||||
# For testing purposes, we'll just use a mock that the EventPage can accept
|
||||
# The actual event creation is complex and not the focus of these tests
|
||||
return None
|
||||
|
||||
|
||||
class TestAwsSharedEventService:
|
||||
"""Test cases for AwsSharedEventService."""
|
||||
|
||||
async def test_get_shared_event_returns_event_for_public_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that get_shared_event returns an event for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_id = uuid4()
|
||||
|
||||
# Mock the public conversation service to return a public conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Mock get_event_service to return our mock event service
|
||||
aws_shared_event_service.get_event_service = AsyncMock(
|
||||
return_value=mock_event_service
|
||||
)
|
||||
|
||||
# Mock the event service to return an event
|
||||
mock_event_service.get_event.return_value = sample_event
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.get_shared_event(
|
||||
conversation_id, event_id
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == sample_event
|
||||
aws_shared_event_service.get_event_service.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.get_event.assert_called_once_with(conversation_id, event_id)
|
||||
|
||||
async def test_get_shared_event_returns_none_for_private_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that get_shared_event returns None for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
event_id = uuid4()
|
||||
|
||||
# Mock get_event_service to return None (private conversation)
|
||||
aws_shared_event_service.get_event_service = AsyncMock(return_value=None)
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.get_shared_event(
|
||||
conversation_id, event_id
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result is None
|
||||
aws_shared_event_service.get_event_service.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called since get_event_service returns None
|
||||
mock_event_service.get_event.assert_not_called()
|
||||
|
||||
async def test_search_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that search_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock get_event_service to return our mock event service
|
||||
aws_shared_event_service.get_event_service = AsyncMock(
|
||||
return_value=mock_event_service
|
||||
)
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_page = EventPage(items=[], next_page_id=None)
|
||||
mock_event_service.search_events.return_value = mock_event_page
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == mock_event_page
|
||||
assert len(result.items) == 0 # Empty list as we mocked
|
||||
|
||||
aws_shared_event_service.get_event_service.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.search_events.assert_called_once_with(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
sort_order=EventSortOrder.TIMESTAMP,
|
||||
page_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
async def test_search_shared_events_returns_empty_for_private_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that search_shared_events returns empty page for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock get_event_service to return None (private conversation)
|
||||
aws_shared_event_service.get_event_service = AsyncMock(return_value=None)
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.search_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EventPage)
|
||||
assert len(result.items) == 0
|
||||
assert result.next_page_id is None
|
||||
|
||||
aws_shared_event_service.get_event_service.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.search_events.assert_not_called()
|
||||
|
||||
async def test_count_shared_events_returns_count_for_public_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test that count_shared_events returns count for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock get_event_service to return our mock event service
|
||||
aws_shared_event_service.get_event_service = AsyncMock(
|
||||
return_value=mock_event_service
|
||||
)
|
||||
|
||||
# Mock the event service to return a count
|
||||
mock_event_service.count_events.return_value = 5
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 5
|
||||
|
||||
aws_shared_event_service.get_event_service.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
mock_event_service.count_events.assert_called_once_with(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq='ActionEvent',
|
||||
timestamp__gte=None,
|
||||
timestamp__lt=None,
|
||||
)
|
||||
|
||||
async def test_count_shared_events_returns_zero_for_private_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
):
|
||||
"""Test that count_shared_events returns 0 for a private conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock get_event_service to return None (private conversation)
|
||||
aws_shared_event_service.get_event_service = AsyncMock(return_value=None)
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.count_shared_events(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert result == 0
|
||||
|
||||
aws_shared_event_service.get_event_service.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
# Event service should not be called
|
||||
mock_event_service.count_events.assert_not_called()
|
||||
|
||||
async def test_batch_get_shared_events_returns_events_for_public_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
mock_event_service,
|
||||
sample_public_conversation,
|
||||
sample_event,
|
||||
):
|
||||
"""Test that batch_get_shared_events returns events for a public conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
event_ids = [uuid4() for _ in range(3)]
|
||||
|
||||
# Mock get_event_service to return our mock event service
|
||||
aws_shared_event_service.get_event_service = AsyncMock(
|
||||
return_value=mock_event_service
|
||||
)
|
||||
|
||||
# Mock the event service to return events
|
||||
mock_event_service.get_event.return_value = sample_event
|
||||
|
||||
# Call the method
|
||||
results = await aws_shared_event_service.batch_get_shared_events(
|
||||
conversation_id, event_ids
|
||||
)
|
||||
|
||||
# Verify the results
|
||||
assert len(results) == 3
|
||||
assert all(result == sample_event for result in results)
|
||||
|
||||
|
||||
class TestAwsSharedEventServiceGetEventService:
|
||||
"""Test cases for AwsSharedEventService.get_event_service method."""
|
||||
|
||||
async def test_get_event_service_returns_event_service_for_shared_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
sample_public_conversation,
|
||||
):
|
||||
"""Test that get_event_service returns an EventService for a shared conversation."""
|
||||
conversation_id = sample_public_conversation.id
|
||||
|
||||
# Mock the shared conversation info service to return a shared conversation
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = sample_public_conversation
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.get_event_service(conversation_id)
|
||||
|
||||
# Verify the result
|
||||
assert result is not None
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
async def test_get_event_service_returns_none_for_non_shared_conversation(
|
||||
self,
|
||||
aws_shared_event_service,
|
||||
mock_shared_conversation_info_service,
|
||||
):
|
||||
"""Test that get_event_service returns None for a non-shared conversation."""
|
||||
conversation_id = uuid4()
|
||||
|
||||
# Mock the shared conversation info service to return None
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.return_value = None
|
||||
|
||||
# Call the method
|
||||
result = await aws_shared_event_service.get_event_service(conversation_id)
|
||||
|
||||
# Verify the result
|
||||
assert result is None
|
||||
mock_shared_conversation_info_service.get_shared_conversation_info.assert_called_once_with(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
|
||||
class TestAwsSharedEventServiceInjector:
|
||||
"""Test cases for AwsSharedEventServiceInjector."""
|
||||
|
||||
def test_bucket_name_from_environment_variable(self):
|
||||
"""Test that bucket_name is read from FILE_STORE_PATH environment variable."""
|
||||
test_bucket_name = 'test-bucket-name'
|
||||
with patch.dict(os.environ, {'FILE_STORE_PATH': test_bucket_name}):
|
||||
# Create a new injector instance to pick up the environment variable
|
||||
# Note: The class attribute is evaluated at class definition time,
|
||||
# so we need to test that the attribute exists and can be overridden
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
injector.bucket_name = os.environ.get('FILE_STORE_PATH')
|
||||
assert injector.bucket_name == test_bucket_name
|
||||
|
||||
def test_bucket_name_default_value_when_env_not_set(self):
|
||||
"""Test that bucket_name is None when FILE_STORE_PATH is not set."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
# Remove FILE_STORE_PATH if it exists
|
||||
os.environ.pop('FILE_STORE_PATH', None)
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
# The bucket_name will be whatever was set at class definition time
|
||||
# or None if FILE_STORE_PATH was not set when the class was defined
|
||||
assert hasattr(injector, 'bucket_name')
|
||||
|
||||
async def test_injector_yields_aws_shared_event_service(self):
|
||||
"""Test that the injector yields an AwsSharedEventService instance."""
|
||||
mock_state = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_db_session = AsyncMock()
|
||||
|
||||
# Create the injector
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
injector.bucket_name = 'test-bucket'
|
||||
|
||||
# Mock the get_db_session context manager
|
||||
mock_db_context = AsyncMock()
|
||||
mock_db_context.__aenter__.return_value = mock_db_session
|
||||
mock_db_context.__aexit__.return_value = None
|
||||
|
||||
# Mock boto3.client
|
||||
mock_s3_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.sharing.aws_shared_event_service.boto3.client',
|
||||
return_value=mock_s3_client,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.config.get_db_session',
|
||||
return_value=mock_db_context,
|
||||
),
|
||||
):
|
||||
# Call the inject method
|
||||
async for service in injector.inject(mock_state, mock_request):
|
||||
# Verify the service is an instance of AwsSharedEventService
|
||||
assert isinstance(service, AwsSharedEventService)
|
||||
assert service.s3_client == mock_s3_client
|
||||
assert service.bucket_name == 'test-bucket'
|
||||
|
||||
async def test_injector_uses_bucket_name_from_instance(self):
|
||||
"""Test that the injector uses the bucket_name from the instance."""
|
||||
mock_state = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_db_session = AsyncMock()
|
||||
|
||||
# Create the injector with a specific bucket name
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
injector.bucket_name = 'my-custom-bucket'
|
||||
|
||||
# Mock the get_db_session context manager
|
||||
mock_db_context = AsyncMock()
|
||||
mock_db_context.__aenter__.return_value = mock_db_session
|
||||
mock_db_context.__aexit__.return_value = None
|
||||
|
||||
# Mock boto3.client
|
||||
mock_s3_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.sharing.aws_shared_event_service.boto3.client',
|
||||
return_value=mock_s3_client,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.config.get_db_session',
|
||||
return_value=mock_db_context,
|
||||
),
|
||||
):
|
||||
# Call the inject method
|
||||
async for service in injector.inject(mock_state, mock_request):
|
||||
assert service.bucket_name == 'my-custom-bucket'
|
||||
|
||||
async def test_injector_creates_sql_shared_conversation_info_service(self):
|
||||
"""Test that the injector creates SQLSharedConversationInfoService with db_session."""
|
||||
mock_state = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_db_session = AsyncMock()
|
||||
|
||||
# Create the injector
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
injector.bucket_name = 'test-bucket'
|
||||
|
||||
# Mock the get_db_session context manager
|
||||
mock_db_context = AsyncMock()
|
||||
mock_db_context.__aenter__.return_value = mock_db_session
|
||||
mock_db_context.__aexit__.return_value = None
|
||||
|
||||
# Mock boto3.client
|
||||
mock_s3_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.sharing.aws_shared_event_service.boto3.client',
|
||||
return_value=mock_s3_client,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.config.get_db_session',
|
||||
return_value=mock_db_context,
|
||||
),
|
||||
patch(
|
||||
'server.sharing.aws_shared_event_service.SQLSharedConversationInfoService'
|
||||
) as mock_sql_service_class,
|
||||
):
|
||||
mock_sql_service = MagicMock()
|
||||
mock_sql_service_class.return_value = mock_sql_service
|
||||
|
||||
# Call the inject method
|
||||
async for service in injector.inject(mock_state, mock_request):
|
||||
# Verify the service has the correct shared_conversation_info_service
|
||||
assert service.shared_conversation_info_service == mock_sql_service
|
||||
|
||||
# Verify SQLSharedConversationInfoService was created with db_session
|
||||
mock_sql_service_class.assert_called_once_with(db_session=mock_db_session)
|
||||
|
||||
async def test_injector_works_without_request(self):
|
||||
"""Test that the injector works when request is None."""
|
||||
mock_state = MagicMock()
|
||||
mock_db_session = AsyncMock()
|
||||
|
||||
# Create the injector
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
injector.bucket_name = 'test-bucket'
|
||||
|
||||
# Mock the get_db_session context manager
|
||||
mock_db_context = AsyncMock()
|
||||
mock_db_context.__aenter__.return_value = mock_db_session
|
||||
mock_db_context.__aexit__.return_value = None
|
||||
|
||||
# Mock boto3.client
|
||||
mock_s3_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.sharing.aws_shared_event_service.boto3.client',
|
||||
return_value=mock_s3_client,
|
||||
),
|
||||
patch(
|
||||
'openhands.app_server.config.get_db_session',
|
||||
return_value=mock_db_context,
|
||||
),
|
||||
):
|
||||
# Call the inject method with request=None
|
||||
async for service in injector.inject(mock_state, request=None):
|
||||
assert isinstance(service, AwsSharedEventService)
|
||||
|
||||
async def test_injector_uses_role_based_authentication(self):
|
||||
"""Test that the injector uses role-based authentication (no explicit credentials)."""
|
||||
mock_state = MagicMock()
|
||||
mock_request = MagicMock()
|
||||
mock_db_session = AsyncMock()
|
||||
|
||||
# Create the injector
|
||||
injector = AwsSharedEventServiceInjector()
|
||||
injector.bucket_name = 'test-bucket'
|
||||
|
||||
# Mock the get_db_session context manager
|
||||
mock_db_context = AsyncMock()
|
||||
mock_db_context.__aenter__.return_value = mock_db_session
|
||||
mock_db_context.__aexit__.return_value = None
|
||||
|
||||
# Mock boto3.client
|
||||
mock_s3_client = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.sharing.aws_shared_event_service.boto3.client',
|
||||
return_value=mock_s3_client,
|
||||
) as mock_boto3_client,
|
||||
patch(
|
||||
'openhands.app_server.config.get_db_session',
|
||||
return_value=mock_db_context,
|
||||
),
|
||||
patch.dict(os.environ, {'AWS_S3_ENDPOINT': 'https://s3.example.com'}),
|
||||
):
|
||||
# Call the inject method
|
||||
async for service in injector.inject(mock_state, mock_request):
|
||||
pass
|
||||
|
||||
# Verify boto3.client was called with 's3' and endpoint_url
|
||||
# but without explicit credentials (role-based auth)
|
||||
mock_boto3_client.assert_called_once_with(
|
||||
's3',
|
||||
endpoint_url='https://s3.example.com',
|
||||
)
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Tests for shared_event_router provider selection.
|
||||
|
||||
This module tests the get_shared_event_service_injector function which
|
||||
determines which SharedEventServiceInjector to use based on environment variables.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from server.sharing.aws_shared_event_service import AwsSharedEventServiceInjector
|
||||
from server.sharing.google_cloud_shared_event_service import (
|
||||
GoogleCloudSharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.shared_event_router import get_shared_event_service_injector
|
||||
|
||||
|
||||
class TestGetSharedEventServiceInjector:
|
||||
"""Test cases for get_shared_event_service_injector function."""
|
||||
|
||||
def test_defaults_to_google_cloud_when_no_env_set(self):
|
||||
"""Test that GoogleCloudSharedEventServiceInjector is used when no env is set."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{},
|
||||
clear=True,
|
||||
):
|
||||
os.environ.pop('SHARED_EVENT_STORAGE_PROVIDER', None)
|
||||
os.environ.pop('FILE_STORE', None)
|
||||
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_uses_google_cloud_when_file_store_google_cloud(self):
|
||||
"""Test that GoogleCloudSharedEventServiceInjector is used when FILE_STORE=google_cloud."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'FILE_STORE': 'google_cloud',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
os.environ.pop('SHARED_EVENT_STORAGE_PROVIDER', None)
|
||||
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_uses_aws_when_provider_aws(self):
|
||||
"""Test that AwsSharedEventServiceInjector is used when SHARED_EVENT_STORAGE_PROVIDER=aws."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'aws',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, AwsSharedEventServiceInjector)
|
||||
|
||||
def test_uses_gcp_when_provider_gcp(self):
|
||||
"""Test that GoogleCloudSharedEventServiceInjector is used when SHARED_EVENT_STORAGE_PROVIDER=gcp."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'gcp',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_uses_gcp_when_provider_google_cloud(self):
|
||||
"""Test that GoogleCloudSharedEventServiceInjector is used when SHARED_EVENT_STORAGE_PROVIDER=google_cloud."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'google_cloud',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_provider_takes_precedence_over_file_store(self):
|
||||
"""Test that SHARED_EVENT_STORAGE_PROVIDER takes precedence over FILE_STORE."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'aws',
|
||||
'FILE_STORE': 'google_cloud',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
# Should use AWS because SHARED_EVENT_STORAGE_PROVIDER takes precedence
|
||||
assert isinstance(injector, AwsSharedEventServiceInjector)
|
||||
|
||||
def test_provider_gcp_takes_precedence_over_file_store_s3(self):
|
||||
"""Test that SHARED_EVENT_STORAGE_PROVIDER=gcp takes precedence over FILE_STORE=s3."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'gcp',
|
||||
'FILE_STORE': 's3',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
# Should use GCP because SHARED_EVENT_STORAGE_PROVIDER takes precedence
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_provider_is_case_insensitive_aws(self):
|
||||
"""Test that SHARED_EVENT_STORAGE_PROVIDER is case insensitive for AWS."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'AWS',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, AwsSharedEventServiceInjector)
|
||||
|
||||
def test_provider_is_case_insensitive_gcp(self):
|
||||
"""Test that SHARED_EVENT_STORAGE_PROVIDER is case insensitive for GCP."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'GCP',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_unknown_provider_defaults_to_google_cloud(self):
|
||||
"""Test that unknown provider defaults to GoogleCloudSharedEventServiceInjector."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': 'unknown_provider',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
# Should default to GCP for unknown providers
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
|
||||
def test_empty_provider_falls_back_to_file_store(self):
|
||||
"""Test that empty SHARED_EVENT_STORAGE_PROVIDER falls back to FILE_STORE."""
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
'SHARED_EVENT_STORAGE_PROVIDER': '',
|
||||
'FILE_STORE': 'google_cloud',
|
||||
},
|
||||
clear=True,
|
||||
):
|
||||
injector = get_shared_event_service_injector()
|
||||
|
||||
# Should default to GCP for unknown providers
|
||||
assert isinstance(injector, GoogleCloudSharedEventServiceInjector)
|
||||
@@ -101,72 +101,6 @@ async def test_create_default_settings_with_litellm(mock_litellm_api):
|
||||
assert settings.llm_base_url == 'http://test.url'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_v1_enabled_true_when_default_is_true(
|
||||
mock_litellm_api,
|
||||
):
|
||||
"""
|
||||
GIVEN: DEFAULT_V1_ENABLED is True
|
||||
WHEN: create_default_settings is called
|
||||
THEN: The default_settings.v1_enabled should be set to True
|
||||
"""
|
||||
org_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Track the settings passed to LiteLlmManager.create_entries
|
||||
captured_settings = None
|
||||
|
||||
async def capture_create_entries(_org_id, _user_id, settings, _create_user):
|
||||
nonlocal captured_settings
|
||||
captured_settings = settings
|
||||
return settings
|
||||
|
||||
with (
|
||||
patch('storage.user_store.DEFAULT_V1_ENABLED', True),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.create_entries',
|
||||
side_effect=capture_create_entries,
|
||||
),
|
||||
):
|
||||
await UserStore.create_default_settings(org_id, user_id)
|
||||
|
||||
assert captured_settings is not None
|
||||
assert captured_settings.v1_enabled is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_settings_v1_enabled_false_when_default_is_false(
|
||||
mock_litellm_api,
|
||||
):
|
||||
"""
|
||||
GIVEN: DEFAULT_V1_ENABLED is False
|
||||
WHEN: create_default_settings is called
|
||||
THEN: The default_settings.v1_enabled should be set to False
|
||||
"""
|
||||
org_id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
# Track the settings passed to LiteLlmManager.create_entries
|
||||
captured_settings = None
|
||||
|
||||
async def capture_create_entries(_org_id, _user_id, settings, _create_user):
|
||||
nonlocal captured_settings
|
||||
captured_settings = settings
|
||||
return settings
|
||||
|
||||
with (
|
||||
patch('storage.user_store.DEFAULT_V1_ENABLED', False),
|
||||
patch(
|
||||
'storage.lite_llm_manager.LiteLlmManager.create_entries',
|
||||
side_effect=capture_create_entries,
|
||||
),
|
||||
):
|
||||
await UserStore.create_default_settings(org_id, user_id)
|
||||
|
||||
assert captured_settings is not None
|
||||
assert captured_settings.v1_enabled is False
|
||||
|
||||
|
||||
# --- Tests for get_user_by_id ---
|
||||
|
||||
|
||||
@@ -1309,19 +1243,3 @@ async def test_migrate_user_sql_multiple_conversations(async_session_maker):
|
||||
assert (
|
||||
row.org_id == user_uuid_str
|
||||
), f'org_id should match: {row.org_id} vs {user_uuid_str}'
|
||||
|
||||
|
||||
# Note: The v1_enabled logic in migrate_user follows the same pattern as OrgStore.create_org:
|
||||
# if org.v1_enabled is None:
|
||||
# org.v1_enabled = DEFAULT_V1_ENABLED
|
||||
#
|
||||
# This behavior is tested in test_org_store.py via:
|
||||
# - test_create_org_v1_enabled_defaults_to_true_when_default_is_true
|
||||
# - test_create_org_v1_enabled_defaults_to_false_when_default_is_false
|
||||
# - test_create_org_v1_enabled_explicit_false_overrides_default_true
|
||||
# - test_create_org_v1_enabled_explicit_true_overrides_default_false
|
||||
#
|
||||
# Testing migrate_user directly is impractical due to its complex raw SQL migration
|
||||
# statements that have SQLite/UUID compatibility issues in the test environment.
|
||||
# The SQL migration tests above (test_migrate_user_sql_type_handling, etc.) verify
|
||||
# the SQL operations work correctly with proper type handling.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Tests for enterprise server utils
|
||||
@@ -1,425 +0,0 @@
|
||||
"""Tests for URL utility functions that prevent URL hijacking attacks."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestGetWebUrl:
|
||||
"""Tests for get_web_url function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Create a mock FastAPI request object."""
|
||||
request = MagicMock()
|
||||
request.url = MagicMock()
|
||||
return request
|
||||
|
||||
def test_configured_web_url_is_used(self, mock_request):
|
||||
"""When web_url is configured, it should be used instead of request URL."""
|
||||
from server.utils.url_utils import get_web_url
|
||||
|
||||
mock_request.url.hostname = 'evil-attacker.com'
|
||||
mock_request.url.netloc = 'evil-attacker.com:443'
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev'
|
||||
|
||||
with patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
):
|
||||
result = get_web_url(mock_request)
|
||||
|
||||
assert result == 'https://app.all-hands.dev'
|
||||
# Should not use any info from the potentially poisoned request
|
||||
assert 'evil-attacker.com' not in result
|
||||
|
||||
def test_configured_web_url_trailing_slash_stripped(self, mock_request):
|
||||
"""Configured web_url should have trailing slashes stripped."""
|
||||
from server.utils.url_utils import get_web_url
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev/'
|
||||
|
||||
with patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
):
|
||||
result = get_web_url(mock_request)
|
||||
|
||||
assert result == 'https://app.all-hands.dev'
|
||||
assert not result.endswith('/')
|
||||
|
||||
def test_unconfigured_web_url_localhost_uses_http(self, mock_request):
|
||||
"""When web_url is not configured and hostname is localhost, use http."""
|
||||
from server.utils.url_utils import get_web_url
|
||||
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.netloc = 'localhost:3000'
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = None
|
||||
|
||||
with patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
):
|
||||
result = get_web_url(mock_request)
|
||||
|
||||
assert result == 'http://localhost:3000'
|
||||
|
||||
def test_unconfigured_web_url_non_localhost_uses_https(self, mock_request):
|
||||
"""When web_url is not configured and hostname is not localhost, use https."""
|
||||
from server.utils.url_utils import get_web_url
|
||||
|
||||
mock_request.url.hostname = 'example.com'
|
||||
mock_request.url.netloc = 'example.com:443'
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = None
|
||||
|
||||
with patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
):
|
||||
result = get_web_url(mock_request)
|
||||
|
||||
assert result == 'https://example.com:443'
|
||||
|
||||
def test_unconfigured_web_url_empty_string_fallback(self, mock_request):
|
||||
"""Empty string web_url should trigger fallback."""
|
||||
from server.utils.url_utils import get_web_url
|
||||
|
||||
mock_request.url.hostname = 'localhost'
|
||||
mock_request.url.netloc = 'localhost:3000'
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = ''
|
||||
|
||||
with patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
):
|
||||
result = get_web_url(mock_request)
|
||||
|
||||
assert result == 'http://localhost:3000'
|
||||
|
||||
|
||||
class TestGetCookieDomain:
|
||||
"""Tests for get_cookie_domain function."""
|
||||
|
||||
def test_production_with_configured_web_url(self):
|
||||
"""In production with web_url configured, should return hostname."""
|
||||
from server.utils.url_utils import get_cookie_domain
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_domain()
|
||||
|
||||
assert result == 'app.all-hands.dev'
|
||||
|
||||
def test_production_without_web_url_returns_none(self):
|
||||
"""In production without web_url configured, should return None."""
|
||||
from server.utils.url_utils import get_cookie_domain
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = None
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_domain()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_local_env_returns_none(self):
|
||||
"""In local environment, should return None for cookie domain."""
|
||||
from server.utils.url_utils import get_cookie_domain
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', True),
|
||||
):
|
||||
result = get_cookie_domain()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_staging_env_returns_none(self):
|
||||
"""In staging environment, should return None for cookie domain."""
|
||||
from server.utils.url_utils import get_cookie_domain
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://staging.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', True),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_domain()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_feature_env_returns_none(self):
|
||||
"""In feature environment, should return None for cookie domain."""
|
||||
from server.utils.url_utils import get_cookie_domain
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://feature-123.staging.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', True),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', True),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_domain()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGetCookieSamesite:
|
||||
"""Tests for get_cookie_samesite function."""
|
||||
|
||||
def test_production_with_configured_web_url_returns_strict(self):
|
||||
"""In production with web_url configured, should return 'strict'."""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_samesite()
|
||||
|
||||
assert result == 'strict'
|
||||
|
||||
def test_production_without_web_url_returns_lax(self):
|
||||
"""In production without web_url configured, should return 'lax'."""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = None
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_samesite()
|
||||
|
||||
assert result == 'lax'
|
||||
|
||||
def test_local_env_returns_lax(self):
|
||||
"""In local environment, should return 'lax'."""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'http://localhost:3000'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', True),
|
||||
):
|
||||
result = get_cookie_samesite()
|
||||
|
||||
assert result == 'lax'
|
||||
|
||||
def test_staging_env_returns_lax(self):
|
||||
"""In staging environment, should return 'lax'."""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://staging.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', True),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_samesite()
|
||||
|
||||
assert result == 'lax'
|
||||
|
||||
def test_feature_env_returns_lax(self):
|
||||
"""In feature environment, should return 'lax'."""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://feature-xyz.staging.all-hands.dev'
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', True),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', True),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_samesite()
|
||||
|
||||
assert result == 'lax'
|
||||
|
||||
def test_empty_web_url_returns_lax(self):
|
||||
"""Empty web_url should be treated as unconfigured and return 'lax'."""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = ''
|
||||
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
result = get_cookie_samesite()
|
||||
|
||||
assert result == 'lax'
|
||||
|
||||
|
||||
class TestSecurityScenarios:
|
||||
"""Tests for security-critical scenarios."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request(self):
|
||||
"""Create a mock FastAPI request object."""
|
||||
request = MagicMock()
|
||||
request.url = MagicMock()
|
||||
return request
|
||||
|
||||
def test_header_poisoning_attack_blocked_when_configured(self, mock_request):
|
||||
"""
|
||||
When web_url is configured, X-Forwarded-* header poisoning should not affect
|
||||
the returned URL.
|
||||
"""
|
||||
from server.utils.url_utils import get_web_url
|
||||
|
||||
# Simulate a poisoned request where attacker controls headers
|
||||
mock_request.url.hostname = 'evil.com'
|
||||
mock_request.url.netloc = 'evil.com:443'
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev'
|
||||
|
||||
with patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
):
|
||||
result = get_web_url(mock_request)
|
||||
|
||||
# Should use configured web_url, not the poisoned request data
|
||||
assert result == 'https://app.all-hands.dev'
|
||||
assert 'evil' not in result
|
||||
|
||||
def test_cookie_domain_not_set_in_dev_environments(self):
|
||||
"""
|
||||
Cookie domain should not be set in development environments to prevent
|
||||
cookies from leaking to other subdomains.
|
||||
"""
|
||||
from server.utils.url_utils import get_cookie_domain
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://my-feature.staging.all-hands.dev'
|
||||
|
||||
# Test each dev environment
|
||||
for env_name, env_config in [
|
||||
(
|
||||
'local',
|
||||
{
|
||||
'IS_LOCAL_ENV': True,
|
||||
'IS_STAGING_ENV': False,
|
||||
'IS_FEATURE_ENV': False,
|
||||
},
|
||||
),
|
||||
(
|
||||
'staging',
|
||||
{
|
||||
'IS_LOCAL_ENV': False,
|
||||
'IS_STAGING_ENV': True,
|
||||
'IS_FEATURE_ENV': False,
|
||||
},
|
||||
),
|
||||
(
|
||||
'feature',
|
||||
{'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': True},
|
||||
),
|
||||
]:
|
||||
with (
|
||||
patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
),
|
||||
patch(
|
||||
'server.utils.url_utils.IS_FEATURE_ENV',
|
||||
env_config['IS_FEATURE_ENV'],
|
||||
),
|
||||
patch(
|
||||
'server.utils.url_utils.IS_STAGING_ENV',
|
||||
env_config['IS_STAGING_ENV'],
|
||||
),
|
||||
patch(
|
||||
'server.utils.url_utils.IS_LOCAL_ENV', env_config['IS_LOCAL_ENV']
|
||||
),
|
||||
):
|
||||
result = get_cookie_domain()
|
||||
assert result is None, f'Expected None for {env_name} environment'
|
||||
|
||||
def test_strict_samesite_only_in_production(self):
|
||||
"""
|
||||
SameSite=strict should only be set in production to ensure proper
|
||||
security without breaking OAuth flows in development.
|
||||
"""
|
||||
from server.utils.url_utils import get_cookie_samesite
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.web_url = 'https://app.all-hands.dev'
|
||||
|
||||
# Production should be strict
|
||||
with (
|
||||
patch('server.utils.url_utils.get_global_config', return_value=mock_config),
|
||||
patch('server.utils.url_utils.IS_FEATURE_ENV', False),
|
||||
patch('server.utils.url_utils.IS_STAGING_ENV', False),
|
||||
patch('server.utils.url_utils.IS_LOCAL_ENV', False),
|
||||
):
|
||||
assert get_cookie_samesite() == 'strict'
|
||||
|
||||
# Dev environments should be lax
|
||||
for env_config in [
|
||||
{'IS_LOCAL_ENV': True, 'IS_STAGING_ENV': False, 'IS_FEATURE_ENV': False},
|
||||
{'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': False},
|
||||
{'IS_LOCAL_ENV': False, 'IS_STAGING_ENV': True, 'IS_FEATURE_ENV': True},
|
||||
]:
|
||||
with (
|
||||
patch(
|
||||
'server.utils.url_utils.get_global_config', return_value=mock_config
|
||||
),
|
||||
patch(
|
||||
'server.utils.url_utils.IS_FEATURE_ENV',
|
||||
env_config['IS_FEATURE_ENV'],
|
||||
),
|
||||
patch(
|
||||
'server.utils.url_utils.IS_STAGING_ENV',
|
||||
env_config['IS_STAGING_ENV'],
|
||||
),
|
||||
patch(
|
||||
'server.utils.url_utils.IS_LOCAL_ENV', env_config['IS_LOCAL_ENV']
|
||||
),
|
||||
):
|
||||
assert get_cookie_samesite() == 'lax'
|
||||
@@ -8,4 +8,3 @@ node_modules/
|
||||
/blob-report/
|
||||
/playwright/.cache/
|
||||
.react-router/
|
||||
ralph/
|
||||
|
||||
@@ -113,7 +113,7 @@ describe("ExpandableMessage", () => {
|
||||
|
||||
it("should render the out of credits message when the user is out of credits", async () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - partial mock for testing
|
||||
// @ts-expect-error - We only care about the app_mode and feature_flags fields
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
|
||||
import { AccountSettingsContextMenu } from "#/components/features/context-menu/account-settings-context-menu";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { renderWithProviders } from "../../../test-utils";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createMockWebClientConfig } from "../../helpers/mock-config";
|
||||
|
||||
const mockTrackAddTeamMembersButtonClick = vi.fn();
|
||||
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackAddTeamMembersButtonClick: mockTrackAddTeamMembersButtonClick,
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock posthog feature flag
|
||||
vi.mock("posthog-js/react", () => ({
|
||||
useFeatureFlagEnabled: vi.fn(),
|
||||
}));
|
||||
|
||||
// Import the mocked module to get access to the mock
|
||||
import * as posthog from "posthog-js/react";
|
||||
|
||||
describe("AccountSettingsContextMenu", () => {
|
||||
const user = userEvent.setup();
|
||||
const onClickAccountSettingsMock = vi.fn();
|
||||
const onLogoutMock = vi.fn();
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
let queryClient: QueryClient;
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
// Set default feature flag to false
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(false);
|
||||
});
|
||||
|
||||
// Create a wrapper with MemoryRouter and renderWithProviders
|
||||
const renderWithRouter = (ui: React.ReactElement) => {
|
||||
return renderWithProviders(<MemoryRouter>{ui}</MemoryRouter>);
|
||||
};
|
||||
|
||||
const renderWithSaasConfig = (ui: React.ReactElement, options?: { analyticsConsent?: boolean }) => {
|
||||
queryClient.setQueryData(["web-client-config"], createMockWebClientConfig({ app_mode: "saas" }));
|
||||
queryClient.setQueryData(["settings"], { user_consents_to_analytics: options?.analyticsConsent ?? true });
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<MemoryRouter>{ui}</MemoryRouter>
|
||||
</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
const renderWithOssConfig = (ui: React.ReactElement) => {
|
||||
queryClient.setQueryData(["web-client-config"], createMockWebClientConfig({ app_mode: "oss" }));
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<MemoryRouter>{ui}</MemoryRouter>
|
||||
</QueryClientProvider>
|
||||
);
|
||||
};
|
||||
|
||||
afterEach(() => {
|
||||
onClickAccountSettingsMock.mockClear();
|
||||
onLogoutMock.mockClear();
|
||||
onCloseMock.mockClear();
|
||||
mockTrackAddTeamMembersButtonClick.mockClear();
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockClear();
|
||||
});
|
||||
|
||||
it("should always render the right options", () => {
|
||||
renderWithRouter(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByTestId("account-settings-context-menu"),
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("SIDEBAR$DOCS")).toBeInTheDocument();
|
||||
expect(screen.getByText("ACCOUNT_SETTINGS$LOGOUT")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render Documentation link with correct attributes", () => {
|
||||
renderWithRouter(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const documentationLink = screen.getByText("SIDEBAR$DOCS").closest("a");
|
||||
expect(documentationLink).toHaveAttribute("href", "https://docs.openhands.dev");
|
||||
expect(documentationLink).toHaveAttribute("target", "_blank");
|
||||
expect(documentationLink).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
|
||||
it("should call onLogout when the logout option is clicked", async () => {
|
||||
renderWithRouter(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const logoutOption = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await user.click(logoutOption);
|
||||
|
||||
expect(onLogoutMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
test("logout button is always enabled", async () => {
|
||||
renderWithRouter(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const logoutOption = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await user.click(logoutOption);
|
||||
|
||||
expect(onLogoutMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should call onClose when clicking outside of the element", async () => {
|
||||
renderWithRouter(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const accountSettingsButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await user.click(accountSettingsButton);
|
||||
await user.click(document.body);
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should show Add Team Members button in SaaS mode when feature flag is enabled", () => {
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
|
||||
renderWithSaasConfig(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByTestId("add-team-members-button")).toBeInTheDocument();
|
||||
expect(screen.getByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show Add Team Members button in SaaS mode when feature flag is disabled", () => {
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(false);
|
||||
renderWithSaasConfig(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("add-team-members-button")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show Add Team Members button in OSS mode even when feature flag is enabled", () => {
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
|
||||
renderWithOssConfig(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("add-team-members-button")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show Add Team Members button when analytics consent is disabled", () => {
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
|
||||
renderWithSaasConfig(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
{ analyticsConsent: false },
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("add-team-members-button")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("SETTINGS$NAV_ADD_TEAM_MEMBERS")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call tracking function and onClose when Add Team Members button is clicked", async () => {
|
||||
vi.mocked(posthog.useFeatureFlagEnabled).mockReturnValue(true);
|
||||
renderWithSaasConfig(
|
||||
<AccountSettingsContextMenu
|
||||
onLogout={onLogoutMock}
|
||||
onClose={onCloseMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const addTeamMembersButton = screen.getByTestId("add-team-members-button");
|
||||
await user.click(addTeamMembersButton);
|
||||
|
||||
expect(mockTrackAddTeamMembersButtonClick).toHaveBeenCalledOnce();
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
@@ -15,7 +15,6 @@ vi.mock("#/hooks/use-auth-url", () => ({
|
||||
bitbucket: "https://bitbucket.org/site/oauth2/authorize",
|
||||
bitbucket_data_center:
|
||||
"https://bitbucket-dc.example.com/site/oauth2/authorize",
|
||||
enterprise_sso: "https://auth.example.com/realms/test/protocol/openid-connect/auth",
|
||||
};
|
||||
if (config.appMode === "saas") {
|
||||
return urls[config.identityProvider] || null;
|
||||
@@ -118,74 +117,6 @@ describe("LoginContent", () => {
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display Enterprise SSO button when configured", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
authUrl="https://auth.example.com"
|
||||
providersConfigured={["enterprise_sso"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByRole("button", { name: /ENTERPRISE_SSO\$CONNECT_TO_ENTERPRISE_SSO/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display Enterprise SSO alongside other providers when all configured", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
authUrl="https://auth.example.com"
|
||||
providersConfigured={["github", "gitlab", "bitbucket", "enterprise_sso"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.getByRole("button", { name: "GITHUB$CONNECT_TO_GITHUB" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("button", { name: "GITLAB$CONNECT_TO_GITLAB" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /BITBUCKET\$CONNECT_TO_BITBUCKET/i }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("button", { name: /ENTERPRISE_SSO\$CONNECT_TO_ENTERPRISE_SSO/i }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should redirect to Enterprise SSO auth URL when Enterprise SSO button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const mockUrl = "https://auth.example.com/realms/test/protocol/openid-connect/auth";
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode="saas"
|
||||
authUrl="https://auth.example.com"
|
||||
providersConfigured={["enterprise_sso"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
const enterpriseSsoButton = screen.getByRole("button", {
|
||||
name: /ENTERPRISE_SSO\$CONNECT_TO_ENTERPRISE_SSO/i,
|
||||
});
|
||||
await user.click(enterpriseSsoButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(window.location.href).toContain(mockUrl);
|
||||
});
|
||||
});
|
||||
|
||||
it("should display message when no providers are configured", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
import React from "react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { ConfirmRemoveMemberModal } from "#/components/features/org/confirm-remove-member-modal";
|
||||
|
||||
vi.mock("react-i18next", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-i18next")>()),
|
||||
Trans: ({
|
||||
values,
|
||||
components,
|
||||
}: {
|
||||
values: { email: string };
|
||||
components: { email: React.ReactElement };
|
||||
}) => React.cloneElement(components.email, {}, values.email),
|
||||
}));
|
||||
|
||||
describe("ConfirmRemoveMemberModal", () => {
|
||||
it("should display the member email in the confirmation message", () => {
|
||||
// Arrange
|
||||
const memberEmail = "test@example.com";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConfirmRemoveMemberModal
|
||||
onConfirm={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
memberEmail={memberEmail}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(memberEmail)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call onConfirm when the confirm button is clicked", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const onConfirmMock = vi.fn();
|
||||
renderWithProviders(
|
||||
<ConfirmRemoveMemberModal
|
||||
onConfirm={onConfirmMock}
|
||||
onCancel={vi.fn()}
|
||||
memberEmail="test@example.com"
|
||||
/>,
|
||||
);
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByTestId("confirm-button"));
|
||||
|
||||
// Assert
|
||||
expect(onConfirmMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should call onCancel when the cancel button is clicked", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const onCancelMock = vi.fn();
|
||||
renderWithProviders(
|
||||
<ConfirmRemoveMemberModal
|
||||
onConfirm={vi.fn()}
|
||||
onCancel={onCancelMock}
|
||||
memberEmail="test@example.com"
|
||||
/>,
|
||||
);
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByTestId("cancel-button"));
|
||||
|
||||
// Assert
|
||||
expect(onCancelMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should disable buttons and show loading spinner when isLoading is true", () => {
|
||||
// Arrange & Act
|
||||
renderWithProviders(
|
||||
<ConfirmRemoveMemberModal
|
||||
onConfirm={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
memberEmail="test@example.com"
|
||||
isLoading
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId("confirm-button")).toBeDisabled();
|
||||
expect(screen.getByTestId("cancel-button")).toBeDisabled();
|
||||
expect(screen.getByTestId("loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,102 +0,0 @@
|
||||
import React from "react";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { ConfirmUpdateRoleModal } from "#/components/features/org/confirm-update-role-modal";
|
||||
|
||||
vi.mock("react-i18next", async (importOriginal) => ({
|
||||
...(await importOriginal<typeof import("react-i18next")>()),
|
||||
Trans: ({
|
||||
values,
|
||||
components,
|
||||
}: {
|
||||
values: { email: string; role: string };
|
||||
components: { email: React.ReactElement; role: React.ReactElement };
|
||||
}) => (
|
||||
<>
|
||||
{React.cloneElement(components.email, {}, values.email)}
|
||||
{React.cloneElement(components.role, {}, values.role)}
|
||||
</>
|
||||
),
|
||||
}));
|
||||
|
||||
describe("ConfirmUpdateRoleModal", () => {
|
||||
it("should display the member email and new role in the confirmation message", () => {
|
||||
// Arrange
|
||||
const memberEmail = "test@example.com";
|
||||
const newRole = "admin";
|
||||
|
||||
// Act
|
||||
renderWithProviders(
|
||||
<ConfirmUpdateRoleModal
|
||||
onConfirm={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
memberEmail={memberEmail}
|
||||
newRole={newRole}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.getByText(memberEmail)).toBeInTheDocument();
|
||||
expect(screen.getByText(newRole)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call onConfirm when the confirm button is clicked", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const onConfirmMock = vi.fn();
|
||||
renderWithProviders(
|
||||
<ConfirmUpdateRoleModal
|
||||
onConfirm={onConfirmMock}
|
||||
onCancel={vi.fn()}
|
||||
memberEmail="test@example.com"
|
||||
newRole="admin"
|
||||
/>,
|
||||
);
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByTestId("confirm-button"));
|
||||
|
||||
// Assert
|
||||
expect(onConfirmMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should call onCancel when the cancel button is clicked", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
const onCancelMock = vi.fn();
|
||||
renderWithProviders(
|
||||
<ConfirmUpdateRoleModal
|
||||
onConfirm={vi.fn()}
|
||||
onCancel={onCancelMock}
|
||||
memberEmail="test@example.com"
|
||||
newRole="admin"
|
||||
/>,
|
||||
);
|
||||
|
||||
// Act
|
||||
await user.click(screen.getByTestId("cancel-button"));
|
||||
|
||||
// Assert
|
||||
expect(onCancelMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should disable buttons and show loading spinner when isLoading is true", () => {
|
||||
// Arrange & Act
|
||||
renderWithProviders(
|
||||
<ConfirmUpdateRoleModal
|
||||
onConfirm={vi.fn()}
|
||||
onCancel={vi.fn()}
|
||||
memberEmail="test@example.com"
|
||||
newRole="admin"
|
||||
isLoading
|
||||
/>,
|
||||
);
|
||||
|
||||
// Assert
|
||||
expect(screen.getByTestId("confirm-button")).toBeDisabled();
|
||||
expect(screen.getByTestId("cancel-button")).toBeDisabled();
|
||||
expect(screen.getByTestId("loading-spinner")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -1,141 +0,0 @@
|
||||
import { within, screen, render } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { InviteOrganizationMemberModal } from "#/components/features/org/invite-organization-member-modal";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
useRevalidator: vi.fn(() => ({ revalidate: vi.fn() })),
|
||||
}));
|
||||
|
||||
const renderInviteOrganizationMemberModal = (config?: {
|
||||
onClose: () => void;
|
||||
}) =>
|
||||
render(
|
||||
<InviteOrganizationMemberModal onClose={config?.onClose || vi.fn()} />,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
},
|
||||
);
|
||||
|
||||
describe("InviteOrganizationMemberModal", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should call onClose the modal when the close button is clicked", async () => {
|
||||
const onCloseMock = vi.fn();
|
||||
renderInviteOrganizationMemberModal({ onClose: onCloseMock });
|
||||
|
||||
const modal = screen.getByTestId("invite-modal");
|
||||
const closeButton = within(modal).getByRole("button", {
|
||||
name: /close/i,
|
||||
});
|
||||
await userEvent.click(closeButton);
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should call the batch API to invite a single team member when the form is submitted", async () => {
|
||||
const inviteMembersBatchSpy = vi.spyOn(
|
||||
organizationService,
|
||||
"inviteMembers",
|
||||
);
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
renderInviteOrganizationMemberModal({ onClose: onCloseMock });
|
||||
|
||||
const modal = screen.getByTestId("invite-modal");
|
||||
|
||||
const badgeInput = within(modal).getByTestId("emails-badge-input");
|
||||
await userEvent.type(badgeInput, "someone@acme.org ");
|
||||
|
||||
// Verify badge is displayed
|
||||
expect(screen.getByText("someone@acme.org")).toBeInTheDocument();
|
||||
|
||||
const submitButton = within(modal).getByRole("button", {
|
||||
name: /add/i,
|
||||
});
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
expect(inviteMembersBatchSpy).toHaveBeenCalledExactlyOnceWith({
|
||||
orgId: "1",
|
||||
emails: ["someone@acme.org"],
|
||||
});
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should allow adding multiple emails using badge input and make a batch POST request", async () => {
|
||||
const inviteMembersBatchSpy = vi.spyOn(
|
||||
organizationService,
|
||||
"inviteMembers",
|
||||
);
|
||||
const onCloseMock = vi.fn();
|
||||
|
||||
renderInviteOrganizationMemberModal({ onClose: onCloseMock });
|
||||
|
||||
const modal = screen.getByTestId("invite-modal");
|
||||
|
||||
// Should have badge input instead of regular input
|
||||
const badgeInput = within(modal).getByTestId("emails-badge-input");
|
||||
expect(badgeInput).toBeInTheDocument();
|
||||
|
||||
// Add first email by typing and pressing space
|
||||
await userEvent.type(badgeInput, "user1@acme.org ");
|
||||
|
||||
// Add second email by typing and pressing space
|
||||
await userEvent.type(badgeInput, "user2@acme.org ");
|
||||
|
||||
// Add third email by typing and pressing space
|
||||
await userEvent.type(badgeInput, "user3@acme.org ");
|
||||
|
||||
// Verify badges are displayed
|
||||
expect(screen.getByText("user1@acme.org")).toBeInTheDocument();
|
||||
expect(screen.getByText("user2@acme.org")).toBeInTheDocument();
|
||||
expect(screen.getByText("user3@acme.org")).toBeInTheDocument();
|
||||
|
||||
const submitButton = within(modal).getByRole("button", {
|
||||
name: /add/i,
|
||||
});
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Should call batch invite API with all emails
|
||||
expect(inviteMembersBatchSpy).toHaveBeenCalledExactlyOnceWith({
|
||||
orgId: "1",
|
||||
emails: ["user1@acme.org", "user2@acme.org", "user3@acme.org"],
|
||||
});
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should display an error toast when clicking add button with no emails added", async () => {
|
||||
// Arrange
|
||||
const displayErrorToastSpy = vi.spyOn(ToastHandlers, "displayErrorToast");
|
||||
const inviteMembersSpy = vi.spyOn(organizationService, "inviteMembers");
|
||||
renderInviteOrganizationMemberModal();
|
||||
|
||||
// Act
|
||||
const modal = screen.getByTestId("invite-modal");
|
||||
const submitButton = within(modal).getByRole("button", { name: /add/i });
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert
|
||||
expect(displayErrorToastSpy).toHaveBeenCalledWith(
|
||||
"ORG$NO_EMAILS_ADDED_HINT",
|
||||
);
|
||||
expect(inviteMembersSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -1,203 +0,0 @@
|
||||
import { screen, render, waitFor, within } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { OrgSelector } from "#/components/features/org/org-selector";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import {
|
||||
MOCK_PERSONAL_ORG,
|
||||
MOCK_TEAM_ORG_ACME,
|
||||
createMockOrganization,
|
||||
} from "#/mocks/org-handlers";
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
useNavigate: () => vi.fn(),
|
||||
useLocation: () => ({ pathname: "/" }),
|
||||
useMatch: () => null,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => ({ data: true }),
|
||||
}));
|
||||
|
||||
// Mock useConfig to return SaaS mode (organizations are a SaaS-only feature)
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({ data: { app_mode: "saas" } }),
|
||||
}));
|
||||
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
"ORG$SELECT_ORGANIZATION_PLACEHOLDER": "Please select an organization",
|
||||
"ORG$PERSONAL_WORKSPACE": "Personal Workspace",
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const renderOrgSelector = () =>
|
||||
render(<OrgSelector />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
describe("OrgSelector", () => {
|
||||
it("should not render when user only has a personal workspace", async () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
|
||||
const { container } = renderOrgSelector();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(container).toBeEmptyDOMElement();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render when user only has a team organization", async () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
const { container } = renderOrgSelector();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(container).not.toBeEmptyDOMElement();
|
||||
});
|
||||
expect(screen.getByRole("combobox")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show a loading indicator when fetching organizations", () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockImplementation(
|
||||
() => new Promise(() => {}), // never resolves
|
||||
);
|
||||
|
||||
renderOrgSelector();
|
||||
|
||||
// The dropdown trigger should be disabled while loading
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
expect(trigger).toBeDisabled();
|
||||
});
|
||||
|
||||
it("should select the first organization after orgs are loaded", async () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
|
||||
renderOrgSelector();
|
||||
|
||||
// The combobox input should show the first org name
|
||||
await waitFor(() => {
|
||||
const input = screen.getByRole("combobox");
|
||||
expect(input).toHaveValue("Personal Workspace");
|
||||
});
|
||||
});
|
||||
|
||||
it("should show all options when dropdown is opened", async () => {
|
||||
const user = userEvent.setup();
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [
|
||||
MOCK_PERSONAL_ORG,
|
||||
MOCK_TEAM_ORG_ACME,
|
||||
createMockOrganization("3", "Test Organization", 500),
|
||||
],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
|
||||
renderOrgSelector();
|
||||
|
||||
// Wait for the selector to be populated with the first organization
|
||||
await waitFor(() => {
|
||||
const input = screen.getByRole("combobox");
|
||||
expect(input).toHaveValue("Personal Workspace");
|
||||
});
|
||||
|
||||
// Click the trigger to open dropdown
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
// Verify all 3 options are visible
|
||||
const listbox = await screen.findByRole("listbox");
|
||||
const options = within(listbox).getAllByRole("option");
|
||||
|
||||
expect(options).toHaveLength(3);
|
||||
expect(options[0]).toHaveTextContent("Personal Workspace");
|
||||
expect(options[1]).toHaveTextContent("Acme Corp");
|
||||
expect(options[2]).toHaveTextContent("Test Organization");
|
||||
});
|
||||
|
||||
it("should call switchOrganization API when selecting a different organization", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
const switchOrgSpy = vi
|
||||
.spyOn(organizationService, "switchOrganization")
|
||||
.mockResolvedValue(MOCK_TEAM_ORG_ACME);
|
||||
|
||||
renderOrgSelector();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue("Personal Workspace");
|
||||
});
|
||||
|
||||
// Act
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
const listbox = await screen.findByRole("listbox");
|
||||
const acmeOption = within(listbox).getByText("Acme Corp");
|
||||
await user.click(acmeOption);
|
||||
|
||||
// Assert
|
||||
expect(switchOrgSpy).toHaveBeenCalledWith({ orgId: MOCK_TEAM_ORG_ACME.id });
|
||||
});
|
||||
|
||||
it("should show loading state while switching organizations", async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup();
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "switchOrganization").mockImplementation(
|
||||
() => new Promise(() => {}), // never resolves to keep loading state
|
||||
);
|
||||
|
||||
renderOrgSelector();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue("Personal Workspace");
|
||||
});
|
||||
|
||||
// Act
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
const listbox = await screen.findByRole("listbox");
|
||||
const acmeOption = within(listbox).getByText("Acme Corp");
|
||||
await user.click(acmeOption);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("dropdown-trigger")).toBeDisabled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,109 +0,0 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { SettingsNavigation } from "#/components/features/settings/settings-navigation";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { SAAS_NAV_ITEMS, SettingsNavItem } from "#/constants/settings-nav";
|
||||
|
||||
vi.mock("react-router", async () => ({
|
||||
...(await vi.importActual("react-router")),
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
const mockConfig = () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
} as Awaited<ReturnType<typeof OptionService.getConfig>>);
|
||||
};
|
||||
|
||||
const ITEMS_WITHOUT_ORG = SAAS_NAV_ITEMS.filter(
|
||||
(item) =>
|
||||
item.to !== "/settings/org" && item.to !== "/settings/org-members",
|
||||
);
|
||||
|
||||
const renderSettingsNavigation = (
|
||||
items: SettingsNavItem[] = SAAS_NAV_ITEMS,
|
||||
) => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<MemoryRouter>
|
||||
<SettingsNavigation
|
||||
isMobileMenuOpen={false}
|
||||
onCloseMobileMenu={vi.fn()}
|
||||
navigationItems={items}
|
||||
/>
|
||||
</MemoryRouter>
|
||||
</QueryClientProvider>,
|
||||
);
|
||||
};
|
||||
|
||||
describe("SettingsNavigation", () => {
|
||||
beforeEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
mockConfig();
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
});
|
||||
|
||||
describe("renders navigation items passed via props", () => {
|
||||
it("should render org routes when included in navigation items", async () => {
|
||||
renderSettingsNavigation(SAAS_NAV_ITEMS);
|
||||
|
||||
await screen.findByTestId("settings-navbar");
|
||||
|
||||
const orgMembersLink = await screen.findByText("Organization Members");
|
||||
const orgLink = await screen.findByText("Organization");
|
||||
|
||||
expect(orgMembersLink).toBeInTheDocument();
|
||||
expect(orgLink).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render org routes when excluded from navigation items", async () => {
|
||||
renderSettingsNavigation(ITEMS_WITHOUT_ORG);
|
||||
|
||||
await screen.findByTestId("settings-navbar");
|
||||
|
||||
const orgMembersLink = screen.queryByText("Organization Members");
|
||||
const orgLink = screen.queryByText("Organization");
|
||||
|
||||
expect(orgMembersLink).not.toBeInTheDocument();
|
||||
expect(orgLink).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render all non-org SAAS items regardless of which items are passed", async () => {
|
||||
renderSettingsNavigation(SAAS_NAV_ITEMS);
|
||||
|
||||
await screen.findByTestId("settings-navbar");
|
||||
|
||||
// Verify non-org items are rendered (using their i18n keys as text since
|
||||
// react-i18next returns the key when no translation is loaded)
|
||||
const secretsLink = await screen.findByText("SETTINGS$NAV_SECRETS");
|
||||
const apiKeysLink = await screen.findByText("SETTINGS$NAV_API_KEYS");
|
||||
|
||||
expect(secretsLink).toBeInTheDocument();
|
||||
expect(apiKeysLink).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render empty nav when given an empty items list", async () => {
|
||||
renderSettingsNavigation([]);
|
||||
|
||||
await screen.findByTestId("settings-navbar");
|
||||
|
||||
// No nav links should be rendered
|
||||
const orgMembersLink = screen.queryByText("Organization Members");
|
||||
const orgLink = screen.queryByText("Organization");
|
||||
|
||||
expect(orgMembersLink).not.toBeInTheDocument();
|
||||
expect(orgLink).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,633 +0,0 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeEach, describe, expect, it, test, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { UserContextMenu } from "#/components/features/user/user-context-menu";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { GetComponentPropTypes } from "#/utils/get-component-prop-types";
|
||||
import {
|
||||
INITIAL_MOCK_ORGS,
|
||||
MOCK_PERSONAL_ORG,
|
||||
MOCK_TEAM_ORG_ACME,
|
||||
} from "#/mocks/org-handlers";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import { SAAS_NAV_ITEMS, OSS_NAV_ITEMS } from "#/constants/settings-nav";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { OrganizationMember } from "#/types/org";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
|
||||
type UserContextMenuProps = GetComponentPropTypes<typeof UserContextMenu>;
|
||||
|
||||
function UserContextMenuWithRootOutlet({
|
||||
type,
|
||||
onClose,
|
||||
onOpenInviteModal,
|
||||
}: UserContextMenuProps) {
|
||||
return (
|
||||
<div>
|
||||
<div data-testid="portal-root" id="portal-root" />
|
||||
<UserContextMenu
|
||||
type={type}
|
||||
onClose={onClose}
|
||||
onOpenInviteModal={onOpenInviteModal}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const renderUserContextMenu = ({
|
||||
type,
|
||||
onClose,
|
||||
onOpenInviteModal,
|
||||
}: UserContextMenuProps) =>
|
||||
render(
|
||||
<UserContextMenuWithRootOutlet
|
||||
type={type}
|
||||
onClose={onClose}
|
||||
onOpenInviteModal={onOpenInviteModal}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<MemoryRouter>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
</MemoryRouter>
|
||||
),
|
||||
});
|
||||
|
||||
const { navigateMock } = vi.hoisted(() => ({
|
||||
navigateMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("react-router", async (importActual) => ({
|
||||
...(await importActual()),
|
||||
useNavigate: () => navigateMock,
|
||||
useRevalidator: () => ({
|
||||
revalidate: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock useIsAuthed to return authenticated state
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => ({ data: true }),
|
||||
}));
|
||||
|
||||
const createMockUser = (
|
||||
overrides: Partial<OrganizationMember> = {},
|
||||
): OrganizationMember => ({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedActiveUser = (user: Partial<OrganizationMember>) => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser(user),
|
||||
);
|
||||
};
|
||||
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
ORG$SELECT_ORGANIZATION_PLACEHOLDER: "Please select an organization",
|
||||
ORG$PERSONAL_WORKSPACE: "Personal Workspace",
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
describe("UserContextMenu", () => {
|
||||
beforeEach(() => {
|
||||
// Ensure clean state at the start of each test
|
||||
vi.restoreAllMocks();
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
navigateMock.mockClear();
|
||||
// Reset Zustand store to ensure clean state between tests
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should render the default context items for a user", () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
screen.getByTestId("org-selector");
|
||||
screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
|
||||
expect(
|
||||
screen.queryByText("ORG$INVITE_ORG_MEMBERS"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByText("ORG$ORGANIZATION_MEMBERS"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByText("COMMON$ORGANIZATION"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render navigation items from SAAS_NAV_ITEMS (except organization-members/org)", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load and verify that navigation items are rendered (except organization-members/org which are filtered out)
|
||||
const expectedItems = SAAS_NAV_ITEMS.filter(
|
||||
(item) =>
|
||||
item.to !== "/settings/org-members" &&
|
||||
item.to !== "/settings/org" &&
|
||||
item.to !== "/settings/billing",
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expectedItems.forEach((item) => {
|
||||
expect(screen.getByText(item.text)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should render navigation items from SAAS_NAV_ITEMS when user role is admin (except organization-members/org)", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
seedActiveUser({ role: "admin" });
|
||||
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load and verify that navigation items are rendered (except organization-members/org which are filtered out)
|
||||
const expectedItems = SAAS_NAV_ITEMS.filter(
|
||||
(item) =>
|
||||
item.to !== "/settings/org-members" && item.to !== "/settings/org",
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expectedItems.forEach((item) => {
|
||||
expect(screen.getByText(item.text)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should not display Organization Members menu item for regular users (filtered out)", () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Organization Members is filtered out from nav items for all users
|
||||
expect(screen.queryByText("Organization Members")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render a documentation link", () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
const docsLink = screen.getByText("SIDEBAR$DOCS").closest("a");
|
||||
expect(docsLink).toHaveAttribute("href", "https://docs.openhands.dev");
|
||||
expect(docsLink).toHaveAttribute("target", "_blank");
|
||||
});
|
||||
|
||||
describe("OSS mode", () => {
|
||||
beforeEach(() => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "oss",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it("should render OSS_NAV_ITEMS when in OSS mode", async () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for the config to load and OSS nav items to appear
|
||||
await waitFor(() => {
|
||||
OSS_NAV_ITEMS.forEach((item) => {
|
||||
expect(screen.getByText(item.text)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
// Verify SAAS-only items are NOT rendered (e.g., Billing)
|
||||
expect(
|
||||
screen.queryByText("SETTINGS$NAV_BILLING"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display Organization Members menu item in OSS mode", async () => {
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for the config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("SETTINGS$NAV_LLM")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify Organization Members is NOT rendered in OSS mode
|
||||
expect(
|
||||
screen.queryByText("Organization Members"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("HIDE_LLM_SETTINGS feature flag", () => {
|
||||
it("should hide LLM settings link when HIDE_LLM_SETTINGS is true", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: true,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
await waitFor(() => {
|
||||
// Other nav items should still be visible
|
||||
expect(screen.getByText("SETTINGS$NAV_USER")).toBeInTheDocument();
|
||||
// LLM settings (to: "/settings") should NOT be visible
|
||||
expect(
|
||||
screen.queryByText("COMMON$LANGUAGE_MODEL_LLM"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should show LLM settings link when HIDE_LLM_SETTINGS is false", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText("COMMON$LANGUAGE_MODEL_LLM"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should render additional context items when user is an admin", () => {
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
screen.getByTestId("org-selector");
|
||||
screen.getByText("ORG$INVITE_ORG_MEMBERS");
|
||||
screen.getByText("ORG$ORGANIZATION_MEMBERS");
|
||||
screen.getByText("COMMON$ORGANIZATION");
|
||||
});
|
||||
|
||||
it("should render additional context items when user is an owner", () => {
|
||||
renderUserContextMenu({ type: "owner", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
screen.getByTestId("org-selector");
|
||||
screen.getByText("ORG$INVITE_ORG_MEMBERS");
|
||||
screen.getByText("ORG$ORGANIZATION_MEMBERS");
|
||||
screen.getByText("COMMON$ORGANIZATION");
|
||||
});
|
||||
|
||||
it("should call the logout handler when Logout is clicked", async () => {
|
||||
const logoutSpy = vi.spyOn(AuthService, "logout");
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await userEvent.click(logoutButton);
|
||||
|
||||
expect(logoutSpy).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should have correct navigation links for nav items", async () => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true, // Enable billing so billing link is shown
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
seedActiveUser({ role: "admin" });
|
||||
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load and test a few representative nav items have the correct href
|
||||
await waitFor(() => {
|
||||
const userLink = screen.getByText("SETTINGS$NAV_USER").closest("a");
|
||||
expect(userLink).toHaveAttribute("href", "/settings/user");
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const billingLink = screen.getByText("SETTINGS$NAV_BILLING").closest("a");
|
||||
expect(billingLink).toHaveAttribute("href", "/settings/billing");
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
const integrationsLink = screen
|
||||
.getByText("SETTINGS$NAV_INTEGRATIONS")
|
||||
.closest("a");
|
||||
expect(integrationsLink).toHaveAttribute(
|
||||
"href",
|
||||
"/settings/integrations",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should navigate to /settings/org-members when Manage Organization Members is clicked", async () => {
|
||||
// Mock a team org so org management buttons are visible (not personal org)
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for orgs to load so org management buttons are visible
|
||||
const manageOrganizationMembersButton = await screen.findByText(
|
||||
"ORG$ORGANIZATION_MEMBERS",
|
||||
);
|
||||
await userEvent.click(manageOrganizationMembersButton);
|
||||
|
||||
expect(navigateMock).toHaveBeenCalledExactlyOnceWith(
|
||||
"/settings/org-members",
|
||||
);
|
||||
});
|
||||
|
||||
it("should navigate to /settings/org when Manage Account is clicked", async () => {
|
||||
// Mock a team org so org management buttons are visible (not personal org)
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for orgs to load so org management buttons are visible
|
||||
const manageAccountButton = await screen.findByText(
|
||||
"COMMON$ORGANIZATION",
|
||||
);
|
||||
await userEvent.click(manageAccountButton);
|
||||
|
||||
expect(navigateMock).toHaveBeenCalledExactlyOnceWith("/settings/org");
|
||||
});
|
||||
|
||||
it("should call the onClose handler when clicking outside the context menu", async () => {
|
||||
const onCloseMock = vi.fn();
|
||||
renderUserContextMenu({ type: "member", onClose: onCloseMock, onOpenInviteModal: vi.fn });
|
||||
|
||||
const contextMenu = screen.getByTestId("user-context-menu");
|
||||
await userEvent.click(contextMenu);
|
||||
|
||||
expect(onCloseMock).not.toHaveBeenCalled();
|
||||
|
||||
// Simulate clicking outside the context menu
|
||||
await userEvent.click(document.body);
|
||||
|
||||
expect(onCloseMock).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("should call the onClose handler after each action", async () => {
|
||||
// Mock a team org so org management buttons are visible
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
const onCloseMock = vi.fn();
|
||||
renderUserContextMenu({ type: "owner", onClose: onCloseMock, onOpenInviteModal: vi.fn });
|
||||
|
||||
const logoutButton = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await userEvent.click(logoutButton);
|
||||
expect(onCloseMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Wait for orgs to load so org management buttons are visible
|
||||
const manageOrganizationMembersButton = await screen.findByText(
|
||||
"ORG$ORGANIZATION_MEMBERS",
|
||||
);
|
||||
await userEvent.click(manageOrganizationMembersButton);
|
||||
expect(onCloseMock).toHaveBeenCalledTimes(2);
|
||||
|
||||
const manageAccountButton = screen.getByText("COMMON$ORGANIZATION");
|
||||
await userEvent.click(manageAccountButton);
|
||||
expect(onCloseMock).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
describe("Personal org vs team org visibility", () => {
|
||||
it("should not show Organization and Organization Members settings items when personal org is selected", async () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
// Pre-select the personal org in the Zustand store
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for org selector to load and org management buttons to disappear
|
||||
// (they disappear when personal org is selected)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText("ORG$ORGANIZATION_MEMBERS"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(
|
||||
screen.queryByText("COMMON$ORGANIZATION"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show Billing settings item when team org is selected", async () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
renderUserContextMenu({ type: "admin", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for org selector to load and billing to disappear
|
||||
// (billing disappears when team org is selected)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByText("SETTINGS$NAV_BILLING"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should call onOpenInviteModal and onClose when Invite Organization Member is clicked", async () => {
|
||||
// Mock a team org so org management buttons are visible (not personal org)
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
const onCloseMock = vi.fn();
|
||||
const onOpenInviteModalMock = vi.fn();
|
||||
renderUserContextMenu({
|
||||
type: "admin",
|
||||
onClose: onCloseMock,
|
||||
onOpenInviteModal: onOpenInviteModalMock,
|
||||
});
|
||||
|
||||
// Wait for orgs to load so org management buttons are visible
|
||||
const inviteButton = await screen.findByText("ORG$INVITE_ORG_MEMBERS");
|
||||
await userEvent.click(inviteButton);
|
||||
|
||||
expect(onOpenInviteModalMock).toHaveBeenCalledOnce();
|
||||
expect(onCloseMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
test("the user can change orgs", async () => {
|
||||
// Mock SaaS mode and organizations for this test
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: INITIAL_MOCK_ORGS,
|
||||
currentOrgId: INITIAL_MOCK_ORGS[0].id,
|
||||
});
|
||||
|
||||
const user = userEvent.setup();
|
||||
const onCloseMock = vi.fn();
|
||||
renderUserContextMenu({ type: "member", onClose: onCloseMock, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for org selector to appear (it may take a moment for config to load)
|
||||
const orgSelector = await screen.findByTestId("org-selector");
|
||||
expect(orgSelector).toBeInTheDocument();
|
||||
|
||||
// Wait for organizations to load (indicated by org name appearing in the dropdown)
|
||||
// INITIAL_MOCK_ORGS[0] is a personal org, so it displays "Personal Workspace"
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue("Personal Workspace");
|
||||
});
|
||||
|
||||
// Open the dropdown by clicking the trigger
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
// Select a different organization
|
||||
const orgOption = screen.getByRole("option", {
|
||||
name: INITIAL_MOCK_ORGS[1].name,
|
||||
});
|
||||
await user.click(orgOption);
|
||||
|
||||
expect(onCloseMock).not.toHaveBeenCalled();
|
||||
|
||||
// Verify that the dropdown shows the selected organization
|
||||
expect(screen.getByRole("combobox")).toHaveValue(INITIAL_MOCK_ORGS[1].name);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,219 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { test, expect, describe, vi } from "vitest";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import translations from "../../src/i18n/translation.json";
|
||||
import { UserAvatar } from "../../src/components/features/sidebar/user-avatar";
|
||||
|
||||
vi.mock("@heroui/react", () => ({
|
||||
Tooltip: ({
|
||||
content,
|
||||
children,
|
||||
}: {
|
||||
content: string;
|
||||
children: React.ReactNode;
|
||||
}) => (
|
||||
<div>
|
||||
{children}
|
||||
<div>{content}</div>
|
||||
</div>
|
||||
),
|
||||
}));
|
||||
|
||||
const supportedLanguages = [
|
||||
"en",
|
||||
"ja",
|
||||
"zh-CN",
|
||||
"zh-TW",
|
||||
"ko-KR",
|
||||
"de",
|
||||
"no",
|
||||
"it",
|
||||
"pt",
|
||||
"es",
|
||||
"ar",
|
||||
"fr",
|
||||
"tr",
|
||||
];
|
||||
|
||||
// Helper function to check if a translation exists for all supported languages
|
||||
function checkTranslationExists(key: string) {
|
||||
const missingTranslations: string[] = [];
|
||||
|
||||
const translationEntry = (
|
||||
translations as Record<string, Record<string, string>>
|
||||
)[key];
|
||||
if (!translationEntry) {
|
||||
throw new Error(
|
||||
`Translation key "${key}" does not exist in translation.json`,
|
||||
);
|
||||
}
|
||||
|
||||
for (const lang of supportedLanguages) {
|
||||
if (!translationEntry[lang]) {
|
||||
missingTranslations.push(lang);
|
||||
}
|
||||
}
|
||||
|
||||
return missingTranslations;
|
||||
}
|
||||
|
||||
// Helper function to find duplicate translation keys
|
||||
function findDuplicateKeys(obj: Record<string, any>) {
|
||||
const seen = new Set<string>();
|
||||
const duplicates = new Set<string>();
|
||||
|
||||
// Only check top-level keys as these are our translation keys
|
||||
for (const key in obj) {
|
||||
if (seen.has(key)) {
|
||||
duplicates.add(key);
|
||||
} else {
|
||||
seen.add(key);
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(duplicates);
|
||||
}
|
||||
|
||||
vi.mock("react-i18next", () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translationEntry = (
|
||||
translations as Record<string, Record<string, string>>
|
||||
)[key];
|
||||
return translationEntry?.ja || key;
|
||||
},
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("Landing page translations", () => {
|
||||
test("should render Japanese translations correctly", () => {
|
||||
// Mock a simple component that uses the translations
|
||||
const TestComponent = () => {
|
||||
const { t } = useTranslation();
|
||||
return (
|
||||
<div>
|
||||
<UserAvatar onClick={() => {}} />
|
||||
<div data-testid="main-content">
|
||||
<h1>{t("LANDING$TITLE")}</h1>
|
||||
<button>{t("VSCODE$OPEN")}</button>
|
||||
<button>{t("SUGGESTIONS$INCREASE_TEST_COVERAGE")}</button>
|
||||
<button>{t("SUGGESTIONS$AUTO_MERGE_PRS")}</button>
|
||||
<button>{t("SUGGESTIONS$FIX_README")}</button>
|
||||
<button>{t("SUGGESTIONS$CLEAN_DEPENDENCIES")}</button>
|
||||
</div>
|
||||
<div data-testid="tabs">
|
||||
<span>{t("WORKSPACE$TERMINAL_TAB_LABEL")}</span>
|
||||
<span>{t("WORKSPACE$BROWSER_TAB_LABEL")}</span>
|
||||
<span>{t("WORKSPACE$JUPYTER_TAB_LABEL")}</span>
|
||||
<span>{t("WORKSPACE$CODE_EDITOR_TAB_LABEL")}</span>
|
||||
</div>
|
||||
<div data-testid="workspace-label">{t("WORKSPACE$TITLE")}</div>
|
||||
<button data-testid="new-project">{t("PROJECT$NEW_PROJECT")}</button>
|
||||
<div data-testid="status">
|
||||
<span>{t("TERMINAL$WAITING_FOR_CLIENT")}</span>
|
||||
<span>{t("STATUS$CONNECTED")}</span>
|
||||
<span>{t("STATUS$CONNECTED_TO_SERVER")}</span>
|
||||
</div>
|
||||
<div data-testid="time">
|
||||
<span>{`5 ${t("TIME$MINUTES_AGO")}`}</span>
|
||||
<span>{`2 ${t("TIME$HOURS_AGO")}`}</span>
|
||||
<span>{`3 ${t("TIME$DAYS_AGO")}`}</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
render(<TestComponent />);
|
||||
|
||||
// Check main content translations
|
||||
expect(screen.getByText("開発を始めましょう!")).toBeInTheDocument();
|
||||
expect(screen.getByText("VS Codeで開く")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("テストカバレッジを向上させる"),
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("Dependabot PRを自動マージ")).toBeInTheDocument();
|
||||
expect(screen.getByText("READMEを改善")).toBeInTheDocument();
|
||||
expect(screen.getByText("依存関係を整理")).toBeInTheDocument();
|
||||
|
||||
// Check tab labels
|
||||
const tabs = screen.getByTestId("tabs");
|
||||
expect(tabs).toHaveTextContent("ターミナル");
|
||||
expect(tabs).toHaveTextContent("ブラウザ");
|
||||
expect(tabs).toHaveTextContent("Jupyter");
|
||||
expect(tabs).toHaveTextContent("コードエディタ");
|
||||
|
||||
// Check workspace label and new project button
|
||||
expect(screen.getByTestId("workspace-label")).toHaveTextContent(
|
||||
"ワークスペース",
|
||||
);
|
||||
expect(screen.getByTestId("new-project")).toHaveTextContent(
|
||||
"新規プロジェクト",
|
||||
);
|
||||
|
||||
// Check status messages
|
||||
const status = screen.getByTestId("status");
|
||||
expect(status).toHaveTextContent("クライアントの準備を待機中");
|
||||
expect(status).toHaveTextContent("接続済み");
|
||||
expect(status).toHaveTextContent("サーバーに接続済み");
|
||||
|
||||
// Check time-related translations
|
||||
const time = screen.getByTestId("time");
|
||||
expect(time).toHaveTextContent("5 分前");
|
||||
expect(time).toHaveTextContent("2 時間前");
|
||||
expect(time).toHaveTextContent("3 日前");
|
||||
});
|
||||
|
||||
test("all translation keys should have translations for all supported languages", () => {
|
||||
// Test all translation keys used in the component
|
||||
const translationKeys = [
|
||||
"LANDING$TITLE",
|
||||
"VSCODE$OPEN",
|
||||
"SUGGESTIONS$INCREASE_TEST_COVERAGE",
|
||||
"SUGGESTIONS$AUTO_MERGE_PRS",
|
||||
"SUGGESTIONS$FIX_README",
|
||||
"SUGGESTIONS$CLEAN_DEPENDENCIES",
|
||||
"WORKSPACE$TERMINAL_TAB_LABEL",
|
||||
"WORKSPACE$BROWSER_TAB_LABEL",
|
||||
"WORKSPACE$JUPYTER_TAB_LABEL",
|
||||
"WORKSPACE$CODE_EDITOR_TAB_LABEL",
|
||||
"WORKSPACE$TITLE",
|
||||
"PROJECT$NEW_PROJECT",
|
||||
"TERMINAL$WAITING_FOR_CLIENT",
|
||||
"STATUS$CONNECTED",
|
||||
"STATUS$CONNECTED_TO_SERVER",
|
||||
"TIME$MINUTES_AGO",
|
||||
"TIME$HOURS_AGO",
|
||||
"TIME$DAYS_AGO",
|
||||
];
|
||||
|
||||
// Check all keys and collect missing translations
|
||||
const missingTranslationsMap = new Map<string, string[]>();
|
||||
translationKeys.forEach((key) => {
|
||||
const missing = checkTranslationExists(key);
|
||||
if (missing.length > 0) {
|
||||
missingTranslationsMap.set(key, missing);
|
||||
}
|
||||
});
|
||||
|
||||
// If any translations are missing, throw an error with all missing translations
|
||||
if (missingTranslationsMap.size > 0) {
|
||||
const errorMessage = Array.from(missingTranslationsMap.entries())
|
||||
.map(
|
||||
([key, langs]) =>
|
||||
`\n- "${key}" is missing translations for: ${langs.join(", ")}`,
|
||||
)
|
||||
.join("");
|
||||
throw new Error(`Missing translations:${errorMessage}`);
|
||||
}
|
||||
});
|
||||
|
||||
test("translation file should not have duplicate keys", () => {
|
||||
const duplicates = findDuplicateKeys(translations);
|
||||
|
||||
if (duplicates.length > 0) {
|
||||
throw new Error(
|
||||
`Found duplicate translation keys: ${duplicates.join(", ")}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,429 +0,0 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { Dropdown } from "#/ui/dropdown/dropdown";
|
||||
|
||||
const mockOptions = [
|
||||
{ value: "1", label: "Option 1" },
|
||||
{ value: "2", label: "Option 2" },
|
||||
{ value: "3", label: "Option 3" },
|
||||
];
|
||||
|
||||
describe("Dropdown", () => {
|
||||
describe("Trigger", () => {
|
||||
it("should render a custom trigger button", () => {
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
|
||||
expect(trigger).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should open dropdown on trigger click", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
expect(screen.queryByText("Option 1")).not.toBeInTheDocument();
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const listbox = screen.getByRole("listbox");
|
||||
expect(listbox).toBeInTheDocument();
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
expect(screen.getByText("Option 2")).toBeInTheDocument();
|
||||
expect(screen.getByText("Option 3")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Type-ahead / Search", () => {
|
||||
it("should filter options based on input text", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.type(input, "Option 1");
|
||||
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Option 2")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("Option 3")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should be case-insensitive by default", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.type(input, "option 1");
|
||||
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Option 2")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show all options when search is cleared", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.type(input, "Option 1");
|
||||
await user.clear(input);
|
||||
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
expect(screen.getByText("Option 2")).toBeInTheDocument();
|
||||
expect(screen.getByText("Option 3")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Empty state", () => {
|
||||
it("should display empty state when no options provided", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={[]} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
expect(screen.getByText("No options")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render custom empty state message", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={[]} emptyMessage="Nothing found" />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
expect(screen.getByText("Nothing found")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Single selection", () => {
|
||||
it("should select an option on click", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const option = screen.getByText("Option 1");
|
||||
await user.click(option);
|
||||
|
||||
expect(screen.getByRole("combobox")).toHaveValue("Option 1");
|
||||
});
|
||||
|
||||
it("should close dropdown after selection", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByText("Option 1"));
|
||||
|
||||
expect(screen.queryByText("Option 2")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display selected option in input", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByText("Option 1"));
|
||||
|
||||
expect(screen.getByRole("combobox")).toHaveValue("Option 1");
|
||||
});
|
||||
|
||||
it("should highlight currently selected option in list", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByRole("option", { name: "Option 1" }));
|
||||
|
||||
await user.click(trigger);
|
||||
|
||||
const selectedOption = screen.getByRole("option", { name: "Option 1" });
|
||||
expect(selectedOption).toHaveAttribute("aria-selected", "true");
|
||||
});
|
||||
|
||||
it("should preserve selected value in input and show all options when reopening dropdown", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByRole("option", { name: "Option 1" }));
|
||||
|
||||
// Reopen the dropdown
|
||||
await user.click(trigger);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
expect(input).toHaveValue("Option 1");
|
||||
expect(
|
||||
screen.getByRole("option", { name: "Option 1" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("option", { name: "Option 2" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("option", { name: "Option 3" }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Clear button", () => {
|
||||
it("should not render clear button by default", () => {
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
expect(screen.queryByTestId("dropdown-clear")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render clear button when clearable prop is true and has value", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} clearable />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByRole("option", { name: "Option 1" }));
|
||||
|
||||
expect(screen.getByTestId("dropdown-clear")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should clear selection and search input when clear button is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} clearable />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByRole("option", { name: "Option 1" }));
|
||||
|
||||
const clearButton = screen.getByTestId("dropdown-clear");
|
||||
await user.click(clearButton);
|
||||
|
||||
expect(screen.getByRole("combobox")).toHaveValue("");
|
||||
});
|
||||
|
||||
it("should not render clear button when there is no selection", () => {
|
||||
render(<Dropdown options={mockOptions} clearable />);
|
||||
|
||||
expect(screen.queryByTestId("dropdown-clear")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show placeholder after clearing selection", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(
|
||||
<Dropdown
|
||||
options={mockOptions}
|
||||
clearable
|
||||
placeholder="Select an option"
|
||||
/>,
|
||||
);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByRole("option", { name: "Option 1" }));
|
||||
|
||||
const clearButton = screen.getByTestId("dropdown-clear");
|
||||
await user.click(clearButton);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
expect(input).toHaveValue("");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Loading state", () => {
|
||||
it("should not display loading indicator by default", () => {
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
expect(screen.queryByTestId("dropdown-loading")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display loading indicator when loading prop is true", () => {
|
||||
render(<Dropdown options={mockOptions} loading />);
|
||||
|
||||
expect(screen.getByTestId("dropdown-loading")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should disable interaction while loading", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} loading />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
expect(trigger).toHaveAttribute("aria-expanded", "false");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Disabled state", () => {
|
||||
it("should not open dropdown when disabled", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} disabled />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
expect(trigger).toHaveAttribute("aria-expanded", "false");
|
||||
});
|
||||
|
||||
it("should have disabled attribute on trigger", () => {
|
||||
render(<Dropdown options={mockOptions} disabled />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
expect(trigger).toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Placeholder", () => {
|
||||
it("should display placeholder text when no value selected", () => {
|
||||
render(<Dropdown options={mockOptions} placeholder="Select an option" />);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
expect(input).toHaveAttribute("placeholder", "Select an option");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Default value", () => {
|
||||
it("should display defaultValue in input on mount", () => {
|
||||
render(<Dropdown options={mockOptions} defaultValue={mockOptions[0]} />);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
expect(input).toHaveValue("Option 1");
|
||||
});
|
||||
|
||||
it("should show all options when opened with defaultValue", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} defaultValue={mockOptions[0]} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
expect(
|
||||
screen.getByRole("option", { name: "Option 1" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("option", { name: "Option 2" }),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByRole("option", { name: "Option 3" }),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should restore input value when closed with Escape", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} defaultValue={mockOptions[0]} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.type(input, "test");
|
||||
await user.keyboard("{Escape}");
|
||||
|
||||
expect(input).toHaveValue("Option 1");
|
||||
});
|
||||
});
|
||||
|
||||
describe("onChange", () => {
|
||||
it("should call onChange with selected item when option is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onChangeMock = vi.fn();
|
||||
render(<Dropdown options={mockOptions} onChange={onChangeMock} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
await user.click(screen.getByRole("option", { name: "Option 1" }));
|
||||
|
||||
expect(onChangeMock).toHaveBeenCalledWith(mockOptions[0]);
|
||||
});
|
||||
|
||||
it("should call onChange with null when selection is cleared", async () => {
|
||||
const user = userEvent.setup();
|
||||
const onChangeMock = vi.fn();
|
||||
render(
|
||||
<Dropdown
|
||||
options={mockOptions}
|
||||
clearable
|
||||
defaultValue={mockOptions[0]}
|
||||
onChange={onChangeMock}
|
||||
/>,
|
||||
);
|
||||
|
||||
const clearButton = screen.getByTestId("dropdown-clear");
|
||||
await user.click(clearButton);
|
||||
|
||||
expect(onChangeMock).toHaveBeenCalledWith(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Controlled mode", () => {
|
||||
it.todo("should reflect external value changes");
|
||||
it.todo("should call onChange when selection changes");
|
||||
it.todo("should not update internal state when controlled");
|
||||
});
|
||||
|
||||
describe("Uncontrolled mode", () => {
|
||||
it.todo("should manage selection state internally");
|
||||
it.todo("should call onChange when selection changes");
|
||||
it.todo("should support defaultValue prop");
|
||||
});
|
||||
|
||||
describe("testId prop", () => {
|
||||
it("should apply custom testId to the root container", () => {
|
||||
render(<Dropdown options={mockOptions} testId="org-dropdown" />);
|
||||
|
||||
expect(screen.getByTestId("org-dropdown")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Cursor position preservation", () => {
|
||||
it("should keep menu open when clicking the input while dropdown is open", async () => {
|
||||
// Without a stateReducer, Downshift's default InputClick behavior
|
||||
// toggles the menu (closes it if already open). The stateReducer
|
||||
// should override this to keep the menu open so users can click
|
||||
// to reposition their cursor without losing the dropdown.
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
// Menu should be open
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
|
||||
// Click the input itself (simulates clicking to reposition cursor)
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.click(input);
|
||||
|
||||
// Menu should still be open — not toggled closed
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should still filter options correctly after typing with cursor fix", async () => {
|
||||
// Verifies that the direct onChange handler (which bypasses Downshift's
|
||||
// default onInputValueChange for cursor preservation) still updates
|
||||
// the search/filter state correctly.
|
||||
const user = userEvent.setup();
|
||||
render(<Dropdown options={mockOptions} />);
|
||||
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.type(input, "Option 1");
|
||||
|
||||
expect(screen.getByText("Option 1")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Option 2")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("Option 3")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,64 +1,11 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, afterEach, beforeEach, test } from "vitest";
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it, test, vi, afterEach, beforeEach } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { QueryClientProvider, QueryClient } from "@tanstack/react-query";
|
||||
import { UserActions } from "#/components/features/sidebar/user-actions";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { ReactElement } from "react";
|
||||
import { UserActions } from "#/components/features/sidebar/user-actions";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME } from "#/mocks/org-handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
|
||||
vi.mock("react-router", async (importActual) => ({
|
||||
...(await importActual()),
|
||||
useNavigate: () => vi.fn(),
|
||||
useRevalidator: () => ({
|
||||
revalidate: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
ORG$SELECT_ORGANIZATION_PLACEHOLDER: "Please select an organization",
|
||||
ORG$PERSONAL_WORKSPACE: "Personal Workspace",
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const renderUserActions = (props = { hasAvatar: true }) => {
|
||||
render(
|
||||
<UserActions
|
||||
user={
|
||||
props.hasAvatar
|
||||
? { avatar_url: "https://example.com/avatar.png" }
|
||||
: undefined
|
||||
}
|
||||
/>,
|
||||
{
|
||||
wrapper: ({ children }) => (
|
||||
<MemoryRouter>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
</MemoryRouter>
|
||||
),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
// Create mocks for all the hooks we need
|
||||
const useIsAuthedMock = vi
|
||||
.fn()
|
||||
@@ -91,8 +38,9 @@ describe("UserActions", () => {
|
||||
const onLogoutMock = vi.fn();
|
||||
|
||||
// Create a wrapper with MemoryRouter and renderWithProviders
|
||||
const renderWithRouter = (ui: ReactElement) =>
|
||||
renderWithProviders(<MemoryRouter>{ui}</MemoryRouter>);
|
||||
const renderWithRouter = (ui: ReactElement) => {
|
||||
return renderWithProviders(<MemoryRouter>{ui}</MemoryRouter>);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks to default values before each test
|
||||
@@ -113,11 +61,29 @@ describe("UserActions", () => {
|
||||
});
|
||||
|
||||
it("should render", () => {
|
||||
renderUserActions();
|
||||
renderWithRouter(<UserActions onLogout={onLogoutMock} />);
|
||||
|
||||
expect(screen.getByTestId("user-actions")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("user-avatar")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call onLogout and close the menu when the logout option is clicked", async () => {
|
||||
renderWithRouter(
|
||||
<UserActions
|
||||
onLogout={onLogoutMock}
|
||||
user={{ avatar_url: "https://example.com/avatar.png" }}
|
||||
/>,
|
||||
);
|
||||
|
||||
const userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
|
||||
const logoutOption = screen.getByText("ACCOUNT_SETTINGS$LOGOUT");
|
||||
await user.click(logoutOption);
|
||||
|
||||
expect(onLogoutMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should NOT show context menu when user is not authenticated and avatar is clicked", async () => {
|
||||
// Set isAuthed to false for this test
|
||||
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
|
||||
@@ -130,31 +96,29 @@ describe("UserActions", () => {
|
||||
providers: [{ id: "github", name: "GitHub" }],
|
||||
});
|
||||
|
||||
renderUserActions();
|
||||
renderWithRouter(<UserActions onLogout={onLogoutMock} />);
|
||||
|
||||
const userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
|
||||
// Context menu should NOT appear because user is not authenticated
|
||||
expect(screen.queryByTestId("user-context-menu")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should NOT show context menu when user is undefined and avatar is hovered", async () => {
|
||||
renderUserActions({ hasAvatar: false });
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
|
||||
// Context menu should NOT appear because user is undefined
|
||||
expect(screen.queryByTestId("user-context-menu")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("account-settings-context-menu"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show context menu even when user has no avatar_url", async () => {
|
||||
renderUserActions();
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
renderWithRouter(
|
||||
<UserActions onLogout={onLogoutMock} user={{ avatar_url: "" }} />,
|
||||
);
|
||||
|
||||
const userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
|
||||
// Context menu SHOULD appear because user object exists (even with empty avatar_url)
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByTestId("account-settings-context-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should NOT be able to access logout when user is not authenticated", async () => {
|
||||
@@ -169,13 +133,15 @@ describe("UserActions", () => {
|
||||
providers: [{ id: "github", name: "GitHub" }],
|
||||
});
|
||||
|
||||
renderWithRouter(<UserActions />);
|
||||
renderWithRouter(<UserActions onLogout={onLogoutMock} />);
|
||||
|
||||
const userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
|
||||
// Context menu should NOT appear because user is not authenticated
|
||||
expect(screen.queryByTestId("user-context-menu")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("account-settings-context-menu"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Logout option should NOT be accessible when user is not authenticated
|
||||
expect(
|
||||
@@ -195,12 +161,16 @@ describe("UserActions", () => {
|
||||
providers: [{ id: "github", name: "GitHub" }],
|
||||
});
|
||||
|
||||
const { unmount } = renderWithRouter(<UserActions />);
|
||||
const { unmount } = renderWithRouter(
|
||||
<UserActions onLogout={onLogoutMock} />,
|
||||
);
|
||||
|
||||
// Initially no user and not authenticated - menu should not appear
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
expect(screen.queryByTestId("user-context-menu")).not.toBeInTheDocument();
|
||||
let userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
expect(
|
||||
screen.queryByTestId("account-settings-context-menu"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Unmount the first component
|
||||
unmount();
|
||||
@@ -218,7 +188,10 @@ describe("UserActions", () => {
|
||||
|
||||
// Render a new component with user prop and authentication
|
||||
renderWithRouter(
|
||||
<UserActions user={{ avatar_url: "https://example.com/avatar.png" }} />,
|
||||
<UserActions
|
||||
onLogout={onLogoutMock}
|
||||
user={{ avatar_url: "https://example.com/avatar.png" }}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Component should render correctly
|
||||
@@ -226,10 +199,12 @@ describe("UserActions", () => {
|
||||
expect(screen.getByTestId("user-avatar")).toBeInTheDocument();
|
||||
|
||||
// Menu should now work with user defined and authenticated
|
||||
const userActionsEl = screen.getByTestId("user-actions");
|
||||
await user.hover(userActionsEl);
|
||||
userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByTestId("account-settings-context-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should handle user prop changing from defined to undefined", async () => {
|
||||
@@ -244,13 +219,18 @@ describe("UserActions", () => {
|
||||
});
|
||||
|
||||
const { rerender } = renderWithRouter(
|
||||
<UserActions user={{ avatar_url: "https://example.com/avatar.png" }} />,
|
||||
<UserActions
|
||||
onLogout={onLogoutMock}
|
||||
user={{ avatar_url: "https://example.com/avatar.png" }}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Hover to open menu
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
// Click to open menu
|
||||
const userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
expect(
|
||||
screen.getByTestId("account-settings-context-menu"),
|
||||
).toBeInTheDocument();
|
||||
|
||||
// Set authentication to false for the rerender
|
||||
useIsAuthedMock.mockReturnValue({ data: false, isLoading: false });
|
||||
@@ -266,12 +246,14 @@ describe("UserActions", () => {
|
||||
// Remove user prop - menu should disappear because user is no longer authenticated
|
||||
rerender(
|
||||
<MemoryRouter>
|
||||
<UserActions />
|
||||
<UserActions onLogout={onLogoutMock} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Context menu should NOT be visible when user becomes unauthenticated
|
||||
expect(screen.queryByTestId("user-context-menu")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByTestId("account-settings-context-menu"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Logout option should not be accessible
|
||||
expect(
|
||||
@@ -290,168 +272,20 @@ describe("UserActions", () => {
|
||||
providers: [{ id: "github", name: "GitHub" }],
|
||||
});
|
||||
|
||||
renderUserActions();
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
renderWithRouter(
|
||||
<UserActions
|
||||
onLogout={onLogoutMock}
|
||||
user={{ avatar_url: "https://example.com/avatar.png" }}
|
||||
isLoading={true}
|
||||
/>,
|
||||
);
|
||||
|
||||
const userAvatar = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatar);
|
||||
|
||||
// Context menu should still appear even when loading
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("context menu should default to user role", async () => {
|
||||
renderUserActions();
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
|
||||
// Verify logout is present
|
||||
expect(screen.getByTestId("user-context-menu")).toHaveTextContent(
|
||||
"ACCOUNT_SETTINGS$LOGOUT",
|
||||
);
|
||||
// Verify nav items are present (e.g., settings nav items)
|
||||
expect(screen.getByTestId("user-context-menu")).toHaveTextContent(
|
||||
"SETTINGS$NAV_USER",
|
||||
);
|
||||
// Verify admin-only items are NOT present for user role
|
||||
expect(
|
||||
screen.queryByText("ORG$MANAGE_ORGANIZATION_MEMBERS"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByText("ORG$MANAGE_ORGANIZATION"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("should NOT show Team and Organization nav items when personal workspace is selected", async () => {
|
||||
renderUserActions();
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
|
||||
// Team and Organization nav links should NOT be visible when no org is selected (personal workspace)
|
||||
expect(screen.queryByText("Team")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("Organization")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should show context menu on hover", async () => {
|
||||
renderUserActions();
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
const contextMenu = screen.getByTestId("user-context-menu");
|
||||
|
||||
// Menu is in DOM but hidden via CSS (opacity-0, pointer-events-none)
|
||||
expect(contextMenu.parentElement).toHaveClass("opacity-0");
|
||||
expect(contextMenu.parentElement).toHaveClass("pointer-events-none");
|
||||
|
||||
// Hover over the user actions area
|
||||
await user.hover(userActions);
|
||||
|
||||
// Menu should be visible on hover (CSS classes change via group-hover)
|
||||
expect(contextMenu).toBeVisible();
|
||||
});
|
||||
|
||||
it("should have pointer-events-none on hover bridge pseudo-element to allow menu item clicks", async () => {
|
||||
renderUserActions();
|
||||
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
await user.hover(userActions);
|
||||
|
||||
const contextMenu = screen.getByTestId("user-context-menu");
|
||||
const hoverBridgeContainer = contextMenu.parentElement;
|
||||
|
||||
// The hover bridge uses a ::before pseudo-element for diagonal mouse movement
|
||||
// This pseudo-element MUST have pointer-events-none to allow clicks through to menu items
|
||||
// The class should include "before:pointer-events-none" to prevent the hover bridge from blocking clicks
|
||||
expect(hoverBridgeContainer?.className).toContain(
|
||||
"before:pointer-events-none",
|
||||
);
|
||||
});
|
||||
|
||||
describe("Org selector dropdown state reset when context menu hides", () => {
|
||||
// These tests verify that the org selector dropdown resets its internal
|
||||
// state (search text, open/closed) when the context menu hides and
|
||||
// reappears. Without this, stale state persists because the context
|
||||
// menu is hidden via CSS (opacity/pointer-events) rather than unmounted.
|
||||
|
||||
beforeEach(() => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should reset org selector search text when context menu hides and reappears", async () => {
|
||||
renderUserActions();
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
|
||||
// Hover to show context menu
|
||||
await user.hover(userActions);
|
||||
|
||||
// Wait for orgs to load and auto-select
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue(
|
||||
MOCK_PERSONAL_ORG.name,
|
||||
);
|
||||
});
|
||||
|
||||
// Open dropdown and type search text
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.clear(input);
|
||||
await user.type(input, "search text");
|
||||
expect(input).toHaveValue("search text");
|
||||
|
||||
// Unhover to hide context menu, then hover again
|
||||
await user.unhover(userActions);
|
||||
await user.hover(userActions);
|
||||
|
||||
// Org selector should be reset — showing selected org name, not search text
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue(
|
||||
MOCK_PERSONAL_ORG.name,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should reset dropdown to collapsed state when context menu hides and reappears", async () => {
|
||||
renderUserActions();
|
||||
const userActions = screen.getByTestId("user-actions");
|
||||
|
||||
// Hover to show context menu
|
||||
await user.hover(userActions);
|
||||
|
||||
// Wait for orgs to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue(
|
||||
MOCK_PERSONAL_ORG.name,
|
||||
);
|
||||
});
|
||||
|
||||
// Open dropdown and type to change its state
|
||||
const trigger = screen.getByTestId("dropdown-trigger");
|
||||
await user.click(trigger);
|
||||
const input = screen.getByRole("combobox");
|
||||
await user.clear(input);
|
||||
await user.type(input, "Acme");
|
||||
expect(input).toHaveValue("Acme");
|
||||
|
||||
// Unhover to hide context menu, then hover again
|
||||
await user.unhover(userActions);
|
||||
await user.hover(userActions);
|
||||
|
||||
// Wait for fresh component with org data
|
||||
await waitFor(() => {
|
||||
expect(screen.getByRole("combobox")).toHaveValue(
|
||||
MOCK_PERSONAL_ORG.name,
|
||||
);
|
||||
});
|
||||
|
||||
// Dropdown should be collapsed (closed) after reset
|
||||
expect(screen.getByTestId("dropdown-trigger")).toHaveAttribute(
|
||||
"aria-expanded",
|
||||
"false",
|
||||
);
|
||||
// No option elements should be rendered
|
||||
expect(screen.queryAllByRole("option")).toHaveLength(0);
|
||||
});
|
||||
screen.getByTestId("account-settings-context-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,18 +1,40 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { UserAvatar } from "#/components/features/sidebar/user-avatar";
|
||||
|
||||
describe("UserAvatar", () => {
|
||||
const onClickMock = vi.fn();
|
||||
|
||||
afterEach(() => {
|
||||
onClickMock.mockClear();
|
||||
});
|
||||
|
||||
it("(default) should render the placeholder avatar when the user is logged out", () => {
|
||||
render(<UserAvatar />);
|
||||
render(<UserAvatar onClick={onClickMock} />);
|
||||
expect(screen.getByTestId("user-avatar")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByLabelText("USER$AVATAR_PLACEHOLDER"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call onClick when clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
render(<UserAvatar onClick={onClickMock} />);
|
||||
|
||||
const userAvatarContainer = screen.getByTestId("user-avatar");
|
||||
await user.click(userAvatarContainer);
|
||||
|
||||
expect(onClickMock).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("should display the user's avatar when available", () => {
|
||||
render(<UserAvatar avatarUrl="https://example.com/avatar.png" />);
|
||||
render(
|
||||
<UserAvatar
|
||||
onClick={onClickMock}
|
||||
avatarUrl="https://example.com/avatar.png"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByAltText("AVATAR$ALT_TEXT")).toBeInTheDocument();
|
||||
expect(
|
||||
@@ -21,20 +43,24 @@ describe("UserAvatar", () => {
|
||||
});
|
||||
|
||||
it("should display a loading spinner instead of an avatar when isLoading is true", () => {
|
||||
const { rerender } = render(<UserAvatar />);
|
||||
const { rerender } = render(<UserAvatar onClick={onClickMock} />);
|
||||
expect(screen.queryByTestId("loading-spinner")).not.toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByLabelText("USER$AVATAR_PLACEHOLDER"),
|
||||
).toBeInTheDocument();
|
||||
|
||||
rerender(<UserAvatar isLoading />);
|
||||
rerender(<UserAvatar onClick={onClickMock} isLoading />);
|
||||
expect(screen.getByTestId("loading-spinner")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.queryByLabelText("USER$AVATAR_PLACEHOLDER"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
rerender(
|
||||
<UserAvatar avatarUrl="https://example.com/avatar.png" isLoading />,
|
||||
<UserAvatar
|
||||
onClick={onClickMock}
|
||||
avatarUrl="https://example.com/avatar.png"
|
||||
isLoading
|
||||
/>,
|
||||
);
|
||||
expect(screen.getByTestId("loading-spinner")).toBeInTheDocument();
|
||||
expect(screen.queryByAltText("AVATAR$ALT_TEXT")).not.toBeInTheDocument();
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
|
||||
/**
|
||||
* Creates a mock WebClientConfig with all required fields.
|
||||
* Use this helper to create test config objects with sensible defaults.
|
||||
*/
|
||||
export const createMockWebClientConfig = (
|
||||
overrides: Partial<WebClientConfig> = {},
|
||||
): WebClientConfig => ({
|
||||
app_mode: "oss",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
...overrides.feature_flags,
|
||||
},
|
||||
providers_configured: [],
|
||||
maintenance_start_time: null,
|
||||
auth_url: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
error_message: null,
|
||||
updated_at: new Date().toISOString(),
|
||||
github_app_slug: null,
|
||||
...overrides,
|
||||
});
|
||||
@@ -1,45 +0,0 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { useInviteMembersBatch } from "#/hooks/mutation/use-invite-members-batch";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock the useRevalidator hook from react-router
|
||||
vi.mock("react-router", () => ({
|
||||
useRevalidator: () => ({
|
||||
revalidate: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("useInviteMembersBatch", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should throw an error when organizationId is null", async () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
mutations: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useInviteMembersBatch(), {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Attempt to mutate without organizationId
|
||||
result.current.mutate({ emails: ["test@example.com"] });
|
||||
|
||||
// Should fail with an error about missing organizationId
|
||||
await waitFor(() => {
|
||||
expect(result.current.isError).toBe(true);
|
||||
});
|
||||
|
||||
expect(result.current.error).toBeInstanceOf(Error);
|
||||
expect(result.current.error?.message).toBe("Organization ID is required");
|
||||
});
|
||||
});
|
||||
@@ -1,45 +0,0 @@
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { useRemoveMember } from "#/hooks/mutation/use-remove-member";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
// Mock the useRevalidator hook from react-router
|
||||
vi.mock("react-router", () => ({
|
||||
useRevalidator: () => ({
|
||||
revalidate: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("useRemoveMember", () => {
|
||||
beforeEach(() => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
it("should throw an error when organizationId is null", async () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
mutations: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useRemoveMember(), {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Attempt to mutate without organizationId
|
||||
result.current.mutate({ userId: "user-123" });
|
||||
|
||||
// Should fail with an error about missing organizationId
|
||||
await waitFor(() => {
|
||||
expect(result.current.isError).toBe(true);
|
||||
});
|
||||
|
||||
expect(result.current.error).toBeInstanceOf(Error);
|
||||
expect(result.current.error?.message).toBe("Organization ID is required");
|
||||
});
|
||||
});
|
||||
@@ -1,174 +0,0 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { useOrganizations } from "#/hooks/query/use-organizations";
|
||||
import type { Organization } from "#/types/org";
|
||||
|
||||
vi.mock("#/api/organization-service/organization-service.api", () => ({
|
||||
organizationService: {
|
||||
getOrganizations: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock useIsAuthed to return authenticated
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => ({ data: true }),
|
||||
}));
|
||||
|
||||
// Mock useConfig to return SaaS mode (organizations are a SaaS-only feature)
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => ({ data: { app_mode: "saas" } }),
|
||||
}));
|
||||
|
||||
const mockGetOrganizations = vi.mocked(organizationService.getOrganizations);
|
||||
|
||||
function createMinimalOrg(
|
||||
id: string,
|
||||
name: string,
|
||||
is_personal?: boolean,
|
||||
): Organization {
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
is_personal,
|
||||
contact_name: "",
|
||||
contact_email: "",
|
||||
conversation_expiration: 0,
|
||||
agent: "",
|
||||
default_max_iterations: 0,
|
||||
security_analyzer: "",
|
||||
confirmation_mode: false,
|
||||
default_llm_model: "",
|
||||
default_llm_api_key_for_byor: "",
|
||||
default_llm_base_url: "",
|
||||
remote_runtime_resource_factor: 0,
|
||||
enable_default_condenser: false,
|
||||
billing_margin: 0,
|
||||
enable_proactive_conversation_starters: false,
|
||||
sandbox_base_container_image: "",
|
||||
sandbox_runtime_container_image: "",
|
||||
org_version: 0,
|
||||
mcp_config: { tools: [], settings: {} },
|
||||
search_api_key: null,
|
||||
sandbox_api_key: null,
|
||||
max_budget_per_task: 0,
|
||||
enable_solvability_analysis: false,
|
||||
v1_enabled: false,
|
||||
credits: 0,
|
||||
};
|
||||
}
|
||||
|
||||
describe("useOrganizations", () => {
|
||||
let queryClient: QueryClient;
|
||||
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
});
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("sorts personal workspace first, then non-personal alphabetically by name", async () => {
|
||||
// API returns unsorted: Beta, Personal, Acme, All Hands
|
||||
mockGetOrganizations.mockResolvedValue({
|
||||
items: [
|
||||
createMinimalOrg("3", "Beta LLC", false),
|
||||
createMinimalOrg("1", "Personal Workspace", true),
|
||||
createMinimalOrg("2", "Acme Corp", false),
|
||||
createMinimalOrg("4", "All Hands AI", false),
|
||||
],
|
||||
currentOrgId: "1",
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useOrganizations(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isSuccess).toBe(true);
|
||||
});
|
||||
|
||||
const { organizations } = result.current.data!;
|
||||
expect(organizations).toHaveLength(4);
|
||||
expect(organizations[0].id).toBe("1");
|
||||
expect(organizations[0].is_personal).toBe(true);
|
||||
expect(organizations[0].name).toBe("Personal Workspace");
|
||||
expect(organizations[1].name).toBe("Acme Corp");
|
||||
expect(organizations[2].name).toBe("All Hands AI");
|
||||
expect(organizations[3].name).toBe("Beta LLC");
|
||||
});
|
||||
|
||||
it("treats missing is_personal as false and sorts by name", async () => {
|
||||
mockGetOrganizations.mockResolvedValue({
|
||||
items: [
|
||||
createMinimalOrg("1", "Zebra Org"), // no is_personal
|
||||
createMinimalOrg("2", "Alpha Org", true), // personal first
|
||||
createMinimalOrg("3", "Mango Org"), // no is_personal
|
||||
],
|
||||
currentOrgId: "2",
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useOrganizations(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isSuccess).toBe(true);
|
||||
});
|
||||
|
||||
const { organizations } = result.current.data!;
|
||||
expect(organizations[0].id).toBe("2");
|
||||
expect(organizations[0].is_personal).toBe(true);
|
||||
expect(organizations[1].name).toBe("Mango Org");
|
||||
expect(organizations[2].name).toBe("Zebra Org");
|
||||
});
|
||||
|
||||
it("handles missing name by treating as empty string for sort", async () => {
|
||||
const orgWithName = createMinimalOrg("2", "Beta", false);
|
||||
const orgNoName = { ...createMinimalOrg("1", "Alpha", false) };
|
||||
delete (orgNoName as Record<string, unknown>).name;
|
||||
mockGetOrganizations.mockResolvedValue({
|
||||
items: [orgWithName, orgNoName] as Organization[],
|
||||
currentOrgId: "1",
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useOrganizations(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isSuccess).toBe(true);
|
||||
});
|
||||
|
||||
const { organizations } = result.current.data!;
|
||||
// undefined name is coerced to ""; "" sorts before "Beta"
|
||||
expect(organizations[0].id).toBe("1");
|
||||
expect(organizations[1].id).toBe("2");
|
||||
expect(organizations[1].name).toBe("Beta");
|
||||
});
|
||||
|
||||
it("does not mutate the original array from the API", async () => {
|
||||
const apiOrgs = [
|
||||
createMinimalOrg("2", "Acme", false),
|
||||
createMinimalOrg("1", "Personal", true),
|
||||
];
|
||||
mockGetOrganizations.mockResolvedValue({
|
||||
items: apiOrgs,
|
||||
currentOrgId: "1",
|
||||
});
|
||||
|
||||
const { result } = renderHook(() => useOrganizations(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isSuccess).toBe(true);
|
||||
});
|
||||
|
||||
// Hook sorts a copy ([...data]), so API order unchanged
|
||||
expect(apiOrgs[0].id).toBe("2");
|
||||
expect(apiOrgs[1].id).toBe("1");
|
||||
// Returned data is sorted
|
||||
expect(result.current.data!.organizations[0].id).toBe("1");
|
||||
expect(result.current.data!.organizations[1].id).toBe("2");
|
||||
});
|
||||
});
|
||||
@@ -1,134 +0,0 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { renderHook, waitFor } from "@testing-library/react";
|
||||
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||
import { useOrgTypeAndAccess } from "#/hooks/use-org-type-and-access";
|
||||
|
||||
// Mock the dependencies
|
||||
vi.mock("#/context/use-selected-organization", () => ({
|
||||
useSelectedOrganizationId: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-organizations", () => ({
|
||||
useOrganizations: vi.fn(),
|
||||
}));
|
||||
|
||||
// Import mocked modules
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
import { useOrganizations } from "#/hooks/query/use-organizations";
|
||||
|
||||
const mockUseSelectedOrganizationId = vi.mocked(useSelectedOrganizationId);
|
||||
const mockUseOrganizations = vi.mocked(useOrganizations);
|
||||
|
||||
const queryClient = new QueryClient();
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
|
||||
describe("useOrgTypeAndAccess", () => {
|
||||
beforeEach(() => {
|
||||
queryClient.clear();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should return false for all booleans when no organization is selected", async () => {
|
||||
mockUseSelectedOrganizationId.mockReturnValue({
|
||||
organizationId: null,
|
||||
setOrganizationId: vi.fn(),
|
||||
});
|
||||
mockUseOrganizations.mockReturnValue({
|
||||
data: { organizations: [], currentOrgId: null },
|
||||
} as unknown as ReturnType<typeof useOrganizations>);
|
||||
|
||||
const { result } = renderHook(() => useOrgTypeAndAccess(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.selectedOrg).toBeUndefined();
|
||||
expect(result.current.isPersonalOrg).toBe(false);
|
||||
expect(result.current.isTeamOrg).toBe(false);
|
||||
expect(result.current.canViewOrgRoutes).toBe(false);
|
||||
expect(result.current.organizationId).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
it("should return isPersonalOrg=true and isTeamOrg=false for personal org", async () => {
|
||||
const personalOrg = { id: "org-1", is_personal: true, name: "Personal" };
|
||||
mockUseSelectedOrganizationId.mockReturnValue({
|
||||
organizationId: "org-1",
|
||||
setOrganizationId: vi.fn(),
|
||||
});
|
||||
mockUseOrganizations.mockReturnValue({
|
||||
data: { organizations: [personalOrg], currentOrgId: "org-1" },
|
||||
} as unknown as ReturnType<typeof useOrganizations>);
|
||||
|
||||
const { result } = renderHook(() => useOrgTypeAndAccess(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.selectedOrg).toEqual(personalOrg);
|
||||
expect(result.current.isPersonalOrg).toBe(true);
|
||||
expect(result.current.isTeamOrg).toBe(false);
|
||||
expect(result.current.canViewOrgRoutes).toBe(false);
|
||||
expect(result.current.organizationId).toBe("org-1");
|
||||
});
|
||||
});
|
||||
|
||||
it("should return isPersonalOrg=false and isTeamOrg=true for team org", async () => {
|
||||
const teamOrg = { id: "org-2", is_personal: false, name: "Team" };
|
||||
mockUseSelectedOrganizationId.mockReturnValue({
|
||||
organizationId: "org-2",
|
||||
setOrganizationId: vi.fn(),
|
||||
});
|
||||
mockUseOrganizations.mockReturnValue({
|
||||
data: { organizations: [teamOrg], currentOrgId: "org-2" },
|
||||
} as unknown as ReturnType<typeof useOrganizations>);
|
||||
|
||||
const { result } = renderHook(() => useOrgTypeAndAccess(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.selectedOrg).toEqual(teamOrg);
|
||||
expect(result.current.isPersonalOrg).toBe(false);
|
||||
expect(result.current.isTeamOrg).toBe(true);
|
||||
expect(result.current.canViewOrgRoutes).toBe(true);
|
||||
expect(result.current.organizationId).toBe("org-2");
|
||||
});
|
||||
});
|
||||
|
||||
it("should return canViewOrgRoutes=true only when isTeamOrg AND organizationId is truthy", async () => {
|
||||
const teamOrg = { id: "org-3", is_personal: false, name: "Team" };
|
||||
mockUseSelectedOrganizationId.mockReturnValue({
|
||||
organizationId: "org-3",
|
||||
setOrganizationId: vi.fn(),
|
||||
});
|
||||
mockUseOrganizations.mockReturnValue({
|
||||
data: { organizations: [teamOrg], currentOrgId: "org-3" },
|
||||
} as unknown as ReturnType<typeof useOrganizations>);
|
||||
|
||||
const { result } = renderHook(() => useOrgTypeAndAccess(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.isTeamOrg).toBe(true);
|
||||
expect(result.current.organizationId).toBe("org-3");
|
||||
expect(result.current.canViewOrgRoutes).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
it("should treat undefined is_personal field as team org", async () => {
|
||||
// Organization without is_personal field (undefined)
|
||||
const orgWithoutPersonalField = { id: "org-4", name: "Unknown Type" };
|
||||
mockUseSelectedOrganizationId.mockReturnValue({
|
||||
organizationId: "org-4",
|
||||
setOrganizationId: vi.fn(),
|
||||
});
|
||||
mockUseOrganizations.mockReturnValue({
|
||||
data: { organizations: [orgWithoutPersonalField], currentOrgId: "org-4" },
|
||||
} as unknown as ReturnType<typeof useOrganizations>);
|
||||
|
||||
const { result } = renderHook(() => useOrgTypeAndAccess(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.selectedOrg).toEqual(orgWithoutPersonalField);
|
||||
expect(result.current.isPersonalOrg).toBe(false);
|
||||
expect(result.current.isTeamOrg).toBe(true);
|
||||
expect(result.current.canViewOrgRoutes).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,98 +0,0 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { usePermission } from "#/hooks/organizations/use-permissions";
|
||||
import { rolePermissions } from "#/utils/org/permissions";
|
||||
import { OrganizationUserRole } from "#/types/org";
|
||||
|
||||
describe("usePermission", () => {
|
||||
const setup = (role: OrganizationUserRole) =>
|
||||
renderHook(() => usePermission(role)).result.current;
|
||||
|
||||
describe("hasPermission", () => {
|
||||
it("returns true when the role has the permission", () => {
|
||||
const { hasPermission } = setup("admin");
|
||||
|
||||
expect(hasPermission("invite_user_to_organization")).toBe(true);
|
||||
});
|
||||
|
||||
it("returns false when the role does not have the permission", () => {
|
||||
const { hasPermission } = setup("member");
|
||||
|
||||
expect(hasPermission("invite_user_to_organization")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("rolePermissions integration", () => {
|
||||
it("matches the permissions defined for the role", () => {
|
||||
const { hasPermission } = setup("member");
|
||||
|
||||
rolePermissions.member.forEach((permission) => {
|
||||
expect(hasPermission(permission)).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("change_user_role permission behavior", () => {
|
||||
const run = (
|
||||
activeUserRole: OrganizationUserRole,
|
||||
targetUserId: string,
|
||||
targetRole: OrganizationUserRole,
|
||||
activeUserId = "123",
|
||||
) => {
|
||||
const { hasPermission } = renderHook(() =>
|
||||
usePermission(activeUserRole),
|
||||
).result.current;
|
||||
|
||||
// users can't change their own roles
|
||||
if (activeUserId === targetUserId) return false;
|
||||
|
||||
return hasPermission(`change_user_role:${targetRole}`);
|
||||
};
|
||||
|
||||
describe("member role", () => {
|
||||
it("cannot change any roles", () => {
|
||||
expect(run("member", "u2", "member")).toBe(false);
|
||||
expect(run("member", "u2", "admin")).toBe(false);
|
||||
expect(run("member", "u2", "owner")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("admin role", () => {
|
||||
it("cannot change owner role", () => {
|
||||
expect(run("admin", "u2", "owner")).toBe(false);
|
||||
});
|
||||
|
||||
it("can change member or admin roles", () => {
|
||||
expect(run("admin", "u2", "member")).toBe(
|
||||
rolePermissions.admin.includes("change_user_role:member")
|
||||
);
|
||||
expect(run("admin", "u2", "admin")).toBe(
|
||||
rolePermissions.admin.includes("change_user_role:admin")
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("owner role", () => {
|
||||
it("can change owner, admin, and member roles", () => {
|
||||
expect(run("owner", "u2", "admin")).toBe(
|
||||
rolePermissions.owner.includes("change_user_role:admin"),
|
||||
);
|
||||
|
||||
expect(run("owner", "u2", "member")).toBe(
|
||||
rolePermissions.owner.includes("change_user_role:member"),
|
||||
);
|
||||
|
||||
expect(run("owner", "u2", "owner")).toBe(
|
||||
rolePermissions.owner.includes("change_user_role:owner"),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("self role change", () => {
|
||||
it("is always disallowed", () => {
|
||||
expect(run("owner", "u2", "member", "u2")).toBe(false);
|
||||
expect(run("admin", "u2", "member", "u2")).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -6,54 +6,18 @@ import OptionService from "#/api/option-service/option-service.api";
|
||||
import { useSettingsNavItems } from "#/hooks/use-settings-nav-items";
|
||||
import { WebClientFeatureFlags } from "#/api/option-service/option.types";
|
||||
|
||||
// Mock useOrgTypeAndAccess
|
||||
const mockOrgTypeAndAccess = vi.hoisted(() => ({
|
||||
isPersonalOrg: false,
|
||||
isTeamOrg: false,
|
||||
organizationId: null as string | null,
|
||||
selectedOrg: null,
|
||||
canViewOrgRoutes: false,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-org-type-and-access", () => ({
|
||||
useOrgTypeAndAccess: () => mockOrgTypeAndAccess,
|
||||
}));
|
||||
|
||||
// Mock useMe
|
||||
const mockMe = vi.hoisted(() => ({
|
||||
data: null as { role: string } | null | undefined,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-me", () => ({
|
||||
useMe: () => mockMe,
|
||||
}));
|
||||
|
||||
const queryClient = new QueryClient();
|
||||
const wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
);
|
||||
|
||||
const mockConfig = (
|
||||
appMode: "saas" | "oss",
|
||||
hideLlmSettings = false,
|
||||
enableBilling = true,
|
||||
) => {
|
||||
const mockConfig = (appMode: "saas" | "oss", hideLlmSettings = false) => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
app_mode: appMode,
|
||||
feature_flags: {
|
||||
hide_llm_settings: hideLlmSettings,
|
||||
enable_billing: enableBilling,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
},
|
||||
feature_flags: { hide_llm_settings: hideLlmSettings },
|
||||
} as Awaited<ReturnType<typeof OptionService.getConfig>>);
|
||||
};
|
||||
|
||||
vi.mock("react-router", () => ({
|
||||
useRevalidator: () => ({ revalidate: vi.fn() }),
|
||||
}));
|
||||
|
||||
const mockConfigWithFeatureFlags = (
|
||||
appMode: "saas" | "oss",
|
||||
featureFlags: Partial<WebClientFeatureFlags>,
|
||||
@@ -61,7 +25,7 @@ const mockConfigWithFeatureFlags = (
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue({
|
||||
app_mode: appMode,
|
||||
feature_flags: {
|
||||
enable_billing: true, // Enable billing by default so it's not hidden
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
@@ -77,38 +41,19 @@ const mockConfigWithFeatureFlags = (
|
||||
describe("useSettingsNavItems", () => {
|
||||
beforeEach(() => {
|
||||
queryClient.clear();
|
||||
vi.restoreAllMocks();
|
||||
// Reset mock state
|
||||
mockOrgTypeAndAccess.isPersonalOrg = false;
|
||||
mockOrgTypeAndAccess.isTeamOrg = false;
|
||||
mockOrgTypeAndAccess.organizationId = null;
|
||||
mockOrgTypeAndAccess.selectedOrg = null;
|
||||
mockOrgTypeAndAccess.canViewOrgRoutes = false;
|
||||
mockMe.data = null;
|
||||
});
|
||||
|
||||
it("should return SAAS_NAV_ITEMS minus billing/org/org-members when userRole is 'member'", async () => {
|
||||
it("should return SAAS_NAV_ITEMS when app_mode is 'saas'", async () => {
|
||||
mockConfig("saas");
|
||||
mockMe.data = { role: "member" };
|
||||
mockOrgTypeAndAccess.organizationId = "org-1";
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current).toEqual(
|
||||
SAAS_NAV_ITEMS.filter(
|
||||
(item) =>
|
||||
item.to !== "/settings/billing" &&
|
||||
item.to !== "/settings/org" &&
|
||||
item.to !== "/settings/org-members",
|
||||
),
|
||||
);
|
||||
expect(result.current).toEqual(SAAS_NAV_ITEMS);
|
||||
});
|
||||
});
|
||||
|
||||
it("should return OSS_NAV_ITEMS when app_mode is 'oss'", async () => {
|
||||
mockConfig("oss");
|
||||
mockMe.data = { role: "admin" };
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
@@ -118,8 +63,6 @@ describe("useSettingsNavItems", () => {
|
||||
|
||||
it("should filter out '/settings' item when hide_llm_settings feature flag is enabled", async () => {
|
||||
mockConfig("saas", true);
|
||||
mockMe.data = { role: "admin" };
|
||||
mockOrgTypeAndAccess.organizationId = "org-1";
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
await waitFor(() => {
|
||||
@@ -129,163 +72,7 @@ describe("useSettingsNavItems", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("org-type and role-based filtering", () => {
|
||||
it("should include org routes by default for team org admin", async () => {
|
||||
mockConfig("saas");
|
||||
mockOrgTypeAndAccess.isTeamOrg = true;
|
||||
mockOrgTypeAndAccess.organizationId = "org-123";
|
||||
mockMe.data = { role: "admin" };
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
// Wait for config to load (check that any SAAS item is present)
|
||||
await waitFor(() => {
|
||||
expect(result.current.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/user"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
// Org routes should be included for team org admin
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org"),
|
||||
).toBeDefined();
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org-members"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
it("should hide org routes when isPersonalOrg is true", async () => {
|
||||
mockConfig("saas");
|
||||
mockOrgTypeAndAccess.isPersonalOrg = true;
|
||||
mockOrgTypeAndAccess.organizationId = "org-123";
|
||||
mockMe.data = { role: "admin" };
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
// Wait for config to load (check that any SAAS item is present)
|
||||
await waitFor(() => {
|
||||
expect(result.current.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/user"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
// Org routes should be filtered out for personal orgs
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org"),
|
||||
).toBeUndefined();
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org-members"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should hide org routes when user role is member", async () => {
|
||||
mockConfig("saas");
|
||||
mockOrgTypeAndAccess.isTeamOrg = true;
|
||||
mockOrgTypeAndAccess.organizationId = "org-123";
|
||||
mockMe.data = { role: "member" };
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(result.current.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/user"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
// Org routes should be hidden for members
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org"),
|
||||
).toBeUndefined();
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org-members"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should hide org routes when no organization is selected", async () => {
|
||||
mockConfig("saas");
|
||||
mockOrgTypeAndAccess.isTeamOrg = false;
|
||||
mockOrgTypeAndAccess.isPersonalOrg = false;
|
||||
mockOrgTypeAndAccess.organizationId = null;
|
||||
mockMe.data = { role: "admin" };
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(result.current.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/user"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
// Org routes should be hidden when no org is selected
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org"),
|
||||
).toBeUndefined();
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/org-members"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should hide billing route when isTeamOrg is true", async () => {
|
||||
mockConfig("saas");
|
||||
mockOrgTypeAndAccess.isTeamOrg = true;
|
||||
mockOrgTypeAndAccess.organizationId = "org-123";
|
||||
mockMe.data = { role: "admin" };
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(result.current.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/user"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
// Billing should be hidden for team orgs
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/billing"),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
it("should show billing route for personal org", async () => {
|
||||
mockConfig("saas");
|
||||
mockOrgTypeAndAccess.isPersonalOrg = true;
|
||||
mockOrgTypeAndAccess.isTeamOrg = false;
|
||||
mockOrgTypeAndAccess.organizationId = "org-123";
|
||||
mockMe.data = { role: "admin" };
|
||||
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(result.current.length).toBeGreaterThan(0);
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/user"),
|
||||
).toBeDefined();
|
||||
});
|
||||
|
||||
// Billing should be visible for personal orgs
|
||||
expect(
|
||||
result.current.find((item) => item.to === "/settings/billing"),
|
||||
).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("hide page feature flags", () => {
|
||||
beforeEach(() => {
|
||||
// Set up user as admin with org context so billing is accessible
|
||||
mockMe.data = { role: "admin" };
|
||||
mockOrgTypeAndAccess.isPersonalOrg = true; // Personal org shows billing
|
||||
mockOrgTypeAndAccess.isTeamOrg = false;
|
||||
mockOrgTypeAndAccess.organizationId = "org-1";
|
||||
});
|
||||
|
||||
it("should filter out '/settings/user' when hide_users_page is true", async () => {
|
||||
mockConfigWithFeatureFlags("saas", { hide_users_page: true });
|
||||
const { result } = renderHook(() => useSettingsNavItems(), { wrapper });
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import i18n from "../../src/i18n";
|
||||
import { AccountSettingsContextMenu } from "../../src/components/features/context-menu/account-settings-context-menu";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { MemoryRouter } from "react-router";
|
||||
|
||||
describe("Translations", () => {
|
||||
it("should render translated text", () => {
|
||||
i18n.changeLanguage("en");
|
||||
renderWithProviders(
|
||||
<MemoryRouter>
|
||||
<AccountSettingsContextMenu onLogout={() => {}} onClose={() => {}} />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
expect(
|
||||
screen.getByTestId("account-settings-context-menu"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not attempt to load unsupported language codes", async () => {
|
||||
// Test that the configuration prevents 404 errors by not attempting to load
|
||||
// unsupported language codes like 'en-US@posix'
|
||||
const originalLanguage = i18n.language;
|
||||
|
||||
try {
|
||||
// With nonExplicitSupportedLngs: false, i18next will not attempt to load
|
||||
// unsupported language codes, preventing 404 errors
|
||||
|
||||
// Test with a language code that includes region but is not in supportedLngs
|
||||
await i18n.changeLanguage("en-US@posix");
|
||||
|
||||
// Since "en-US@posix" is not in supportedLngs and nonExplicitSupportedLngs is false,
|
||||
// i18next should fall back to the fallbackLng ("en")
|
||||
expect(i18n.language).toBe("en");
|
||||
|
||||
// Test another unsupported region code
|
||||
await i18n.changeLanguage("ja-JP");
|
||||
|
||||
// Even with nonExplicitSupportedLngs: false, i18next still falls back to base language
|
||||
// if it exists in supportedLngs, but importantly, it won't make a 404 request first
|
||||
expect(i18n.language).toBe("ja");
|
||||
|
||||
// Test that supported languages still work
|
||||
await i18n.changeLanguage("ja");
|
||||
expect(i18n.language).toBe("ja");
|
||||
|
||||
await i18n.changeLanguage("zh-CN");
|
||||
expect(i18n.language).toBe("zh-CN");
|
||||
|
||||
} finally {
|
||||
// Restore the original language
|
||||
await i18n.changeLanguage(originalLanguage);
|
||||
}
|
||||
});
|
||||
|
||||
it("should have proper i18n configuration", () => {
|
||||
// Test that the i18n instance has the expected configuration
|
||||
expect(i18n.options.supportedLngs).toBeDefined();
|
||||
|
||||
// nonExplicitSupportedLngs should be false to prevent 404 errors
|
||||
expect(i18n.options.nonExplicitSupportedLngs).toBe(false);
|
||||
|
||||
// fallbackLng can be a string or array, check if it includes "en"
|
||||
const fallbackLng = i18n.options.fallbackLng;
|
||||
if (Array.isArray(fallbackLng)) {
|
||||
expect(fallbackLng).toContain("en");
|
||||
} else {
|
||||
expect(fallbackLng).toBe("en");
|
||||
}
|
||||
|
||||
// Test that supported languages include both base and region-specific codes
|
||||
const supportedLngs = i18n.options.supportedLngs as string[];
|
||||
expect(supportedLngs).toContain("en");
|
||||
expect(supportedLngs).toContain("zh-CN");
|
||||
expect(supportedLngs).toContain("zh-TW");
|
||||
expect(supportedLngs).toContain("ko-KR");
|
||||
});
|
||||
});
|
||||
@@ -1,10 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { clientLoader } from "#/routes/api-keys";
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
it("should export a clientLoader for route protection", () => {
|
||||
// This test verifies the clientLoader is exported (for consistency with other routes)
|
||||
expect(clientLoader).toBeDefined();
|
||||
expect(typeof clientLoader).toBe("function");
|
||||
});
|
||||
});
|
||||
@@ -2,7 +2,7 @@ import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { afterEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import AppSettingsScreen, { clientLoader } from "#/routes/app-settings";
|
||||
import AppSettingsScreen from "#/routes/app-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { AvailableLanguages } from "#/i18n";
|
||||
@@ -18,14 +18,6 @@ const renderAppSettingsScreen = () =>
|
||||
),
|
||||
});
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
it("should export a clientLoader for route protection", () => {
|
||||
// This test verifies the clientLoader is exported (for consistency with other routes)
|
||||
expect(clientLoader).toBeDefined();
|
||||
expect(typeof clientLoader).toBe("function");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content", () => {
|
||||
it("should render the screen", () => {
|
||||
renderAppSettingsScreen();
|
||||
|
||||
@@ -1,367 +0,0 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { QueryClientProvider } from "@tanstack/react-query";
|
||||
import BillingSettingsScreen, { clientLoader } from "#/routes/billing";
|
||||
import { PaymentForm } from "#/components/features/payment/payment-form";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { OrganizationMember } from "#/types/org";
|
||||
import * as orgStore from "#/stores/selected-organization-store";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
|
||||
// Mock the i18next hook
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
// Mock useTracking hook
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
trackCreditsPurchased: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock useBalance hook
|
||||
const mockUseBalance = vi.fn();
|
||||
vi.mock("#/hooks/query/use-balance", () => ({
|
||||
useBalance: () => mockUseBalance(),
|
||||
}));
|
||||
|
||||
// Mock useCreateStripeCheckoutSession hook
|
||||
vi.mock(
|
||||
"#/hooks/mutation/stripe/use-create-stripe-checkout-session",
|
||||
() => ({
|
||||
useCreateStripeCheckoutSession: () => ({
|
||||
mutate: vi.fn(),
|
||||
isPending: false,
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
describe("Billing Route", () => {
|
||||
const { mockQueryClient } = vi.hoisted(() => ({
|
||||
mockQueryClient: (() => {
|
||||
const { QueryClient } = require("@tanstack/react-query");
|
||||
return new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
});
|
||||
})(),
|
||||
}));
|
||||
|
||||
// Mock queryClient to use our test instance
|
||||
vi.mock("#/query-client-config", () => ({
|
||||
queryClient: mockQueryClient,
|
||||
}));
|
||||
|
||||
const createMockUser = (
|
||||
overrides: Partial<OrganizationMember> = {},
|
||||
): OrganizationMember => ({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedActiveUser = (user: Partial<OrganizationMember>) => {
|
||||
orgStore.useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser(user),
|
||||
);
|
||||
};
|
||||
|
||||
const setupSaasMode = (featureFlags = {}) => {
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
...featureFlags,
|
||||
},
|
||||
}),
|
||||
);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockQueryClient.clear();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe("clientLoader cache key", () => {
|
||||
it("should use the 'web-client-config' query key to read cached config", async () => {
|
||||
// Arrange: pre-populate the cache under the canonical key
|
||||
seedActiveUser({ role: "admin" });
|
||||
const cachedConfig = {
|
||||
app_mode: "saas" as const,
|
||||
posthog_client_key: "test",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
},
|
||||
};
|
||||
mockQueryClient.setQueryData(["web-client-config"], cachedConfig);
|
||||
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
|
||||
// Act: invoke the clientLoader directly
|
||||
const result = await clientLoader();
|
||||
|
||||
// Assert: the loader should have found the cached config and NOT called getConfig
|
||||
expect(getConfigSpy).not.toHaveBeenCalled();
|
||||
expect(result).toBeNull(); // admin with billing enabled = no redirect
|
||||
});
|
||||
});
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
it("should redirect members to /settings/user when accessing billing directly", async () => {
|
||||
// Arrange
|
||||
setupSaasMode();
|
||||
seedActiveUser({ role: "member" });
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: BillingSettingsScreen,
|
||||
loader: clientLoader,
|
||||
path: "/settings/billing",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/settings/billing"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - should be redirected to user settings
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-settings-screen")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow admins to access billing route", async () => {
|
||||
// Arrange
|
||||
setupSaasMode();
|
||||
seedActiveUser({ role: "admin" });
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: BillingSettingsScreen,
|
||||
loader: clientLoader,
|
||||
path: "/settings/billing",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/settings/billing"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - should stay on billing page (component renders PaymentForm)
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("user-settings-screen"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow owners to access billing route", async () => {
|
||||
// Arrange
|
||||
setupSaasMode();
|
||||
seedActiveUser({ role: "owner" });
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: BillingSettingsScreen,
|
||||
loader: clientLoader,
|
||||
path: "/settings/billing",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/settings/billing"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - should stay on billing page
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("user-settings-screen"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should redirect when user is undefined (no org selected)", async () => {
|
||||
// Arrange: no org selected, so getActiveOrganizationUser returns undefined
|
||||
setupSaasMode();
|
||||
// Explicitly clear org store so getActiveOrganizationUser returns undefined
|
||||
orgStore.useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: BillingSettingsScreen,
|
||||
loader: clientLoader,
|
||||
path: "/settings/billing",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/settings/billing"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - should be redirected to user settings
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-settings-screen")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should redirect all users when enable_billing is false", async () => {
|
||||
// Arrange: enable_billing=false means billing is hidden for everyone
|
||||
setupSaasMode({ enable_billing: false });
|
||||
seedActiveUser({ role: "owner" }); // Even owners should be redirected
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: BillingSettingsScreen,
|
||||
loader: clientLoader,
|
||||
path: "/settings/billing",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/settings/billing"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - should be redirected to user settings
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-settings-screen")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("PaymentForm permission behavior", () => {
|
||||
beforeEach(() => {
|
||||
mockUseBalance.mockReturnValue({
|
||||
data: "150.00",
|
||||
isLoading: false,
|
||||
});
|
||||
});
|
||||
|
||||
it("should disable input and button when isDisabled is true, but show balance", async () => {
|
||||
// Arrange & Act
|
||||
render(<PaymentForm isDisabled />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - balance is visible
|
||||
const balance = screen.getByTestId("user-balance");
|
||||
expect(balance).toBeInTheDocument();
|
||||
expect(balance).toHaveTextContent("$150.00");
|
||||
|
||||
// Assert - input is disabled
|
||||
const topUpInput = screen.getByTestId("top-up-input");
|
||||
expect(topUpInput).toBeDisabled();
|
||||
|
||||
// Assert - button is disabled
|
||||
const submitButton = screen.getByRole("button");
|
||||
expect(submitButton).toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable input and button when isDisabled is false", async () => {
|
||||
// Arrange & Act
|
||||
render(<PaymentForm isDisabled={false} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={mockQueryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - input is enabled
|
||||
const topUpInput = screen.getByTestId("top-up-input");
|
||||
expect(topUpInput).not.toBeDisabled();
|
||||
|
||||
// Assert - button starts disabled (no amount entered) but is NOT
|
||||
// permanently disabled by the isDisabled prop
|
||||
const submitButton = screen.getByRole("button");
|
||||
// The button is disabled because no valid amount is entered, not because of isDisabled
|
||||
expect(submitButton).toBeDisabled();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -5,7 +5,7 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import i18next from "i18next";
|
||||
import { I18nextProvider } from "react-i18next";
|
||||
import GitSettingsScreen, { clientLoader } from "#/routes/git-settings";
|
||||
import GitSettingsScreen from "#/routes/git-settings";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import AuthService from "#/api/auth-service/auth-service.api";
|
||||
@@ -13,6 +13,7 @@ import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import { integrationService } from "#/api/integration-service/integration-service.api";
|
||||
|
||||
const VALID_OSS_CONFIG: WebClientConfig = {
|
||||
app_mode: "oss",
|
||||
@@ -656,10 +657,3 @@ describe("GitLab Webhook Manager Integration", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
it("should export a clientLoader for route protection", () => {
|
||||
expect(clientLoader).toBeDefined();
|
||||
expect(typeof clientLoader).toBe("function");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -541,7 +541,7 @@ describe("Settings 404", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("New user welcome toast", () => {
|
||||
describe("Setup Payment modal", () => {
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
|
||||
@@ -593,7 +593,7 @@ describe("New user welcome toast", () => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should not show the setup payment modal (removed) in SaaS mode for new users", async () => {
|
||||
it("should only render if SaaS mode and is new user", async () => {
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
is_new_user: true,
|
||||
@@ -603,9 +603,9 @@ describe("New user welcome toast", () => {
|
||||
|
||||
await screen.findByTestId("root-layout");
|
||||
|
||||
// SetupPaymentModal was removed; verify it no longer renders
|
||||
expect(
|
||||
screen.queryByTestId("proceed-to-stripe-button"),
|
||||
).not.toBeInTheDocument();
|
||||
const setupPaymentModal = await screen.findByTestId(
|
||||
"proceed-to-stripe-button",
|
||||
);
|
||||
expect(setupPaymentModal).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,23 +11,12 @@ import {
|
||||
import * as AdvancedSettingsUtlls from "#/utils/has-advanced-settings-set";
|
||||
import * as ToastHandlers from "#/utils/custom-toast-handlers";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import type { OrganizationMember } from "#/types/org";
|
||||
|
||||
// Mock react-router hooks
|
||||
const mockUseSearchParams = vi.fn();
|
||||
vi.mock("react-router", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-router")>("react-router");
|
||||
return {
|
||||
...actual,
|
||||
useSearchParams: () => mockUseSearchParams(),
|
||||
useRevalidator: () => ({
|
||||
revalidate: vi.fn(),
|
||||
}),
|
||||
};
|
||||
});
|
||||
vi.mock("react-router", () => ({
|
||||
useSearchParams: () => mockUseSearchParams(),
|
||||
}));
|
||||
|
||||
// Mock useIsAuthed hook
|
||||
const mockUseIsAuthed = vi.fn();
|
||||
@@ -35,63 +24,14 @@ vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => mockUseIsAuthed(),
|
||||
}));
|
||||
|
||||
// Mock useConfig hook
|
||||
const mockUseConfig = vi.fn();
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => mockUseConfig(),
|
||||
}));
|
||||
|
||||
const renderLlmSettingsScreen = (
|
||||
orgId: string | null = null,
|
||||
meData?: {
|
||||
org_id: string;
|
||||
user_id: string;
|
||||
email: string;
|
||||
role: string;
|
||||
status: string;
|
||||
llm_api_key: string;
|
||||
max_iterations: number;
|
||||
llm_model: string;
|
||||
llm_api_key_for_byor: string | null;
|
||||
llm_base_url: string;
|
||||
},
|
||||
) => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
// Default to orgId "1" if not provided (for backward compatibility)
|
||||
const finalOrgId = orgId ?? "1";
|
||||
useSelectedOrganizationStore.setState({ organizationId: finalOrgId });
|
||||
|
||||
// Pre-populate React Query cache with me data
|
||||
// If meData is provided, use it; otherwise use default owner data
|
||||
const defaultMeData = {
|
||||
org_id: finalOrgId,
|
||||
user_id: "99",
|
||||
email: "owner@example.com",
|
||||
role: "owner",
|
||||
status: "active",
|
||||
llm_api_key: "",
|
||||
max_iterations: 20,
|
||||
llm_model: "",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
};
|
||||
queryClient.setQueryData(
|
||||
["organizations", finalOrgId, "me"],
|
||||
meData || defaultMeData,
|
||||
);
|
||||
|
||||
return render(<LlmSettingsScreen />, {
|
||||
const renderLlmSettingsScreen = () =>
|
||||
render(<LlmSettingsScreen />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
<QueryClientProvider client={new QueryClient()}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
@@ -107,58 +47,22 @@ beforeEach(() => {
|
||||
|
||||
// Default mock for useIsAuthed - returns authenticated by default
|
||||
mockUseIsAuthed.mockReturnValue({ data: true, isLoading: false });
|
||||
|
||||
// Default mock for useConfig - returns SaaS mode by default
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
// Default mock for organizationService.getMe - returns owner role by default (full access)
|
||||
const defaultMeData: OrganizationMember = {
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "owner@example.com",
|
||||
role: "owner",
|
||||
status: "active",
|
||||
llm_api_key: "",
|
||||
max_iterations: 20,
|
||||
llm_model: "",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
};
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(defaultMeData);
|
||||
|
||||
// Reset organization store
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
});
|
||||
|
||||
describe("Content", () => {
|
||||
describe("Basic form", () => {
|
||||
it("should render the basic form by default", async () => {
|
||||
// Use OSS mode so API key input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
within(basicForm).getByTestId("llm-provider-input");
|
||||
within(basicForm).getByTestId("llm-model-input");
|
||||
within(basicForm).getByTestId("llm-api-key-input");
|
||||
within(basicForm).getByTestId("llm-api-key-help-anchor");
|
||||
const basicFom = screen.getByTestId("llm-settings-form-basic");
|
||||
within(basicFom).getByTestId("llm-provider-input");
|
||||
within(basicFom).getByTestId("llm-model-input");
|
||||
within(basicFom).getByTestId("llm-api-key-input");
|
||||
within(basicFom).getByTestId("llm-api-key-help-anchor");
|
||||
});
|
||||
|
||||
it("should render the default values if non exist", async () => {
|
||||
// Use OSS mode so API key input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
@@ -238,12 +142,6 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should render the advanced form if the switch is toggled", async () => {
|
||||
// Use OSS mode so agent-input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
@@ -278,12 +176,6 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should render the default advanced settings", async () => {
|
||||
// Use OSS mode so agent-input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
@@ -323,12 +215,6 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should render existing advanced settings correctly", async () => {
|
||||
// Use OSS mode so agent-input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
@@ -450,10 +336,10 @@ describe("Content", () => {
|
||||
|
||||
describe("API key visibility in Basic Settings", () => {
|
||||
it("should hide API key input when SaaS mode is enabled and OpenHands provider is selected", async () => {
|
||||
// SaaS mode is already the default from beforeEach, but let's be explicit
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return app_mode for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -477,10 +363,10 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should show API key input when SaaS mode is enabled and non-OpenHands provider is selected", async () => {
|
||||
// SaaS mode is already the default from beforeEach, but let's be explicit
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return app_mode for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -508,9 +394,10 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should show API key input when OSS mode is enabled and OpenHands provider is selected", async () => {
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return app_mode for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "oss",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -534,9 +421,10 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should show API key input when OSS mode is enabled and non-OpenHands provider is selected", async () => {
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return app_mode for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "oss",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -564,10 +452,10 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should hide API key input when switching from non-OpenHands to OpenHands provider in SaaS mode", async () => {
|
||||
// SaaS mode is already the default from beforeEach, but let's be explicit
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return app_mode for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -609,10 +497,10 @@ describe("Content", () => {
|
||||
});
|
||||
|
||||
it("should show API key input when switching from OpenHands to non-OpenHands provider in SaaS mode", async () => {
|
||||
// SaaS mode is already the default from beforeEach, but let's be explicit
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - only return app_mode for these tests
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
});
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -660,17 +548,15 @@ describe("Form submission", () => {
|
||||
|
||||
const provider = screen.getByTestId("llm-provider-input");
|
||||
const model = screen.getByTestId("llm-model-input");
|
||||
const apiKey = screen.getByTestId("llm-api-key-input");
|
||||
|
||||
// select provider (switch to OpenAI so API key input becomes visible)
|
||||
// select provider
|
||||
await userEvent.click(provider);
|
||||
const providerOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(providerOption);
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
|
||||
// enter api key (now visible after switching provider)
|
||||
const apiKey = await screen.findByTestId("llm-api-key-input");
|
||||
// enter api key
|
||||
await userEvent.type(apiKey, "test-api-key");
|
||||
|
||||
// select model
|
||||
@@ -691,12 +577,6 @@ describe("Form submission", () => {
|
||||
});
|
||||
|
||||
it("should submit the advanced form with the correct values", async () => {
|
||||
// Use OSS mode so agent-input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
@@ -805,12 +685,6 @@ describe("Form submission", () => {
|
||||
});
|
||||
|
||||
it("should disable the button if there are no changes in the advanced form", async () => {
|
||||
// Use OSS mode so agent-input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
getSettingsSpy.mockResolvedValue({
|
||||
...MOCK_DEFAULT_USER_SETTINGS,
|
||||
@@ -944,17 +818,8 @@ describe("Form submission", () => {
|
||||
|
||||
expect(submitButton).toBeDisabled();
|
||||
|
||||
// Switch to a non-OpenHands provider first so API key input is visible
|
||||
const provider = screen.getByTestId("llm-provider-input");
|
||||
await userEvent.click(provider);
|
||||
const providerOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(providerOption);
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// dirty the basic form
|
||||
const apiKey = await screen.findByTestId("llm-api-key-input");
|
||||
const apiKey = screen.getByTestId("llm-api-key-input");
|
||||
await userEvent.type(apiKey, "test-api-key");
|
||||
expect(submitButton).not.toBeDisabled();
|
||||
|
||||
@@ -1144,9 +1009,21 @@ describe("View persistence after saving advanced settings", () => {
|
||||
|
||||
it("should remain on Advanced view after saving when search API key is set", async () => {
|
||||
// Arrange: Start with default settings (non-SaaS mode to show search API key field)
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
// @ts-expect-error - partial mock for testing
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "oss",
|
||||
posthog_client_key: "fake-posthog-client-key",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
});
|
||||
|
||||
const getSettingsSpy = vi.spyOn(SettingsService, "getSettings");
|
||||
@@ -1203,37 +1080,12 @@ describe("Status toasts", () => {
|
||||
);
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Switch to a non-OpenHands provider so API key input is visible
|
||||
const provider = screen.getByTestId("llm-provider-input");
|
||||
await userEvent.click(provider);
|
||||
const providerOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(providerOption);
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// Wait for API key input to appear
|
||||
// Toggle setting to change
|
||||
const apiKeyInput = await screen.findByTestId("llm-api-key-input");
|
||||
|
||||
// Also change the model to ensure form is dirty
|
||||
const model = screen.getByTestId("llm-model-input");
|
||||
await userEvent.click(model);
|
||||
const modelOption = screen.getByText("gpt-4o");
|
||||
await userEvent.click(modelOption);
|
||||
await waitFor(() => {
|
||||
expect(model).toHaveValue("gpt-4o");
|
||||
});
|
||||
|
||||
// Enter API key
|
||||
await userEvent.type(apiKeyInput, "test-api-key");
|
||||
|
||||
// Wait for submit button to be enabled
|
||||
const submit = await screen.findByTestId("submit-button");
|
||||
await waitFor(() => {
|
||||
expect(submit).not.toBeDisabled();
|
||||
});
|
||||
await userEvent.click(submit);
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalled();
|
||||
@@ -1248,37 +1100,12 @@ describe("Status toasts", () => {
|
||||
saveSettingsSpy.mockRejectedValue(new Error("Failed to save settings"));
|
||||
|
||||
renderLlmSettingsScreen();
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
|
||||
// Switch to a non-OpenHands provider so API key input is visible
|
||||
const provider = screen.getByTestId("llm-provider-input");
|
||||
await userEvent.click(provider);
|
||||
const providerOption = screen.getByText("OpenAI");
|
||||
await userEvent.click(providerOption);
|
||||
await waitFor(() => {
|
||||
expect(provider).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// Wait for API key input to appear
|
||||
// Toggle setting to change
|
||||
const apiKeyInput = await screen.findByTestId("llm-api-key-input");
|
||||
|
||||
// Also change the model to ensure form is dirty
|
||||
const model = screen.getByTestId("llm-model-input");
|
||||
await userEvent.click(model);
|
||||
const modelOption = screen.getByText("gpt-4o");
|
||||
await userEvent.click(modelOption);
|
||||
await waitFor(() => {
|
||||
expect(model).toHaveValue("gpt-4o");
|
||||
});
|
||||
|
||||
// Enter API key
|
||||
await userEvent.type(apiKeyInput, "test-api-key");
|
||||
|
||||
// Wait for submit button to be enabled
|
||||
const submit = await screen.findByTestId("submit-button");
|
||||
await waitFor(() => {
|
||||
expect(submit).not.toBeDisabled();
|
||||
});
|
||||
await userEvent.click(submit);
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalled();
|
||||
@@ -1288,12 +1115,6 @@ describe("Status toasts", () => {
|
||||
|
||||
describe("Advanced form", () => {
|
||||
it("should call displaySuccessToast when the settings are saved", async () => {
|
||||
// Use OSS mode to ensure API key input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
|
||||
const displaySuccessToastSpy = vi.spyOn(
|
||||
@@ -1312,11 +1133,7 @@ describe("Status toasts", () => {
|
||||
const apiKeyInput = await screen.findByTestId("llm-api-key-input");
|
||||
await userEvent.type(apiKeyInput, "test-api-key");
|
||||
|
||||
// Wait for submit button to be enabled
|
||||
const submit = await screen.findByTestId("submit-button");
|
||||
await waitFor(() => {
|
||||
expect(submit).not.toBeDisabled();
|
||||
});
|
||||
await userEvent.click(submit);
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalled();
|
||||
@@ -1324,12 +1141,6 @@ describe("Status toasts", () => {
|
||||
});
|
||||
|
||||
it("should call displayErrorToast when the settings fail to save", async () => {
|
||||
// Use OSS mode to ensure API key input is visible
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "oss" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
|
||||
const displayErrorToastSpy = vi.spyOn(ToastHandlers, "displayErrorToast");
|
||||
@@ -1347,11 +1158,7 @@ describe("Status toasts", () => {
|
||||
const apiKeyInput = await screen.findByTestId("llm-api-key-input");
|
||||
await userEvent.type(apiKeyInput, "test-api-key");
|
||||
|
||||
// Wait for submit button to be enabled
|
||||
const submit = await screen.findByTestId("submit-button");
|
||||
await waitFor(() => {
|
||||
expect(submit).not.toBeDisabled();
|
||||
});
|
||||
await userEvent.click(submit);
|
||||
|
||||
expect(saveSettingsSpy).toHaveBeenCalled();
|
||||
@@ -1359,411 +1166,3 @@ describe("Status toasts", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Role-based permissions", () => {
|
||||
const getMeSpy = vi.spyOn(organizationService, "getMe");
|
||||
|
||||
beforeEach(() => {
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
});
|
||||
|
||||
describe("User role (read-only)", () => {
|
||||
const memberData: OrganizationMember = {
|
||||
org_id: "2",
|
||||
user_id: "99",
|
||||
email: "user@example.com",
|
||||
role: "member",
|
||||
status: "active",
|
||||
llm_api_key: "",
|
||||
max_iterations: 20,
|
||||
llm_model: "",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Mock user role
|
||||
getMeSpy.mockResolvedValue(memberData);
|
||||
});
|
||||
|
||||
it("should disable all input fields in basic view", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("2", memberData); // orgId "2" returns user role
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
|
||||
// Assert
|
||||
const providerInput = within(basicForm).getByTestId("llm-provider-input");
|
||||
const modelInput = within(basicForm).getByTestId("llm-model-input");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(providerInput).toBeDisabled();
|
||||
expect(modelInput).toBeDisabled();
|
||||
});
|
||||
|
||||
// API key input may be hidden if OpenHands provider is selected in SaaS mode
|
||||
// If it exists, it should be disabled
|
||||
const apiKeyInput = within(basicForm).queryByTestId("llm-api-key-input");
|
||||
if (apiKeyInput) {
|
||||
expect(apiKeyInput).toBeDisabled();
|
||||
}
|
||||
});
|
||||
|
||||
// Note: No "should disable all input fields in advanced view" test for members
|
||||
// because members cannot access the advanced view (the toggle is disabled).
|
||||
|
||||
it("should not render submit button", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("2", memberData);
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const submitButton = screen.queryByTestId("submit-button");
|
||||
|
||||
// Assert
|
||||
expect(submitButton).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should disable the advanced/basic toggle for read-only users", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("2", memberData);
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
|
||||
// Assert - toggle should be disabled for members who lack edit_llm_settings
|
||||
await waitFor(() => {
|
||||
expect(advancedSwitch).toBeDisabled();
|
||||
});
|
||||
|
||||
// Basic form should remain visible (members can't switch to advanced)
|
||||
expect(
|
||||
screen.getByTestId("llm-settings-form-basic"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
describe("Owner role (full access)", () => {
|
||||
beforeEach(() => {
|
||||
// Mock owner role
|
||||
getMeSpy.mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "owner@example.com",
|
||||
role: "owner",
|
||||
status: "active",
|
||||
llm_api_key: "",
|
||||
max_iterations: 20,
|
||||
llm_model: "",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
});
|
||||
});
|
||||
|
||||
it("should enable all input fields in basic view", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("1"); // orgId "1" returns owner role
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
|
||||
// Assert
|
||||
const providerInput = within(basicForm).getByTestId("llm-provider-input");
|
||||
const modelInput = within(basicForm).getByTestId("llm-model-input");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(providerInput).not.toBeDisabled();
|
||||
expect(modelInput).not.toBeDisabled();
|
||||
});
|
||||
|
||||
// API key input may be hidden if OpenHands provider is selected in SaaS mode
|
||||
// If it exists, it should be enabled
|
||||
const apiKeyInput = within(basicForm).queryByTestId("llm-api-key-input");
|
||||
if (apiKeyInput) {
|
||||
expect(apiKeyInput).not.toBeDisabled();
|
||||
}
|
||||
});
|
||||
|
||||
it("should enable all input fields in advanced view", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("1");
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
|
||||
// Assert - owners can toggle between views
|
||||
expect(advancedSwitch).not.toBeDisabled();
|
||||
|
||||
await userEvent.click(advancedSwitch);
|
||||
const advancedForm = await screen.findByTestId(
|
||||
"llm-settings-form-advanced",
|
||||
);
|
||||
|
||||
// Assert
|
||||
const modelInput = within(advancedForm).getByTestId(
|
||||
"llm-custom-model-input",
|
||||
);
|
||||
const baseUrlInput = within(advancedForm).getByTestId("base-url-input");
|
||||
const condenserSwitch = within(advancedForm).getByTestId(
|
||||
"enable-memory-condenser-switch",
|
||||
);
|
||||
const confirmationSwitch = within(advancedForm).getByTestId(
|
||||
"enable-confirmation-mode-switch",
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(modelInput).not.toBeDisabled();
|
||||
expect(baseUrlInput).not.toBeDisabled();
|
||||
expect(condenserSwitch).not.toBeDisabled();
|
||||
expect(confirmationSwitch).not.toBeDisabled();
|
||||
});
|
||||
|
||||
// API key input may be hidden if OpenHands provider is selected in SaaS mode
|
||||
// If it exists, it should be enabled
|
||||
const apiKeyInput =
|
||||
within(advancedForm).queryByTestId("llm-api-key-input");
|
||||
if (apiKeyInput) {
|
||||
expect(apiKeyInput).not.toBeDisabled();
|
||||
}
|
||||
});
|
||||
|
||||
it("should enable submit button when form is dirty", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("1");
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
const providerInput = screen.getByTestId("llm-provider-input");
|
||||
|
||||
// Assert - initially disabled (no changes)
|
||||
expect(submitButton).toBeDisabled();
|
||||
|
||||
// Act - make a change by selecting a different provider
|
||||
await userEvent.click(providerInput);
|
||||
const openAIOption = await screen.findByText("OpenAI");
|
||||
await userEvent.click(openAIOption);
|
||||
|
||||
// Assert - button should be enabled
|
||||
await waitFor(() => {
|
||||
expect(submitButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow submitting form changes", async () => {
|
||||
// Arrange
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
renderLlmSettingsScreen("1");
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const providerInput = screen.getByTestId("llm-provider-input");
|
||||
const modelInput = screen.getByTestId("llm-model-input");
|
||||
|
||||
// Select a different provider to make form dirty
|
||||
await userEvent.click(providerInput);
|
||||
const openAIOption = await screen.findByText("OpenAI");
|
||||
await userEvent.click(openAIOption);
|
||||
await waitFor(() => {
|
||||
expect(providerInput).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// Select a different model to ensure form is dirty
|
||||
await userEvent.click(modelInput);
|
||||
const modelOption = await screen.findByText("gpt-4o");
|
||||
await userEvent.click(modelOption);
|
||||
await waitFor(() => {
|
||||
expect(modelInput).toHaveValue("gpt-4o");
|
||||
});
|
||||
|
||||
// Wait for form to be marked as dirty
|
||||
const submitButton = await screen.findByTestId("submit-button");
|
||||
await waitFor(() => {
|
||||
expect(submitButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(saveSettingsSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
// Note: The former "should disable security analyzer dropdown when confirmation mode
|
||||
// is enabled" test was removed. It was in the member block and only passed because
|
||||
// members have isReadOnly=true (all fields disabled), not because confirmation mode
|
||||
// disables the analyzer. For owners/admins, the security analyzer is enabled
|
||||
// regardless of confirmation mode.
|
||||
});
|
||||
|
||||
describe("Admin role (full access)", () => {
|
||||
beforeEach(() => {
|
||||
// Mock admin role
|
||||
getMeSpy.mockResolvedValue({
|
||||
org_id: "3",
|
||||
user_id: "99",
|
||||
email: "admin@example.com",
|
||||
role: "admin",
|
||||
status: "active",
|
||||
llm_api_key: "",
|
||||
max_iterations: 20,
|
||||
llm_model: "",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
});
|
||||
});
|
||||
|
||||
it("should enable all input fields in basic view", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("3"); // orgId "3" returns admin role
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const basicForm = screen.getByTestId("llm-settings-form-basic");
|
||||
|
||||
// Assert
|
||||
const providerInput = within(basicForm).getByTestId("llm-provider-input");
|
||||
const modelInput = within(basicForm).getByTestId("llm-model-input");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(providerInput).not.toBeDisabled();
|
||||
expect(modelInput).not.toBeDisabled();
|
||||
});
|
||||
|
||||
// API key input may be hidden if OpenHands provider is selected in SaaS mode
|
||||
// If it exists, it should be enabled
|
||||
const apiKeyInput = within(basicForm).queryByTestId("llm-api-key-input");
|
||||
if (apiKeyInput) {
|
||||
expect(apiKeyInput).not.toBeDisabled();
|
||||
}
|
||||
});
|
||||
|
||||
it("should enable all input fields in advanced view", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("3");
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const advancedSwitch = screen.getByTestId("advanced-settings-switch");
|
||||
|
||||
// Assert - admins can toggle between views
|
||||
expect(advancedSwitch).not.toBeDisabled();
|
||||
|
||||
await userEvent.click(advancedSwitch);
|
||||
const advancedForm = await screen.findByTestId(
|
||||
"llm-settings-form-advanced",
|
||||
);
|
||||
|
||||
// Assert
|
||||
const modelInput = within(advancedForm).getByTestId(
|
||||
"llm-custom-model-input",
|
||||
);
|
||||
const baseUrlInput = within(advancedForm).getByTestId("base-url-input");
|
||||
const condenserSwitch = within(advancedForm).getByTestId(
|
||||
"enable-memory-condenser-switch",
|
||||
);
|
||||
const confirmationSwitch = within(advancedForm).getByTestId(
|
||||
"enable-confirmation-mode-switch",
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(modelInput).not.toBeDisabled();
|
||||
expect(baseUrlInput).not.toBeDisabled();
|
||||
expect(condenserSwitch).not.toBeDisabled();
|
||||
expect(confirmationSwitch).not.toBeDisabled();
|
||||
});
|
||||
|
||||
// API key input may be hidden if OpenHands provider is selected in SaaS mode
|
||||
// If it exists, it should be enabled
|
||||
const apiKeyInput =
|
||||
within(advancedForm).queryByTestId("llm-api-key-input");
|
||||
if (apiKeyInput) {
|
||||
expect(apiKeyInput).not.toBeDisabled();
|
||||
}
|
||||
});
|
||||
|
||||
it("should enable submit button when form is dirty", async () => {
|
||||
// Arrange
|
||||
renderLlmSettingsScreen("3");
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const submitButton = screen.getByTestId("submit-button");
|
||||
const providerInput = screen.getByTestId("llm-provider-input");
|
||||
|
||||
// Assert - initially disabled (no changes)
|
||||
expect(submitButton).toBeDisabled();
|
||||
|
||||
// Act - make a change by selecting a different provider
|
||||
await userEvent.click(providerInput);
|
||||
const openAIOption = await screen.findByText("OpenAI");
|
||||
await userEvent.click(openAIOption);
|
||||
|
||||
// Assert - button should be enabled
|
||||
await waitFor(() => {
|
||||
expect(submitButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow submitting form changes", async () => {
|
||||
// Arrange
|
||||
const saveSettingsSpy = vi.spyOn(SettingsService, "saveSettings");
|
||||
renderLlmSettingsScreen("3");
|
||||
|
||||
// Act
|
||||
await screen.findByTestId("llm-settings-screen");
|
||||
const providerInput = screen.getByTestId("llm-provider-input");
|
||||
const modelInput = screen.getByTestId("llm-model-input");
|
||||
|
||||
// Select a different provider to make form dirty
|
||||
await userEvent.click(providerInput);
|
||||
const openAIOption = await screen.findByText("OpenAI");
|
||||
await userEvent.click(openAIOption);
|
||||
await waitFor(() => {
|
||||
expect(providerInput).toHaveValue("OpenAI");
|
||||
});
|
||||
|
||||
// Select a different model to ensure form is dirty
|
||||
await userEvent.click(modelInput);
|
||||
const modelOption = await screen.findByText("gpt-4o");
|
||||
await userEvent.click(modelOption);
|
||||
await waitFor(() => {
|
||||
expect(modelInput).toHaveValue("gpt-4o");
|
||||
});
|
||||
|
||||
// Wait for form to be marked as dirty
|
||||
const submitButton = await screen.findByTestId("submit-button");
|
||||
await waitFor(() => {
|
||||
expect(submitButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
await userEvent.click(submitButton);
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(saveSettingsSpy).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
it("should export a clientLoader for route protection", async () => {
|
||||
// This test verifies the clientLoader is exported for consistency with other routes
|
||||
// Note: All roles have view_llm_settings permission, so this guard ensures
|
||||
// the route is protected and can be restricted in the future if needed
|
||||
const { clientLoader } = await import("#/routes/llm-settings");
|
||||
expect(clientLoader).toBeDefined();
|
||||
expect(typeof clientLoader).toBe("function");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,954 +0,0 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { render, screen, waitFor, within } from "@testing-library/react";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { selectOrganization } from "test-utils";
|
||||
import ManageOrg from "#/routes/manage-org";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import SettingsScreen, { clientLoader } from "#/routes/settings";
|
||||
import {
|
||||
resetOrgMockData,
|
||||
MOCK_TEAM_ORG_ACME,
|
||||
INITIAL_MOCK_ORGS,
|
||||
} from "#/mocks/org-handlers";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import BillingService from "#/api/billing-service/billing-service.api";
|
||||
import { OrganizationMember } from "#/types/org";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
|
||||
const mockQueryClient = vi.hoisted(() => {
|
||||
const { QueryClient } = require("@tanstack/react-query");
|
||||
return new QueryClient();
|
||||
});
|
||||
|
||||
vi.mock("#/query-client-config", () => ({
|
||||
queryClient: mockQueryClient,
|
||||
}));
|
||||
|
||||
function ManageOrgWithPortalRoot() {
|
||||
return (
|
||||
<div>
|
||||
<ManageOrg />
|
||||
<div data-testid="portal-root" id="portal-root" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const RouteStub = createRoutesStub([
|
||||
{
|
||||
Component: () => <div data-testid="home-screen" />,
|
||||
path: "/",
|
||||
},
|
||||
{
|
||||
// @ts-expect-error - type mismatch
|
||||
loader: clientLoader,
|
||||
Component: SettingsScreen,
|
||||
path: "/settings",
|
||||
HydrateFallback: () => <div>Loading...</div>,
|
||||
children: [
|
||||
{
|
||||
Component: ManageOrgWithPortalRoot,
|
||||
path: "/settings/org",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
let queryClient: QueryClient;
|
||||
|
||||
const renderManageOrg = () =>
|
||||
render(<RouteStub initialEntries={["/settings/org"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>{children}</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
const { navigateMock } = vi.hoisted(() => ({
|
||||
navigateMock: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("react-i18next", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("react-i18next")>("react-i18next");
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
ORG$SELECT_ORGANIZATION_PLACEHOLDER: "Please select an organization",
|
||||
ORG$PERSONAL_WORKSPACE: "Personal Workspace",
|
||||
};
|
||||
return translations[key] || key;
|
||||
},
|
||||
i18n: {
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("react-router", async () => ({
|
||||
...(await vi.importActual("react-router")),
|
||||
useNavigate: () => navigateMock,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
useIsAuthed: () => ({ data: true }),
|
||||
}));
|
||||
|
||||
describe("Manage Org Route", () => {
|
||||
const getMeSpy = vi.spyOn(organizationService, "getMe");
|
||||
|
||||
// Test data constants
|
||||
const TEST_USERS: Record<"OWNER" | "ADMIN", OrganizationMember> = {
|
||||
OWNER: {
|
||||
org_id: "1",
|
||||
user_id: "1",
|
||||
email: "test@example.com",
|
||||
role: "owner",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
},
|
||||
ADMIN: {
|
||||
org_id: "1",
|
||||
user_id: "1",
|
||||
email: "test@example.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
},
|
||||
};
|
||||
|
||||
// Helper function to set up user mock
|
||||
const setupUserMock = (userData: {
|
||||
org_id: string;
|
||||
user_id: string;
|
||||
email: string;
|
||||
role: "owner" | "admin" | "member";
|
||||
llm_api_key: string;
|
||||
max_iterations: number;
|
||||
llm_model: string;
|
||||
llm_api_key_for_byor: string | null;
|
||||
llm_base_url: string;
|
||||
status: "active" | "invited" | "inactive";
|
||||
}) => {
|
||||
getMeSpy.mockResolvedValue(userData);
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
// Set Zustand store to a team org so clientLoader's org route protection allows access
|
||||
useSelectedOrganizationStore.setState({
|
||||
organizationId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
// Seed organizations into the module-level queryClient used by clientLoader
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
queryClient = new QueryClient();
|
||||
// Pre-seed organizations so org selector renders immediately (avoids flaky race with API fetch)
|
||||
queryClient.setQueryData(["organizations"], {
|
||||
items: INITIAL_MOCK_ORGS,
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true, // Enable billing by default so billing UI is shown
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Set default mock for user (owner role has all permissions)
|
||||
setupUserMock(TEST_USERS.OWNER);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Reset organization mock data to ensure clean state between tests
|
||||
resetOrgMockData();
|
||||
// Reset Zustand store to ensure clean state between tests
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
// Clear module-level queryClient used by clientLoader
|
||||
mockQueryClient.clear();
|
||||
// Clear test queryClient
|
||||
queryClient?.clear();
|
||||
});
|
||||
|
||||
it("should render the available credits", async () => {
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
await waitFor(() => {
|
||||
const credits = screen.getByTestId("available-credits");
|
||||
expect(credits).toHaveTextContent("100");
|
||||
});
|
||||
});
|
||||
|
||||
it("should render account details", async () => {
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
await waitFor(() => {
|
||||
const orgName = screen.getByTestId("org-name");
|
||||
expect(orgName).toHaveTextContent("Personal Workspace");
|
||||
});
|
||||
});
|
||||
|
||||
it("should be able to add credits", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 }); // user is owner in org 1
|
||||
|
||||
expect(screen.queryByTestId("add-credits-form")).not.toBeInTheDocument();
|
||||
// Simulate adding credits — wait for permissions-dependent button
|
||||
const addCreditsButton = await waitFor(() => screen.getByText(/add/i));
|
||||
await userEvent.click(addCreditsButton);
|
||||
|
||||
const addCreditsForm = screen.getByTestId("add-credits-form");
|
||||
expect(addCreditsForm).toBeInTheDocument();
|
||||
|
||||
const amountInput = within(addCreditsForm).getByTestId("amount-input");
|
||||
const nextButton = within(addCreditsForm).getByRole("button", {
|
||||
name: /next/i,
|
||||
});
|
||||
|
||||
await userEvent.type(amountInput, "1000");
|
||||
await userEvent.click(nextButton);
|
||||
|
||||
// expect redirect to payment page
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
|
||||
await waitFor(() =>
|
||||
expect(screen.queryByTestId("add-credits-form")).not.toBeInTheDocument(),
|
||||
);
|
||||
});
|
||||
|
||||
it("should close the modal when clicking cancel", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 }); // user is owner in org 1
|
||||
|
||||
expect(screen.queryByTestId("add-credits-form")).not.toBeInTheDocument();
|
||||
// Simulate adding credits — wait for permissions-dependent button
|
||||
const addCreditsButton = await waitFor(() => screen.getByText(/add/i));
|
||||
await userEvent.click(addCreditsButton);
|
||||
|
||||
const addCreditsForm = screen.getByTestId("add-credits-form");
|
||||
expect(addCreditsForm).toBeInTheDocument();
|
||||
|
||||
const cancelButton = within(addCreditsForm).getByRole("button", {
|
||||
name: /close/i,
|
||||
});
|
||||
|
||||
await userEvent.click(cancelButton);
|
||||
|
||||
expect(screen.queryByTestId("add-credits-form")).not.toBeInTheDocument();
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
describe("AddCreditsModal", () => {
|
||||
const openAddCreditsModal = async () => {
|
||||
const user = userEvent.setup();
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 }); // user is owner in org 1
|
||||
|
||||
const addCreditsButton = await waitFor(() => screen.getByText(/add/i));
|
||||
await user.click(addCreditsButton);
|
||||
|
||||
const addCreditsForm = screen.getByTestId("add-credits-form");
|
||||
expect(addCreditsForm).toBeInTheDocument();
|
||||
|
||||
return { user, addCreditsForm };
|
||||
};
|
||||
|
||||
describe("Button State Management", () => {
|
||||
it("should enable submit button initially when modal opens", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains invalid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button when input contains valid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "100");
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable submit button after validation error is shown", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Input Attributes & Placeholder", () => {
|
||||
it("should have min attribute set to 10", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("min", "10");
|
||||
});
|
||||
|
||||
it("should have max attribute set to 25000", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("max", "25000");
|
||||
});
|
||||
|
||||
it("should have step attribute set to 1", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
expect(amountInput).toHaveAttribute("step", "1");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Error Message Display", () => {
|
||||
it("should not display error message initially when modal opens", async () => {
|
||||
await openAddCreditsModal();
|
||||
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display error message after submitting amount above maximum", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MAXIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should display error message after submitting decimal value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "50.5");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MUST_BE_WHOLE_NUMBER",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should replace error message when submitting different invalid value", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "25001");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MAXIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Form Submission Behavior", () => {
|
||||
it("should prevent submission when amount is invalid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should call createCheckoutSession with correct amount when valid", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not call createCheckoutSession when validation fails", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "-50");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Verify mutation was not called
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_NEGATIVE_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should close modal on successful submission", async () => {
|
||||
const createCheckoutSessionSpy = vi
|
||||
.spyOn(BillingService, "createCheckoutSession")
|
||||
.mockResolvedValue("https://checkout.stripe.com/test-session");
|
||||
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "1000");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(1000);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("add-credits-form"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should allow API call when validation passes and clear any previous errors", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
// First submit invalid value
|
||||
await user.type(amountInput, "9");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("amount-error")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Then submit valid value
|
||||
await user.clear(amountInput);
|
||||
await user.type(amountInput, "100");
|
||||
await user.click(nextButton);
|
||||
|
||||
expect(createCheckoutSessionSpy).toHaveBeenCalledWith(100);
|
||||
const errorMessage = screen.queryByTestId("amount-error");
|
||||
expect(errorMessage).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge Cases", () => {
|
||||
it("should handle zero value correctly", async () => {
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
await user.type(amountInput, "0");
|
||||
await user.click(nextButton);
|
||||
|
||||
await waitFor(() => {
|
||||
const errorMessage = screen.getByTestId("amount-error");
|
||||
expect(errorMessage).toHaveTextContent(
|
||||
"PAYMENT$ERROR_MINIMUM_AMOUNT",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
it("should handle whitespace-only input correctly", async () => {
|
||||
const createCheckoutSessionSpy = vi.spyOn(
|
||||
BillingService,
|
||||
"createCheckoutSession",
|
||||
);
|
||||
const { user } = await openAddCreditsModal();
|
||||
const amountInput = screen.getByTestId("amount-input");
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
|
||||
// Number inputs typically don't accept spaces, but test the behavior
|
||||
await user.type(amountInput, " ");
|
||||
await user.click(nextButton);
|
||||
|
||||
// Should not call API (empty/invalid input)
|
||||
expect(createCheckoutSessionSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
it("should show add credits option for ADMIN role", async () => {
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 3 }); // user is admin in org 4 (All Hands AI)
|
||||
|
||||
// Verify credits are shown
|
||||
await waitFor(() => {
|
||||
const credits = screen.getByTestId("available-credits");
|
||||
expect(credits).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify add credits button is present (admins can add credits)
|
||||
const addButton = screen.getByText(/add/i);
|
||||
expect(addButton).toBeInTheDocument();
|
||||
});
|
||||
|
||||
describe("actions", () => {
|
||||
it("should be able to update the organization name", async () => {
|
||||
const updateOrgNameSpy = vi.spyOn(
|
||||
organizationService,
|
||||
"updateOrganization",
|
||||
);
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas", // required to enable getMe
|
||||
}),
|
||||
);
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
const orgName = screen.getByTestId("org-name");
|
||||
await waitFor(() =>
|
||||
expect(orgName).toHaveTextContent("Personal Workspace"),
|
||||
);
|
||||
|
||||
expect(
|
||||
screen.queryByTestId("update-org-name-form"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
const changeOrgNameButton = within(orgName).getByRole("button", {
|
||||
name: /change/i,
|
||||
});
|
||||
await userEvent.click(changeOrgNameButton);
|
||||
|
||||
const orgNameForm = screen.getByTestId("update-org-name-form");
|
||||
const orgNameInput = within(orgNameForm).getByRole("textbox");
|
||||
const saveButton = within(orgNameForm).getByRole("button", {
|
||||
name: /save/i,
|
||||
});
|
||||
|
||||
await userEvent.type(orgNameInput, "New Org Name");
|
||||
await userEvent.click(saveButton);
|
||||
|
||||
expect(updateOrgNameSpy).toHaveBeenCalledWith({
|
||||
orgId: "1",
|
||||
name: "New Org Name",
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("update-org-name-form"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(orgName).toHaveTextContent("New Org Name");
|
||||
});
|
||||
});
|
||||
|
||||
it("should NOT allow roles other than owners to change org name", async () => {
|
||||
// Set admin role before rendering
|
||||
setupUserMock(TEST_USERS.ADMIN);
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 3 }); // user is admin in org 4 (All Hands AI)
|
||||
|
||||
const orgName = screen.getByTestId("org-name");
|
||||
const changeOrgNameButton = within(orgName).queryByRole("button", {
|
||||
name: /change/i,
|
||||
});
|
||||
expect(changeOrgNameButton).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should NOT allow roles other than owners to delete an organization", async () => {
|
||||
setupUserMock(TEST_USERS.ADMIN);
|
||||
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas", // required to enable getMe
|
||||
}),
|
||||
);
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 3 }); // user is admin in org 4 (All Hands AI)
|
||||
|
||||
const deleteOrgButton = screen.queryByRole("button", {
|
||||
name: /ORG\$DELETE_ORGANIZATION/i,
|
||||
});
|
||||
expect(deleteOrgButton).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should be able to delete an organization", async () => {
|
||||
const deleteOrgSpy = vi.spyOn(organizationService, "deleteOrganization");
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
expect(
|
||||
screen.queryByTestId("delete-org-confirmation"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
const deleteOrgButton = await waitFor(() =>
|
||||
screen.getByRole("button", {
|
||||
name: /ORG\$DELETE_ORGANIZATION/i,
|
||||
}),
|
||||
);
|
||||
await userEvent.click(deleteOrgButton);
|
||||
|
||||
const deleteConfirmation = screen.getByTestId("delete-org-confirmation");
|
||||
const confirmButton = within(deleteConfirmation).getByRole("button", {
|
||||
name: /BUTTON\$CONFIRM/i,
|
||||
});
|
||||
|
||||
await userEvent.click(confirmButton);
|
||||
|
||||
expect(deleteOrgSpy).toHaveBeenCalledWith({ orgId: "1" });
|
||||
expect(
|
||||
screen.queryByTestId("delete-org-confirmation"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// expect to have navigated to home screen
|
||||
await screen.findByTestId("home-screen");
|
||||
});
|
||||
|
||||
it.todo("should be able to update the organization billing info");
|
||||
});
|
||||
|
||||
describe("Role-based delete organization permission behavior", () => {
|
||||
it("should show delete organization button when user has canDeleteOrganization permission (Owner role)", async () => {
|
||||
setupUserMock(TEST_USERS.OWNER);
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
const deleteButton = await screen.findByRole("button", {
|
||||
name: /ORG\$DELETE_ORGANIZATION/i,
|
||||
});
|
||||
|
||||
expect(deleteButton).toBeInTheDocument();
|
||||
expect(deleteButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should not show delete organization button when user lacks canDeleteOrganization permission ('Admin' role)", async () => {
|
||||
setupUserMock({
|
||||
org_id: "1",
|
||||
user_id: "1",
|
||||
email: "test@example.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
const deleteButton = screen.queryByRole("button", {
|
||||
name: /ORG\$DELETE_ORGANIZATION/i,
|
||||
});
|
||||
|
||||
expect(deleteButton).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show delete organization button when user lacks canDeleteOrganization permission ('Member' role)", async () => {
|
||||
setupUserMock({
|
||||
org_id: "1",
|
||||
user_id: "1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
// Members lack view_billing permission, so the clientLoader redirects away from /settings/org
|
||||
renderManageOrg();
|
||||
|
||||
// The manage-org screen should NOT be accessible — clientLoader redirects
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("manage-org-screen"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should open delete confirmation modal when delete button is clicked (with permission)", async () => {
|
||||
setupUserMock(TEST_USERS.OWNER);
|
||||
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
expect(
|
||||
screen.queryByTestId("delete-org-confirmation"),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
const deleteButton = await screen.findByRole("button", {
|
||||
name: /ORG\$DELETE_ORGANIZATION/i,
|
||||
});
|
||||
await userEvent.click(deleteButton);
|
||||
|
||||
expect(screen.getByTestId("delete-org-confirmation")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("enable_billing feature flag", () => {
|
||||
it("should show credits section when enable_billing is true", async () => {
|
||||
// Arrange
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Act
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("available-credits")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
getConfigSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should show organization name section when enable_billing is true", async () => {
|
||||
// Arrange
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Act
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("org-name")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
getConfigSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should show Add Credits button when enable_billing is true", async () => {
|
||||
// Arrange
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Act
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const addButton = screen.getByText(/add/i);
|
||||
expect(addButton).toBeInTheDocument();
|
||||
});
|
||||
|
||||
getConfigSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should hide all billing-related elements when enable_billing is false", async () => {
|
||||
// Arrange
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
// Act
|
||||
renderManageOrg();
|
||||
await screen.findByTestId("manage-org-screen");
|
||||
await selectOrganization({ orgIndex: 0 });
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("available-credits"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.queryByText(/add/i)).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
getConfigSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
});
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { clientLoader } from "#/routes/mcp-settings";
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
it("should export a clientLoader for route protection", () => {
|
||||
// This test verifies the clientLoader is exported (for consistency with other routes)
|
||||
expect(clientLoader).toBeDefined();
|
||||
expect(typeof clientLoader).toBe("function");
|
||||
});
|
||||
});
|
||||
@@ -1,5 +1,5 @@
|
||||
import { screen, act } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { screen } from "@testing-library/react";
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import PlannerTab from "#/routes/planner-tab";
|
||||
import { renderWithProviders } from "../../test-utils";
|
||||
import { useConversationStore } from "#/stores/conversation-store";
|
||||
@@ -12,15 +12,8 @@ vi.mock("#/hooks/use-handle-plan-click", () => ({
|
||||
}));
|
||||
|
||||
describe("PlannerTab", () => {
|
||||
const originalRAF = global.requestAnimationFrame;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
// Make requestAnimationFrame execute synchronously for testing
|
||||
global.requestAnimationFrame = (cb: FrameRequestCallback) => {
|
||||
cb(0);
|
||||
return 0;
|
||||
};
|
||||
// Reset store state to defaults
|
||||
useConversationStore.setState({
|
||||
planContent: null,
|
||||
@@ -28,10 +21,6 @@ describe("PlannerTab", () => {
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
global.requestAnimationFrame = originalRAF;
|
||||
});
|
||||
|
||||
describe("Create a plan button", () => {
|
||||
it("should be enabled when conversation mode is 'code'", () => {
|
||||
// Arrange
|
||||
@@ -63,71 +52,4 @@ describe("PlannerTab", () => {
|
||||
expect(button).toBeDisabled();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Auto-scroll behavior", () => {
|
||||
it("should scroll to bottom when plan content is updated", () => {
|
||||
// Arrange
|
||||
const scrollTopSetter = vi.fn();
|
||||
const mockScrollHeight = 500;
|
||||
|
||||
// Mock scroll properties on HTMLElement prototype
|
||||
const originalScrollHeightDescriptor = Object.getOwnPropertyDescriptor(
|
||||
HTMLElement.prototype,
|
||||
"scrollHeight",
|
||||
);
|
||||
const originalScrollTopDescriptor = Object.getOwnPropertyDescriptor(
|
||||
HTMLElement.prototype,
|
||||
"scrollTop",
|
||||
);
|
||||
|
||||
Object.defineProperty(HTMLElement.prototype, "scrollHeight", {
|
||||
get: () => mockScrollHeight,
|
||||
configurable: true,
|
||||
});
|
||||
Object.defineProperty(HTMLElement.prototype, "scrollTop", {
|
||||
get: () => 0,
|
||||
set: scrollTopSetter,
|
||||
configurable: true,
|
||||
});
|
||||
|
||||
try {
|
||||
// Render with initial plan content
|
||||
useConversationStore.setState({
|
||||
planContent: "# Initial Plan",
|
||||
conversationMode: "plan",
|
||||
});
|
||||
|
||||
renderWithProviders(<PlannerTab />);
|
||||
|
||||
// Clear calls from initial render
|
||||
scrollTopSetter.mockClear();
|
||||
|
||||
// Act - Update plan content which should trigger auto-scroll
|
||||
act(() => {
|
||||
useConversationStore.setState({
|
||||
planContent: "# Updated Plan\n\nMore content added here.",
|
||||
});
|
||||
});
|
||||
|
||||
// Assert - scrollTop should be set to scrollHeight
|
||||
expect(scrollTopSetter).toHaveBeenCalledWith(mockScrollHeight);
|
||||
} finally {
|
||||
// Restore original descriptors
|
||||
if (originalScrollHeightDescriptor) {
|
||||
Object.defineProperty(
|
||||
HTMLElement.prototype,
|
||||
"scrollHeight",
|
||||
originalScrollHeightDescriptor,
|
||||
);
|
||||
}
|
||||
if (originalScrollTopDescriptor) {
|
||||
Object.defineProperty(
|
||||
HTMLElement.prototype,
|
||||
"scrollTop",
|
||||
originalScrollTopDescriptor,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,15 +3,12 @@ import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { createRoutesStub, Outlet } from "react-router";
|
||||
import SecretsSettingsScreen, { clientLoader } from "#/routes/secrets-settings";
|
||||
import SecretsSettingsScreen from "#/routes/secrets-settings";
|
||||
import { SecretsService } from "#/api/secrets-service";
|
||||
import { GetSecretsResponse } from "#/api/secrets-service.types";
|
||||
import SettingsService from "#/api/settings-service/settings-service.api";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { OrganizationMember } from "#/types/org";
|
||||
import * as orgStore from "#/stores/selected-organization-store";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
|
||||
const MOCK_GET_SECRETS_RESPONSE: GetSecretsResponse["custom_secrets"] = [
|
||||
{
|
||||
@@ -69,75 +66,6 @@ afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe("clientLoader permission checks", () => {
|
||||
const createMockUser = (
|
||||
overrides: Partial<OrganizationMember> = {},
|
||||
): OrganizationMember => ({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedActiveUser = (user: Partial<OrganizationMember>) => {
|
||||
orgStore.useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser(user),
|
||||
);
|
||||
};
|
||||
|
||||
it("should export a clientLoader for route protection", () => {
|
||||
// This test verifies the clientLoader is exported (for consistency with other routes)
|
||||
expect(clientLoader).toBeDefined();
|
||||
expect(typeof clientLoader).toBe("function");
|
||||
});
|
||||
|
||||
it("should allow members to access secrets settings (all roles have manage_secrets)", async () => {
|
||||
// Arrange
|
||||
seedActiveUser({ role: "member" });
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: SecretsSettingsScreen,
|
||||
loader: clientLoader,
|
||||
path: "/settings/secrets",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
]);
|
||||
|
||||
// Act
|
||||
render(<RouterStub initialEntries={["/settings/secrets"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider
|
||||
client={
|
||||
new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
})
|
||||
}
|
||||
>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
|
||||
// Assert - should stay on secrets settings page (not redirected)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("secrets-settings-screen")).toBeInTheDocument();
|
||||
});
|
||||
expect(screen.queryByTestId("user-settings-screen")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Content", () => {
|
||||
it("should render the secrets settings screen", () => {
|
||||
renderSecretsSettings();
|
||||
|
||||
@@ -1,22 +1,22 @@
|
||||
import { render, screen, within } from "@testing-library/react";
|
||||
import { screen, within } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import SettingsScreen from "#/routes/settings";
|
||||
import { PaymentForm } from "#/components/features/payment/payment-form";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
|
||||
let queryClient: QueryClient;
|
||||
|
||||
// Mock the useSettings hook
|
||||
vi.mock("#/hooks/query/use-settings", async () => {
|
||||
const actual = await vi.importActual<typeof import("#/hooks/query/use-settings")>(
|
||||
"#/hooks/query/use-settings"
|
||||
);
|
||||
const actual = await vi.importActual<
|
||||
typeof import("#/hooks/query/use-settings")
|
||||
>("#/hooks/query/use-settings");
|
||||
return {
|
||||
...actual,
|
||||
useSettings: vi.fn().mockReturnValue({
|
||||
data: { EMAIL_VERIFIED: true },
|
||||
data: {
|
||||
EMAIL_VERIFIED: true, // Mock email as verified to prevent redirection
|
||||
},
|
||||
isLoading: false,
|
||||
}),
|
||||
};
|
||||
@@ -52,36 +52,21 @@ vi.mock("react-i18next", async () => {
|
||||
});
|
||||
|
||||
// Mock useConfig hook
|
||||
const { mockUseConfig, mockUseMe, mockUsePermission } = vi.hoisted(() => ({
|
||||
const { mockUseConfig } = vi.hoisted(() => ({
|
||||
mockUseConfig: vi.fn(),
|
||||
mockUseMe: vi.fn(),
|
||||
mockUsePermission: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: mockUseConfig,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-me", () => ({
|
||||
useMe: mockUseMe,
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/organizations/use-permissions", () => ({
|
||||
usePermission: () => ({
|
||||
hasPermission: mockUsePermission,
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("Settings Billing", () => {
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
// Set default config to OSS mode with lowercase keys
|
||||
// Set default config to OSS mode
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: {
|
||||
app_mode: "oss",
|
||||
github_client_id: "123",
|
||||
posthog_client_key: "456",
|
||||
feature_flags: {
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
@@ -95,13 +80,6 @@ describe("Settings Billing", () => {
|
||||
},
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUseMe.mockReturnValue({
|
||||
data: { role: "admin" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUsePermission.mockReturnValue(false); // default: no billing access
|
||||
});
|
||||
|
||||
const RoutesStub = createRoutesStub([
|
||||
@@ -126,38 +104,14 @@ describe("Settings Billing", () => {
|
||||
]);
|
||||
|
||||
const renderSettingsScreen = () =>
|
||||
render(<RoutesStub initialEntries={["/settings"]} />, {
|
||||
wrapper: ({ children }) => (
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{children}
|
||||
</QueryClientProvider>
|
||||
),
|
||||
});
|
||||
renderWithProviders(<RoutesStub initialEntries={["/settings/billing"]} />);
|
||||
|
||||
afterEach(() => vi.clearAllMocks());
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should not render the billing tab if OSS mode", async () => {
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: {
|
||||
app_mode: "oss",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
},
|
||||
},
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUseMe.mockReturnValue({
|
||||
data: { role: "admin" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUsePermission.mockReturnValue(true);
|
||||
|
||||
// OSS mode is set by default in beforeEach
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
@@ -165,10 +119,12 @@ describe("Settings Billing", () => {
|
||||
expect(credits).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the billing tab if: SaaS mode, billing enabled, admin user", async () => {
|
||||
it("should render the billing tab if SaaS mode and billing is enabled", async () => {
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: {
|
||||
app_mode: "saas",
|
||||
github_client_id: "123",
|
||||
posthog_client_key: "456",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
@@ -183,74 +139,33 @@ describe("Settings Billing", () => {
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUseMe.mockReturnValue({
|
||||
data: { role: "admin" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUsePermission.mockReturnValue(true);
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
expect(within(navbar).getByText("Billing")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should NOT render the billing tab if: SaaS mode, billing is enabled, and member user", async () => {
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: {
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
},
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUseMe.mockReturnValue({
|
||||
data: { role: "member" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUsePermission.mockReturnValue(false);
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
expect(within(navbar).queryByText("Billing")).not.toBeInTheDocument();
|
||||
within(navbar).getByText("Billing");
|
||||
});
|
||||
|
||||
it("should render the billing settings if clicking the billing item", async () => {
|
||||
const user = userEvent.setup();
|
||||
// When enable_billing is true, the billing nav item is shown
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: {
|
||||
app_mode: "saas",
|
||||
github_client_id: "123",
|
||||
posthog_client_key: "456",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
},
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUseMe.mockReturnValue({
|
||||
data: { role: "admin" },
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
mockUsePermission.mockReturnValue(true);
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
import { render, screen, waitFor, within } from "@testing-library/react";
|
||||
import { render, screen, within } from "@testing-library/react";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClientProvider } from "@tanstack/react-query";
|
||||
import SettingsScreen, { clientLoader } from "#/routes/settings";
|
||||
import { getFirstAvailablePath } from "#/utils/settings-utils";
|
||||
import SettingsScreen, {
|
||||
clientLoader,
|
||||
getFirstAvailablePath,
|
||||
} from "#/routes/settings";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { OrganizationMember } from "#/types/org";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { MOCK_PERSONAL_ORG, MOCK_TEAM_ORG_ACME } from "#/mocks/org-handlers";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import { WebClientFeatureFlags } from "#/api/option-service/option.types";
|
||||
import { createMockWebClientConfig } from "#/mocks/settings-handlers";
|
||||
|
||||
// Module-level mocks using vi.hoisted
|
||||
const { handleLogoutMock, mockQueryClient } = vi.hoisted(() => ({
|
||||
@@ -60,44 +57,17 @@ vi.mock("react-i18next", async () => {
|
||||
});
|
||||
|
||||
describe("Settings Screen", () => {
|
||||
const createMockUser = (
|
||||
overrides: Partial<OrganizationMember> = {},
|
||||
): OrganizationMember => ({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const seedActiveUser = (user: Partial<OrganizationMember>) => {
|
||||
useSelectedOrganizationStore.setState({ organizationId: "org-1" });
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser(user),
|
||||
);
|
||||
};
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: SettingsScreen,
|
||||
// @ts-expect-error - custom loader
|
||||
loader: clientLoader,
|
||||
clientLoader,
|
||||
path: "/settings",
|
||||
children: [
|
||||
{
|
||||
Component: () => <div data-testid="llm-settings-screen" />,
|
||||
path: "/settings",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="user-settings-screen" />,
|
||||
path: "/settings/user",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="git-settings-screen" />,
|
||||
path: "/settings/integrations",
|
||||
@@ -114,15 +84,6 @@ describe("Settings Screen", () => {
|
||||
Component: () => <div data-testid="api-keys-settings-screen" />,
|
||||
path: "/settings/api-keys",
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="org-members-settings-screen" />,
|
||||
path: "/settings/org-members",
|
||||
handle: { hideTitle: true },
|
||||
},
|
||||
{
|
||||
Component: () => <div data-testid="organization-settings-screen" />,
|
||||
path: "/settings/org",
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
@@ -168,21 +129,11 @@ describe("Settings Screen", () => {
|
||||
});
|
||||
|
||||
it("should render the saas navbar", async () => {
|
||||
const saasConfig = {
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
},
|
||||
};
|
||||
const saasConfig = { app_mode: "saas" };
|
||||
|
||||
// Clear any existing query data and set the config
|
||||
mockQueryClient.clear();
|
||||
mockQueryClient.setQueryData(["web-client-config"], saasConfig);
|
||||
seedActiveUser({ role: "admin" });
|
||||
|
||||
const sectionsToInclude = [
|
||||
"llm", // LLM settings are now always shown in SaaS mode
|
||||
@@ -198,9 +149,6 @@ describe("Settings Screen", () => {
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
await waitFor(() => {
|
||||
expect(within(navbar).getByText("Billing")).toBeInTheDocument();
|
||||
});
|
||||
sectionsToInclude.forEach((section) => {
|
||||
const sectionElement = within(navbar).getByText(section, {
|
||||
exact: false, // case insensitive
|
||||
@@ -252,367 +200,12 @@ describe("Settings Screen", () => {
|
||||
|
||||
it.todo("should not be able to access oss-only routes in saas mode");
|
||||
|
||||
describe("Personal org vs team org visibility", () => {
|
||||
it("should not show Organization and Organization Members settings items when personal org is selected", async () => {
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
|
||||
// Organization and Organization Members should NOT be visible for personal org
|
||||
expect(
|
||||
within(navbar).queryByText("Organization Members"),
|
||||
).not.toBeInTheDocument();
|
||||
expect(
|
||||
within(navbar).queryByText("Organization"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show Billing settings item when team org is selected", async () => {
|
||||
// Set up SaaS mode (which has Billing in nav items)
|
||||
mockQueryClient.clear();
|
||||
mockQueryClient.setQueryData(["web-client-config"], { app_mode: "saas" });
|
||||
// Pre-select the team org in the query client and Zustand store
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "2" });
|
||||
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "2",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
|
||||
// Wait for orgs to load, then verify Billing is hidden for team orgs
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
within(navbar).queryByText("Billing", { exact: false }),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not allow direct URL access to /settings/org when personal org is selected", async () => {
|
||||
// Set up orgs in query client so clientLoader can access them
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
// Use Zustand store instead of query client for selected org ID
|
||||
// This is the correct pattern - the query client key ["selected_organization"] is never set in production
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
renderSettingsScreen("/settings/org");
|
||||
|
||||
// Should redirect away from org settings for personal org
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("organization-settings-screen"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not allow direct URL access to /settings/org-members when personal org is selected", async () => {
|
||||
// Set up config and organizations in query client so clientLoader can access them
|
||||
mockQueryClient.setQueryData(["web-client-config"], { app_mode: "saas" });
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
// Use Zustand store for selected org ID
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
// Mock getMe so getActiveOrganizationUser returns admin
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser({ role: "admin", org_id: "1" }),
|
||||
);
|
||||
|
||||
// Act: Call clientLoader directly with the REAL route path (as defined in routes.ts)
|
||||
const request = new Request("http://localhost/settings/org-members");
|
||||
// @ts-expect-error - test only needs request and params, not full loader args
|
||||
const result = await clientLoader({ request, params: {} });
|
||||
|
||||
// Assert: Should redirect away from org-members settings for personal org
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toBeInstanceOf(Response);
|
||||
const response = result as Response;
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.get("Location")).toBe("/settings");
|
||||
});
|
||||
|
||||
it("should not allow direct URL access to /settings/billing when team org is selected", async () => {
|
||||
// Set up orgs in query client so clientLoader can access them
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
// Use Zustand store instead of query client for selected org ID
|
||||
useSelectedOrganizationStore.setState({ organizationId: "2" });
|
||||
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
renderSettingsScreen("/settings/billing");
|
||||
|
||||
// Should redirect away from billing settings for team org
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTestId("billing-settings-screen"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("enable_billing feature flag", () => {
|
||||
it("should show billing navigation item when enable_billing is true", async () => {
|
||||
// Arrange
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true, // When enable_billing is true, billing nav is shown
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
mockQueryClient.clear();
|
||||
// Set up personal org (billing is only shown for personal orgs, not team orgs)
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
vi.spyOn(organizationService, "getOrganizations").mockResolvedValue({
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
// Act
|
||||
renderSettingsScreen();
|
||||
|
||||
// Assert
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
await waitFor(() => {
|
||||
expect(within(navbar).getByText("Billing")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
getConfigSpy.mockRestore();
|
||||
});
|
||||
|
||||
it("should hide billing navigation item when enable_billing is false", async () => {
|
||||
// Arrange
|
||||
const getConfigSpy = vi.spyOn(OptionService, "getConfig");
|
||||
getConfigSpy.mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: false, // When enable_billing is false, billing nav is hidden
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
mockQueryClient.clear();
|
||||
|
||||
// Act
|
||||
renderSettingsScreen();
|
||||
|
||||
// Assert
|
||||
const navbar = await screen.findByTestId("settings-navbar");
|
||||
expect(within(navbar).queryByText("Billing")).not.toBeInTheDocument();
|
||||
|
||||
getConfigSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("clientLoader reads org ID from Zustand store", () => {
|
||||
beforeEach(() => {
|
||||
mockQueryClient.clear();
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
it("should redirect away from /settings/org when personal org is selected in Zustand store", async () => {
|
||||
// Arrange: Set up config and organizations in query client
|
||||
mockQueryClient.setQueryData(["web-client-config"], { app_mode: "saas" });
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
|
||||
// Set org ID ONLY in Zustand store (not in query client)
|
||||
// This tests that clientLoader reads from the correct source
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
// Mock getMe so getActiveOrganizationUser returns admin
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser({ role: "admin", org_id: "1" }),
|
||||
);
|
||||
|
||||
// Act: Call clientLoader directly
|
||||
const request = new Request("http://localhost/settings/org");
|
||||
// @ts-expect-error - test only needs request and params, not full loader args
|
||||
const result = await clientLoader({ request, params: {} });
|
||||
|
||||
// Assert: Should redirect away from org settings for personal org
|
||||
expect(result).not.toBeNull();
|
||||
// In React Router, redirect returns a Response object
|
||||
expect(result).toBeInstanceOf(Response);
|
||||
const response = result as Response;
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.get("Location")).toBe("/settings");
|
||||
});
|
||||
|
||||
it("should redirect away from /settings/billing when team org is selected in Zustand store", async () => {
|
||||
// Arrange: Set up config and organizations in query client
|
||||
mockQueryClient.setQueryData(["web-client-config"], { app_mode: "saas" });
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_TEAM_ORG_ACME],
|
||||
currentOrgId: MOCK_TEAM_ORG_ACME.id,
|
||||
});
|
||||
|
||||
// Set org ID ONLY in Zustand store (not in query client)
|
||||
useSelectedOrganizationStore.setState({ organizationId: "2" });
|
||||
|
||||
// Mock getMe so getActiveOrganizationUser returns admin
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue(
|
||||
createMockUser({ role: "admin", org_id: "2" }),
|
||||
);
|
||||
|
||||
// Act: Call clientLoader directly
|
||||
const request = new Request("http://localhost/settings/billing");
|
||||
// @ts-expect-error - test only needs request and params, not full loader args
|
||||
const result = await clientLoader({ request, params: {} });
|
||||
|
||||
// Assert: Should redirect away from billing settings for team org
|
||||
expect(result).not.toBeNull();
|
||||
expect(result).toBeInstanceOf(Response);
|
||||
const response = result as Response;
|
||||
expect(response.status).toBe(302);
|
||||
expect(response.headers.get("Location")).toBe("/settings/user");
|
||||
});
|
||||
});
|
||||
|
||||
describe("hide page feature flags", () => {
|
||||
beforeEach(() => {
|
||||
// Set up as personal org admin so billing is accessible
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
vi.spyOn(organizationService, "getMe").mockResolvedValue({
|
||||
org_id: "1",
|
||||
user_id: "99",
|
||||
email: "me@test.com",
|
||||
role: "admin",
|
||||
llm_api_key: "**********",
|
||||
max_iterations: 20,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "https://api.openai.com",
|
||||
status: "active",
|
||||
});
|
||||
});
|
||||
|
||||
it("should hide users page in navbar when hide_users_page is true", async () => {
|
||||
const saasConfig = {
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true, // Enable billing so it's not hidden by isBillingHidden
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
@@ -625,14 +218,6 @@ describe("Settings Screen", () => {
|
||||
|
||||
mockQueryClient.clear();
|
||||
mockQueryClient.setQueryData(["web-client-config"], saasConfig);
|
||||
// Set up personal org so billing is visible
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
// Pre-populate user data in cache so useMe() returns admin role immediately
|
||||
mockQueryClient.setQueryData(["organizations", "1", "me"], createMockUser({ role: "admin", org_id: "1" }));
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
@@ -653,7 +238,7 @@ describe("Settings Screen", () => {
|
||||
const saasConfig = {
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
@@ -666,11 +251,6 @@ describe("Settings Screen", () => {
|
||||
|
||||
mockQueryClient.clear();
|
||||
mockQueryClient.setQueryData(["web-client-config"], saasConfig);
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
@@ -691,7 +271,7 @@ describe("Settings Screen", () => {
|
||||
const saasConfig = {
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
@@ -704,13 +284,6 @@ describe("Settings Screen", () => {
|
||||
|
||||
mockQueryClient.clear();
|
||||
mockQueryClient.setQueryData(["web-client-config"], saasConfig);
|
||||
mockQueryClient.setQueryData(["organizations"], {
|
||||
items: [MOCK_PERSONAL_ORG],
|
||||
currentOrgId: MOCK_PERSONAL_ORG.id,
|
||||
});
|
||||
useSelectedOrganizationStore.setState({ organizationId: "1" });
|
||||
// Pre-populate user data in cache so useMe() returns admin role immediately
|
||||
mockQueryClient.setQueryData(["organizations", "1", "me"], createMockUser({ role: "admin", org_id: "1" }));
|
||||
|
||||
renderSettingsScreen();
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
|
||||
describe("useSelectedOrganizationStore", () => {
|
||||
it("should have null as initial organizationId", () => {
|
||||
const { result } = renderHook(() => useSelectedOrganizationStore());
|
||||
expect(result.current.organizationId).toBeNull();
|
||||
});
|
||||
|
||||
it("should update organizationId when setOrganizationId is called", () => {
|
||||
const { result } = renderHook(() => useSelectedOrganizationStore());
|
||||
|
||||
act(() => {
|
||||
result.current.setOrganizationId("org-123");
|
||||
});
|
||||
|
||||
expect(result.current.organizationId).toBe("org-123");
|
||||
});
|
||||
|
||||
it("should allow setting organizationId to null", () => {
|
||||
const { result } = renderHook(() => useSelectedOrganizationStore());
|
||||
|
||||
act(() => {
|
||||
result.current.setOrganizationId("org-123");
|
||||
});
|
||||
|
||||
expect(result.current.organizationId).toBe("org-123");
|
||||
|
||||
act(() => {
|
||||
result.current.setOrganizationId(null);
|
||||
});
|
||||
|
||||
expect(result.current.organizationId).toBeNull();
|
||||
});
|
||||
|
||||
it("should share state across multiple hook instances", () => {
|
||||
const { result: result1 } = renderHook(() =>
|
||||
useSelectedOrganizationStore(),
|
||||
);
|
||||
const { result: result2 } = renderHook(() =>
|
||||
useSelectedOrganizationStore(),
|
||||
);
|
||||
|
||||
act(() => {
|
||||
result1.current.setOrganizationId("shared-organization");
|
||||
});
|
||||
|
||||
expect(result2.current.organizationId).toBe("shared-organization");
|
||||
});
|
||||
});
|
||||
@@ -1,50 +0,0 @@
|
||||
import { describe, it, expect } from "vitest";
|
||||
import { isBillingHidden } from "#/utils/org/billing-visibility";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
|
||||
describe("isBillingHidden", () => {
|
||||
const createConfig = (
|
||||
featureFlagOverrides: Partial<WebClientConfig["feature_flags"]> = {},
|
||||
): WebClientConfig =>
|
||||
({
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "test",
|
||||
feature_flags: {
|
||||
enable_billing: true,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
...featureFlagOverrides,
|
||||
},
|
||||
}) as WebClientConfig;
|
||||
|
||||
it("should return true when config is undefined (safe default)", () => {
|
||||
expect(isBillingHidden(undefined, true)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true when enable_billing is false", () => {
|
||||
const config = createConfig({ enable_billing: false });
|
||||
expect(isBillingHidden(config, true)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true when user lacks view_billing permission", () => {
|
||||
const config = createConfig();
|
||||
expect(isBillingHidden(config, false)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return true when both enable_billing is false and user lacks permission", () => {
|
||||
const config = createConfig({ enable_billing: false });
|
||||
expect(isBillingHidden(config, false)).toBe(true);
|
||||
});
|
||||
|
||||
it("should return false when enable_billing is true and user has view_billing permission", () => {
|
||||
const config = createConfig();
|
||||
expect(isBillingHidden(config, true)).toBe(false);
|
||||
});
|
||||
|
||||
it("should treat enable_billing as true by default (billing visible, subject to permission)", () => {
|
||||
const config = createConfig({ enable_billing: true });
|
||||
expect(isBillingHidden(config, true)).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -1,172 +0,0 @@
|
||||
import { describe, expect, test } from "vitest";
|
||||
import {
|
||||
isValidEmail,
|
||||
getInvalidEmails,
|
||||
areAllEmailsValid,
|
||||
hasDuplicates,
|
||||
} from "#/utils/input-validation";
|
||||
|
||||
describe("isValidEmail", () => {
|
||||
describe("valid email formats", () => {
|
||||
test("accepts standard email formats", () => {
|
||||
expect(isValidEmail("user@example.com")).toBe(true);
|
||||
expect(isValidEmail("john.doe@company.org")).toBe(true);
|
||||
expect(isValidEmail("test@subdomain.domain.com")).toBe(true);
|
||||
});
|
||||
|
||||
test("accepts emails with numbers", () => {
|
||||
expect(isValidEmail("user123@example.com")).toBe(true);
|
||||
expect(isValidEmail("123user@example.com")).toBe(true);
|
||||
expect(isValidEmail("user@example123.com")).toBe(true);
|
||||
});
|
||||
|
||||
test("accepts emails with special characters in local part", () => {
|
||||
expect(isValidEmail("user.name@example.com")).toBe(true);
|
||||
expect(isValidEmail("user+tag@example.com")).toBe(true);
|
||||
expect(isValidEmail("user_name@example.com")).toBe(true);
|
||||
expect(isValidEmail("user-name@example.com")).toBe(true);
|
||||
expect(isValidEmail("user%tag@example.com")).toBe(true);
|
||||
});
|
||||
|
||||
test("accepts emails with various TLDs", () => {
|
||||
expect(isValidEmail("user@example.io")).toBe(true);
|
||||
expect(isValidEmail("user@example.co.uk")).toBe(true);
|
||||
expect(isValidEmail("user@example.travel")).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe("invalid email formats", () => {
|
||||
test("rejects empty strings", () => {
|
||||
expect(isValidEmail("")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects strings without @", () => {
|
||||
expect(isValidEmail("userexample.com")).toBe(false);
|
||||
expect(isValidEmail("user.example.com")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects strings without domain", () => {
|
||||
expect(isValidEmail("user@")).toBe(false);
|
||||
expect(isValidEmail("user@.com")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects strings without local part", () => {
|
||||
expect(isValidEmail("@example.com")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects strings without TLD", () => {
|
||||
expect(isValidEmail("user@example")).toBe(false);
|
||||
expect(isValidEmail("user@example.")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects strings with single character TLD", () => {
|
||||
expect(isValidEmail("user@example.c")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects plain text", () => {
|
||||
expect(isValidEmail("test")).toBe(false);
|
||||
expect(isValidEmail("just some text")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects emails with spaces", () => {
|
||||
expect(isValidEmail("user @example.com")).toBe(false);
|
||||
expect(isValidEmail("user@ example.com")).toBe(false);
|
||||
expect(isValidEmail(" user@example.com")).toBe(false);
|
||||
expect(isValidEmail("user@example.com ")).toBe(false);
|
||||
});
|
||||
|
||||
test("rejects emails with multiple @ symbols", () => {
|
||||
expect(isValidEmail("user@@example.com")).toBe(false);
|
||||
expect(isValidEmail("user@domain@example.com")).toBe(false);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("getInvalidEmails", () => {
|
||||
test("returns empty array when all emails are valid", () => {
|
||||
const emails = ["user@example.com", "test@domain.org"];
|
||||
expect(getInvalidEmails(emails)).toEqual([]);
|
||||
});
|
||||
|
||||
test("returns all invalid emails", () => {
|
||||
const emails = ["valid@example.com", "invalid", "test@", "another@valid.org"];
|
||||
expect(getInvalidEmails(emails)).toEqual(["invalid", "test@"]);
|
||||
});
|
||||
|
||||
test("returns all emails when none are valid", () => {
|
||||
const emails = ["invalid", "also-invalid", "no-at-symbol"];
|
||||
expect(getInvalidEmails(emails)).toEqual(emails);
|
||||
});
|
||||
|
||||
test("handles empty array", () => {
|
||||
expect(getInvalidEmails([])).toEqual([]);
|
||||
});
|
||||
|
||||
test("handles array with single invalid email", () => {
|
||||
expect(getInvalidEmails(["invalid"])).toEqual(["invalid"]);
|
||||
});
|
||||
|
||||
test("handles array with single valid email", () => {
|
||||
expect(getInvalidEmails(["valid@example.com"])).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe("areAllEmailsValid", () => {
|
||||
test("returns true when all emails are valid", () => {
|
||||
const emails = ["user@example.com", "test@domain.org", "admin@company.io"];
|
||||
expect(areAllEmailsValid(emails)).toBe(true);
|
||||
});
|
||||
|
||||
test("returns false when any email is invalid", () => {
|
||||
const emails = ["user@example.com", "invalid", "test@domain.org"];
|
||||
expect(areAllEmailsValid(emails)).toBe(false);
|
||||
});
|
||||
|
||||
test("returns false when all emails are invalid", () => {
|
||||
const emails = ["invalid", "also-invalid"];
|
||||
expect(areAllEmailsValid(emails)).toBe(false);
|
||||
});
|
||||
|
||||
test("returns true for empty array", () => {
|
||||
expect(areAllEmailsValid([])).toBe(true);
|
||||
});
|
||||
|
||||
test("returns true for single valid email", () => {
|
||||
expect(areAllEmailsValid(["valid@example.com"])).toBe(true);
|
||||
});
|
||||
|
||||
test("returns false for single invalid email", () => {
|
||||
expect(areAllEmailsValid(["invalid"])).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("hasDuplicates", () => {
|
||||
test("returns false when all values are unique", () => {
|
||||
expect(hasDuplicates(["a@test.com", "b@test.com", "c@test.com"])).toBe(
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
test("returns true when duplicates exist", () => {
|
||||
expect(hasDuplicates(["a@test.com", "b@test.com", "a@test.com"])).toBe(true);
|
||||
});
|
||||
|
||||
test("returns true for case-insensitive duplicates", () => {
|
||||
expect(hasDuplicates(["User@Test.com", "user@test.com"])).toBe(true);
|
||||
expect(hasDuplicates(["A@EXAMPLE.COM", "a@example.com"])).toBe(true);
|
||||
});
|
||||
|
||||
test("returns false for empty array", () => {
|
||||
expect(hasDuplicates([])).toBe(false);
|
||||
});
|
||||
|
||||
test("returns false for single item array", () => {
|
||||
expect(hasDuplicates(["single@test.com"])).toBe(false);
|
||||
});
|
||||
|
||||
test("handles multiple duplicates", () => {
|
||||
expect(
|
||||
hasDuplicates(["a@test.com", "a@test.com", "b@test.com", "b@test.com"]),
|
||||
).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -1,79 +0,0 @@
|
||||
import { describe, expect, it, vi, beforeEach } from "vitest";
|
||||
import { PermissionKey } from "#/utils/org/permissions";
|
||||
|
||||
// Mock dependencies for getActiveOrganizationUser tests
|
||||
vi.mock("#/api/organization-service/organization-service.api", () => ({
|
||||
organizationService: {
|
||||
getMe: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("#/stores/selected-organization-store", () => ({
|
||||
getSelectedOrganizationIdFromStore: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/query-client-getters", () => ({
|
||||
getMeFromQueryClient: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/query-client-config", () => ({
|
||||
queryClient: {
|
||||
setQueryData: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocks are set up
|
||||
import {
|
||||
getAvailableRolesAUserCanAssign,
|
||||
getActiveOrganizationUser,
|
||||
} from "#/utils/org/permission-checks";
|
||||
import { organizationService } from "#/api/organization-service/organization-service.api";
|
||||
import { getSelectedOrganizationIdFromStore } from "#/stores/selected-organization-store";
|
||||
import { getMeFromQueryClient } from "#/utils/query-client-getters";
|
||||
|
||||
describe("getAvailableRolesAUserCanAssign", () => {
|
||||
it("returns empty array if user has no permissions", () => {
|
||||
const result = getAvailableRolesAUserCanAssign([]);
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
|
||||
it("returns only roles the user has permission for", () => {
|
||||
const userPermissions: PermissionKey[] = [
|
||||
"change_user_role:member",
|
||||
"change_user_role:admin",
|
||||
];
|
||||
const result = getAvailableRolesAUserCanAssign(userPermissions);
|
||||
expect(result.sort()).toEqual(["admin", "member"].sort());
|
||||
});
|
||||
|
||||
it("returns all roles if user has all permissions", () => {
|
||||
const allPermissions: PermissionKey[] = [
|
||||
"change_user_role:member",
|
||||
"change_user_role:admin",
|
||||
"change_user_role:owner",
|
||||
];
|
||||
const result = getAvailableRolesAUserCanAssign(allPermissions);
|
||||
expect(result.sort()).toEqual(["member", "admin", "owner"].sort());
|
||||
});
|
||||
});
|
||||
|
||||
describe("getActiveOrganizationUser", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("should return undefined when API call throws an error", async () => {
|
||||
// Arrange: orgId exists, cache is empty, API call fails
|
||||
vi.mocked(getSelectedOrganizationIdFromStore).mockReturnValue("org-1");
|
||||
vi.mocked(getMeFromQueryClient).mockReturnValue(undefined);
|
||||
vi.mocked(organizationService.getMe).mockRejectedValue(
|
||||
new Error("Network error"),
|
||||
);
|
||||
|
||||
// Act
|
||||
const result = await getActiveOrganizationUser();
|
||||
|
||||
// Assert: should return undefined instead of propagating the error
|
||||
expect(result).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -1,175 +0,0 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { redirect } from "react-router";
|
||||
|
||||
// Mock dependencies before importing the module under test
|
||||
vi.mock("react-router", () => ({
|
||||
redirect: vi.fn((path: string) => ({ type: "redirect", path })),
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/org/permission-checks", () => ({
|
||||
getActiveOrganizationUser: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("#/api/option-service/option-service.api", () => ({
|
||||
default: {
|
||||
getConfig: vi.fn().mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
hide_llm_settings: false,
|
||||
},
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
const mockConfig = {
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
hide_llm_settings: false,
|
||||
},
|
||||
};
|
||||
|
||||
vi.mock("#/query-client-config", () => ({
|
||||
queryClient: {
|
||||
getQueryData: vi.fn(() => mockConfig),
|
||||
setQueryData: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Import after mocks are set up
|
||||
import { createPermissionGuard } from "#/utils/org/permission-guard";
|
||||
import { getActiveOrganizationUser } from "#/utils/org/permission-checks";
|
||||
|
||||
// Helper to create a mock request
|
||||
const createMockRequest = (pathname: string = "/settings/billing") => ({
|
||||
request: new Request(`http://localhost${pathname}`),
|
||||
});
|
||||
|
||||
describe("createPermissionGuard", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
describe("permission checking", () => {
|
||||
it("should redirect when user lacks required permission", async () => {
|
||||
// Arrange: member lacks view_billing permission
|
||||
vi.mocked(getActiveOrganizationUser).mockResolvedValue({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
// Act
|
||||
const guard = createPermissionGuard("view_billing");
|
||||
await guard(createMockRequest("/settings/billing"));
|
||||
|
||||
// Assert: should redirect to first available path (/settings/user in SaaS mode)
|
||||
expect(redirect).toHaveBeenCalledWith("/settings/user");
|
||||
});
|
||||
|
||||
it("should allow access when user has required permission", async () => {
|
||||
// Arrange: admin has view_billing permission
|
||||
vi.mocked(getActiveOrganizationUser).mockResolvedValue({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "admin@example.com",
|
||||
role: "admin",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
// Act
|
||||
const guard = createPermissionGuard("view_billing");
|
||||
const result = await guard(createMockRequest("/settings/billing"));
|
||||
|
||||
// Assert: should not redirect, return null
|
||||
expect(redirect).not.toHaveBeenCalled();
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it("should redirect when user is undefined (no org selected)", async () => {
|
||||
// Arrange: no user (e.g., no organization selected)
|
||||
vi.mocked(getActiveOrganizationUser).mockResolvedValue(undefined);
|
||||
|
||||
// Act
|
||||
const guard = createPermissionGuard("view_billing");
|
||||
await guard(createMockRequest("/settings/billing"));
|
||||
|
||||
// Assert: should redirect to first available path
|
||||
expect(redirect).toHaveBeenCalledWith("/settings/user");
|
||||
});
|
||||
|
||||
it("should redirect when user is undefined even for member-level permissions", async () => {
|
||||
// Arrange: no user — manage_secrets is a member-level permission,
|
||||
// but undefined user should NOT get member access
|
||||
vi.mocked(getActiveOrganizationUser).mockResolvedValue(undefined);
|
||||
|
||||
// Act
|
||||
const guard = createPermissionGuard("manage_secrets");
|
||||
await guard(createMockRequest("/settings/secrets"));
|
||||
|
||||
// Assert: should redirect, not silently grant member-level access
|
||||
expect(redirect).toHaveBeenCalledWith("/settings/user");
|
||||
});
|
||||
});
|
||||
|
||||
describe("custom redirect path", () => {
|
||||
it("should redirect to custom path when specified", async () => {
|
||||
// Arrange: member lacks permission
|
||||
vi.mocked(getActiveOrganizationUser).mockResolvedValue({
|
||||
org_id: "org-1",
|
||||
user_id: "user-1",
|
||||
email: "test@example.com",
|
||||
role: "member",
|
||||
llm_api_key: "",
|
||||
max_iterations: 100,
|
||||
llm_model: "gpt-4",
|
||||
llm_api_key_for_byor: null,
|
||||
llm_base_url: "",
|
||||
status: "active",
|
||||
});
|
||||
|
||||
// Act
|
||||
const guard = createPermissionGuard("view_billing", "/custom/redirect");
|
||||
await guard(createMockRequest("/settings/billing"));
|
||||
|
||||
// Assert: should redirect to custom path
|
||||
expect(redirect).toHaveBeenCalledWith("/custom/redirect");
|
||||
});
|
||||
});
|
||||
|
||||
describe("infinite loop prevention", () => {
|
||||
it("should return null instead of redirecting when fallback path equals current path", async () => {
|
||||
// Arrange: no user
|
||||
vi.mocked(getActiveOrganizationUser).mockResolvedValue(undefined);
|
||||
|
||||
// Act: access /settings/user when fallback would also be /settings/user
|
||||
const guard = createPermissionGuard("view_billing");
|
||||
const result = await guard(createMockRequest("/settings/user"));
|
||||
|
||||
// Assert: should NOT redirect to avoid infinite loop
|
||||
expect(redirect).not.toHaveBeenCalled();
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
});
|
||||
});
|
||||
Generated
+2
-2
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.5.0",
|
||||
"version": "1.4.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.5.0",
|
||||
"version": "1.4.0",
|
||||
"dependencies": {
|
||||
"@heroui/react": "2.8.7",
|
||||
"@microlink/react-json-view": "^1.27.1",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "openhands-frontend",
|
||||
"version": "1.5.0",
|
||||
"version": "1.4.0",
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"engines": {
|
||||
|
||||
@@ -75,7 +75,7 @@ export default defineConfig({
|
||||
|
||||
/* Run your local dev server before starting the tests */
|
||||
webServer: {
|
||||
command: "npm run dev:mock:saas -- --port 3001",
|
||||
command: "npm run dev:mock -- --port 3001",
|
||||
url: "http://localhost:3001/",
|
||||
reuseExistingServer: !process.env.CI,
|
||||
},
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
import {
|
||||
Organization,
|
||||
OrganizationMember,
|
||||
OrganizationMembersPage,
|
||||
UpdateOrganizationMemberParams,
|
||||
} from "#/types/org";
|
||||
import { openHands } from "../open-hands-axios";
|
||||
|
||||
export const organizationService = {
|
||||
getMe: async ({ orgId }: { orgId: string }) => {
|
||||
const { data } = await openHands.get<OrganizationMember>(
|
||||
`/api/organizations/${orgId}/me`,
|
||||
);
|
||||
|
||||
return data;
|
||||
},
|
||||
|
||||
getOrganization: async ({ orgId }: { orgId: string }) => {
|
||||
const { data } = await openHands.get<Organization>(
|
||||
`/api/organizations/${orgId}`,
|
||||
);
|
||||
return data;
|
||||
},
|
||||
|
||||
getOrganizations: async () => {
|
||||
const { data } = await openHands.get<{
|
||||
items: Organization[];
|
||||
current_org_id: string | null;
|
||||
}>("/api/organizations");
|
||||
return {
|
||||
items: data?.items || [],
|
||||
currentOrgId: data?.current_org_id || null,
|
||||
};
|
||||
},
|
||||
|
||||
updateOrganization: async ({
|
||||
orgId,
|
||||
name,
|
||||
}: {
|
||||
orgId: string;
|
||||
name: string;
|
||||
}) => {
|
||||
const { data } = await openHands.patch<Organization>(
|
||||
`/api/organizations/${orgId}`,
|
||||
{ name },
|
||||
);
|
||||
return data;
|
||||
},
|
||||
|
||||
deleteOrganization: async ({ orgId }: { orgId: string }) => {
|
||||
await openHands.delete(`/api/organizations/${orgId}`);
|
||||
},
|
||||
|
||||
getOrganizationMembers: async ({
|
||||
orgId,
|
||||
page = 1,
|
||||
limit = 10,
|
||||
email,
|
||||
}: {
|
||||
orgId: string;
|
||||
page?: number;
|
||||
limit?: number;
|
||||
email?: string;
|
||||
}) => {
|
||||
const params = new URLSearchParams();
|
||||
|
||||
// Calculate offset from page number (page_id is offset-based)
|
||||
const offset = (page - 1) * limit;
|
||||
params.set("page_id", String(offset));
|
||||
params.set("limit", String(limit));
|
||||
|
||||
if (email) {
|
||||
params.set("email", email);
|
||||
}
|
||||
|
||||
const { data } = await openHands.get<OrganizationMembersPage>(
|
||||
`/api/organizations/${orgId}/members?${params.toString()}`,
|
||||
);
|
||||
|
||||
return data;
|
||||
},
|
||||
|
||||
getOrganizationMembersCount: async ({
|
||||
orgId,
|
||||
email,
|
||||
}: {
|
||||
orgId: string;
|
||||
email?: string;
|
||||
}) => {
|
||||
const params = new URLSearchParams();
|
||||
|
||||
if (email) {
|
||||
params.set("email", email);
|
||||
}
|
||||
|
||||
const { data } = await openHands.get<number>(
|
||||
`/api/organizations/${orgId}/members/count?${params.toString()}`,
|
||||
);
|
||||
|
||||
return data;
|
||||
},
|
||||
|
||||
getOrganizationPaymentInfo: async ({ orgId }: { orgId: string }) => {
|
||||
const { data } = await openHands.get<{
|
||||
cardNumber: string;
|
||||
}>(`/api/organizations/${orgId}/payment`);
|
||||
return data;
|
||||
},
|
||||
|
||||
updateMember: async ({
|
||||
orgId,
|
||||
userId,
|
||||
...updateData
|
||||
}: {
|
||||
orgId: string;
|
||||
userId: string;
|
||||
} & UpdateOrganizationMemberParams) => {
|
||||
const { data } = await openHands.patch(
|
||||
`/api/organizations/${orgId}/members/${userId}`,
|
||||
updateData,
|
||||
);
|
||||
|
||||
return data;
|
||||
},
|
||||
|
||||
removeMember: async ({
|
||||
orgId,
|
||||
userId,
|
||||
}: {
|
||||
orgId: string;
|
||||
userId: string;
|
||||
}) => {
|
||||
await openHands.delete(`/api/organizations/${orgId}/members/${userId}`);
|
||||
},
|
||||
|
||||
inviteMembers: async ({
|
||||
orgId,
|
||||
emails,
|
||||
}: {
|
||||
orgId: string;
|
||||
emails: string[];
|
||||
}) => {
|
||||
const { data } = await openHands.post<OrganizationMember[]>(
|
||||
`/api/organizations/${orgId}/members/invite`,
|
||||
{
|
||||
emails,
|
||||
},
|
||||
);
|
||||
|
||||
return data;
|
||||
},
|
||||
|
||||
switchOrganization: async ({ orgId }: { orgId: string }) => {
|
||||
const { data } = await openHands.post<Organization>(
|
||||
`/api/organizations/${orgId}/switch`,
|
||||
);
|
||||
return data;
|
||||
},
|
||||
};
|
||||
@@ -1,5 +1,4 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { FaUserShield } from "react-icons/fa";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import OpenHandsLogoWhite from "#/assets/branding/openhands-logo-white.svg?react";
|
||||
import GitHubLogo from "#/assets/branding/github-logo.svg?react";
|
||||
@@ -66,12 +65,6 @@ export function LoginContent({
|
||||
authUrl,
|
||||
});
|
||||
|
||||
const enterpriseSsoAuthUrl = useAuthUrl({
|
||||
appMode: appMode || null,
|
||||
identityProvider: "enterprise_sso",
|
||||
authUrl,
|
||||
});
|
||||
|
||||
const handleAuthRedirect = async (
|
||||
redirectUrl: string,
|
||||
provider: Provider,
|
||||
@@ -134,12 +127,6 @@ export function LoginContent({
|
||||
}
|
||||
};
|
||||
|
||||
const handleEnterpriseSsoAuth = () => {
|
||||
if (enterpriseSsoAuthUrl) {
|
||||
handleAuthRedirect(enterpriseSsoAuthUrl, "enterprise_sso");
|
||||
}
|
||||
};
|
||||
|
||||
const showGithub =
|
||||
providersConfigured &&
|
||||
providersConfigured.length > 0 &&
|
||||
@@ -156,10 +143,6 @@ export function LoginContent({
|
||||
providersConfigured &&
|
||||
providersConfigured.length > 0 &&
|
||||
providersConfigured.includes("bitbucket_data_center");
|
||||
const showEnterpriseSso =
|
||||
providersConfigured &&
|
||||
providersConfigured.length > 0 &&
|
||||
providersConfigured.includes("enterprise_sso");
|
||||
|
||||
const noProvidersConfigured =
|
||||
!providersConfigured || providersConfigured.length === 0;
|
||||
@@ -278,19 +261,6 @@ export function LoginContent({
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showEnterpriseSso && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleEnterpriseSsoAuth}
|
||||
className={`${buttonBaseClasses} bg-[#374151] text-white`}
|
||||
>
|
||||
<FaUserShield size={14} className="shrink-0" />
|
||||
<span className={buttonLabelClasses}>
|
||||
{t(I18nKey.ENTERPRISE_SSO$CONNECT_TO_ENTERPRISE_SSO)}
|
||||
</span>
|
||||
</button>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -0,0 +1,116 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { Link } from "react-router";
|
||||
import { useFeatureFlagEnabled } from "posthog-js/react";
|
||||
import { ContextMenu } from "#/ui/context-menu";
|
||||
import { ContextMenuListItem } from "./context-menu-list-item";
|
||||
import { Divider } from "#/ui/divider";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import LogOutIcon from "#/icons/log-out.svg?react";
|
||||
import DocumentIcon from "#/icons/document.svg?react";
|
||||
import PlusIcon from "#/icons/plus.svg?react";
|
||||
import { useSettingsNavItems } from "#/hooks/use-settings-nav-items";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { useTracking } from "#/hooks/use-tracking";
|
||||
|
||||
interface AccountSettingsContextMenuProps {
|
||||
onLogout: () => void;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function AccountSettingsContextMenu({
|
||||
onLogout,
|
||||
onClose,
|
||||
}: AccountSettingsContextMenuProps) {
|
||||
const ref = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
const { t } = useTranslation();
|
||||
const { trackAddTeamMembersButtonClick } = useTracking();
|
||||
const { data: config } = useConfig();
|
||||
const { data: settings } = useSettings();
|
||||
const isAddTeamMemberEnabled = useFeatureFlagEnabled(
|
||||
"exp_add_team_member_button",
|
||||
);
|
||||
// Get navigation items and filter out LLM settings if the feature flag is enabled
|
||||
const items = useSettingsNavItems();
|
||||
|
||||
const isSaasMode = config?.app_mode === "saas";
|
||||
const hasAnalyticsConsent = settings?.user_consents_to_analytics === true;
|
||||
const showAddTeamMembers =
|
||||
isSaasMode && isAddTeamMemberEnabled && hasAnalyticsConsent;
|
||||
|
||||
const navItems = items.map((item) => ({
|
||||
...item,
|
||||
icon: React.cloneElement(item.icon, {
|
||||
width: 16,
|
||||
height: 16,
|
||||
} as React.SVGProps<SVGSVGElement>),
|
||||
}));
|
||||
const handleNavigationClick = () => onClose();
|
||||
|
||||
const handleAddTeamMembers = () => {
|
||||
trackAddTeamMembersButtonClick();
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<ContextMenu
|
||||
testId="account-settings-context-menu"
|
||||
ref={ref}
|
||||
alignment="right"
|
||||
className="mt-0 md:right-full md:left-full md:bottom-0 ml-0 w-fit z-[9999]"
|
||||
>
|
||||
{showAddTeamMembers && (
|
||||
<ContextMenuListItem
|
||||
testId="add-team-members-button"
|
||||
onClick={handleAddTeamMembers}
|
||||
className="flex items-center gap-2 p-2 hover:bg-[#5C5D62] rounded h-[30px]"
|
||||
>
|
||||
<PlusIcon width={16} height={16} />
|
||||
<span className="text-white text-sm">
|
||||
{t(I18nKey.SETTINGS$NAV_ADD_TEAM_MEMBERS)}
|
||||
</span>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
{navItems.map(({ to, text, icon }) => (
|
||||
<Link key={to} to={to} className="text-decoration-none">
|
||||
<ContextMenuListItem
|
||||
onClick={handleNavigationClick}
|
||||
className="flex items-center gap-2 p-2 hover:bg-[#5C5D62] rounded h-[30px]"
|
||||
>
|
||||
{icon}
|
||||
<span className="text-white text-sm">{t(text)}</span>
|
||||
</ContextMenuListItem>
|
||||
</Link>
|
||||
))}
|
||||
|
||||
<Divider />
|
||||
|
||||
<a
|
||||
href="https://docs.openhands.dev"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="text-decoration-none"
|
||||
>
|
||||
<ContextMenuListItem
|
||||
onClick={onClose}
|
||||
className="flex items-center gap-2 p-2 hover:bg-[#5C5D62] rounded h-[30px]"
|
||||
>
|
||||
<DocumentIcon width={16} height={16} />
|
||||
<span className="text-white text-sm">{t(I18nKey.SIDEBAR$DOCS)}</span>
|
||||
</ContextMenuListItem>
|
||||
</a>
|
||||
|
||||
<ContextMenuListItem
|
||||
onClick={onLogout}
|
||||
className="flex items-center gap-2 p-2 hover:bg-[#5C5D62] rounded h-[30px]"
|
||||
>
|
||||
<LogOutIcon width={16} height={16} />
|
||||
<span className="text-white text-sm">
|
||||
{t(I18nKey.ACCOUNT_SETTINGS$LOGOUT)}
|
||||
</span>
|
||||
</ContextMenuListItem>
|
||||
</ContextMenu>
|
||||
);
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
import { useState } from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { OrgModal } from "#/components/shared/modals/org-modal";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useUpdateOrganization } from "#/hooks/mutation/use-update-organization";
|
||||
|
||||
interface ChangeOrgNameModalProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function ChangeOrgNameModal({ onClose }: ChangeOrgNameModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { mutate: updateOrganization, isPending } = useUpdateOrganization();
|
||||
const [orgName, setOrgName] = useState<string>("");
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (orgName?.trim()) {
|
||||
updateOrganization(orgName, {
|
||||
onSuccess: () => {
|
||||
onClose();
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<OrgModal
|
||||
testId="update-org-name-form"
|
||||
title={t(I18nKey.ORG$CHANGE_ORG_NAME)}
|
||||
description={t(I18nKey.ORG$MODIFY_ORG_NAME_DESCRIPTION)}
|
||||
primaryButtonText={t(I18nKey.BUTTON$SAVE)}
|
||||
onPrimaryClick={handleSubmit}
|
||||
onClose={onClose}
|
||||
isLoading={isPending}
|
||||
>
|
||||
<input
|
||||
data-testid="org-name"
|
||||
value={orgName}
|
||||
placeholder={t(I18nKey.ORG$ENTER_NEW_ORGANIZATION_NAME)}
|
||||
onChange={(e) => setOrgName(e.target.value)}
|
||||
className="bg-tertiary border border-[#717888] h-10 w-full rounded-sm p-2 placeholder:italic placeholder:text-tertiary-alt"
|
||||
/>
|
||||
</OrgModal>
|
||||
);
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
import { Trans, useTranslation } from "react-i18next";
|
||||
import { OrgModal } from "#/components/shared/modals/org-modal";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
interface ConfirmRemoveMemberModalProps {
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
memberEmail: string;
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
export function ConfirmRemoveMemberModal({
|
||||
onConfirm,
|
||||
onCancel,
|
||||
memberEmail,
|
||||
isLoading = false,
|
||||
}: ConfirmRemoveMemberModalProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const confirmationMessage = (
|
||||
<Trans
|
||||
i18nKey={I18nKey.ORG$REMOVE_MEMBER_WARNING}
|
||||
values={{ email: memberEmail }}
|
||||
components={{ email: <span className="text-white" /> }}
|
||||
/>
|
||||
);
|
||||
|
||||
return (
|
||||
<OrgModal
|
||||
title={t(I18nKey.ORG$CONFIRM_REMOVE_MEMBER)}
|
||||
description={confirmationMessage}
|
||||
primaryButtonText={t(I18nKey.BUTTON$CONFIRM)}
|
||||
secondaryButtonText={t(I18nKey.BUTTON$CANCEL)}
|
||||
onPrimaryClick={onConfirm}
|
||||
onClose={onCancel}
|
||||
isLoading={isLoading}
|
||||
primaryButtonTestId="confirm-button"
|
||||
secondaryButtonTestId="cancel-button"
|
||||
fullWidthButtons
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
import { Trans, useTranslation } from "react-i18next";
|
||||
import { OrgModal } from "#/components/shared/modals/org-modal";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { OrganizationUserRole } from "#/types/org";
|
||||
|
||||
interface ConfirmUpdateRoleModalProps {
|
||||
onConfirm: () => void;
|
||||
onCancel: () => void;
|
||||
memberEmail: string;
|
||||
newRole: OrganizationUserRole;
|
||||
isLoading?: boolean;
|
||||
}
|
||||
|
||||
export function ConfirmUpdateRoleModal({
|
||||
onConfirm,
|
||||
onCancel,
|
||||
memberEmail,
|
||||
newRole,
|
||||
isLoading = false,
|
||||
}: ConfirmUpdateRoleModalProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const confirmationMessage = (
|
||||
<Trans
|
||||
i18nKey={I18nKey.ORG$UPDATE_ROLE_WARNING}
|
||||
values={{ email: memberEmail, role: newRole }}
|
||||
components={{
|
||||
email: <span className="text-white" />,
|
||||
role: <span className="text-white capitalize" />,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
|
||||
return (
|
||||
<OrgModal
|
||||
title={t(I18nKey.ORG$CONFIRM_UPDATE_ROLE)}
|
||||
description={confirmationMessage}
|
||||
primaryButtonText={t(I18nKey.BUTTON$CONFIRM)}
|
||||
onPrimaryClick={onConfirm}
|
||||
onClose={onCancel}
|
||||
isLoading={isLoading}
|
||||
primaryButtonTestId="confirm-button"
|
||||
secondaryButtonTestId="cancel-button"
|
||||
fullWidthButtons
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
import { Trans, useTranslation } from "react-i18next";
|
||||
import { OrgModal } from "#/components/shared/modals/org-modal";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { useDeleteOrganization } from "#/hooks/mutation/use-delete-organization";
|
||||
import { useOrganization } from "#/hooks/query/use-organization";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
interface DeleteOrgConfirmationModalProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function DeleteOrgConfirmationModal({
|
||||
onClose,
|
||||
}: DeleteOrgConfirmationModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { mutate: deleteOrganization, isPending } = useDeleteOrganization();
|
||||
const { data: organization } = useOrganization();
|
||||
|
||||
const handleConfirm = () => {
|
||||
deleteOrganization(undefined, {
|
||||
onSuccess: onClose,
|
||||
onError: () => {
|
||||
displayErrorToast(t(I18nKey.ORG$DELETE_ORGANIZATION_ERROR));
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
const confirmationMessage = organization?.name ? (
|
||||
<Trans
|
||||
i18nKey={I18nKey.ORG$DELETE_ORGANIZATION_WARNING_WITH_NAME}
|
||||
values={{ name: organization.name }}
|
||||
components={{ name: <span className="text-white" /> }}
|
||||
/>
|
||||
) : (
|
||||
t(I18nKey.ORG$DELETE_ORGANIZATION_WARNING)
|
||||
);
|
||||
|
||||
return (
|
||||
<OrgModal
|
||||
testId="delete-org-confirmation"
|
||||
title={t(I18nKey.ORG$DELETE_ORGANIZATION)}
|
||||
description={confirmationMessage}
|
||||
primaryButtonText={t(I18nKey.BUTTON$CONFIRM)}
|
||||
onPrimaryClick={handleConfirm}
|
||||
onClose={onClose}
|
||||
isLoading={isPending}
|
||||
secondaryButtonTestId="cancel-button"
|
||||
ariaLabel={t(I18nKey.ORG$DELETE_ORGANIZATION)}
|
||||
fullWidthButtons
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,68 +0,0 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { OrgModal } from "#/components/shared/modals/org-modal";
|
||||
import { useInviteMembersBatch } from "#/hooks/mutation/use-invite-members-batch";
|
||||
import { BadgeInput } from "#/components/shared/inputs/badge-input";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { areAllEmailsValid, hasDuplicates } from "#/utils/input-validation";
|
||||
|
||||
interface InviteOrganizationMemberModalProps {
|
||||
onClose: (event?: React.MouseEvent<HTMLButtonElement>) => void;
|
||||
}
|
||||
|
||||
export function InviteOrganizationMemberModal({
|
||||
onClose,
|
||||
}: InviteOrganizationMemberModalProps) {
|
||||
const { t } = useTranslation();
|
||||
const { mutate: inviteMembers, isPending } = useInviteMembersBatch();
|
||||
const [emails, setEmails] = React.useState<string[]>([]);
|
||||
|
||||
const handleEmailsChange = (newEmails: string[]) => {
|
||||
const trimmedEmails = newEmails.map((email) => email.trim());
|
||||
setEmails(trimmedEmails);
|
||||
};
|
||||
|
||||
const handleSubmit = () => {
|
||||
if (emails.length === 0) {
|
||||
displayErrorToast(t(I18nKey.ORG$NO_EMAILS_ADDED_HINT));
|
||||
return;
|
||||
}
|
||||
|
||||
if (!areAllEmailsValid(emails)) {
|
||||
displayErrorToast(t(I18nKey.SETTINGS$INVALID_EMAIL_FORMAT));
|
||||
return;
|
||||
}
|
||||
|
||||
if (hasDuplicates(emails)) {
|
||||
displayErrorToast(t(I18nKey.ORG$DUPLICATE_EMAILS_ERROR));
|
||||
return;
|
||||
}
|
||||
|
||||
inviteMembers(
|
||||
{ emails },
|
||||
{
|
||||
onSuccess: () => onClose(),
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<OrgModal
|
||||
testId="invite-modal"
|
||||
title={t(I18nKey.ORG$INVITE_ORG_MEMBERS)}
|
||||
description={t(I18nKey.ORG$INVITE_USERS_DESCRIPTION)}
|
||||
primaryButtonText={t(I18nKey.BUTTON$ADD)}
|
||||
onPrimaryClick={handleSubmit}
|
||||
onClose={onClose}
|
||||
isLoading={isPending}
|
||||
>
|
||||
<BadgeInput
|
||||
name="emails-badge-input"
|
||||
value={emails}
|
||||
placeholder={t(I18nKey.COMMON$TYPE_EMAIL_AND_PRESS_SPACE)}
|
||||
onChange={handleEmailsChange}
|
||||
/>
|
||||
</OrgModal>
|
||||
);
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
import { useSwitchOrganization } from "#/hooks/mutation/use-switch-organization";
|
||||
import { useOrganizations } from "#/hooks/query/use-organizations";
|
||||
import { useShouldHideOrgSelector } from "#/hooks/use-should-hide-org-selector";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { Organization } from "#/types/org";
|
||||
import { Dropdown } from "#/ui/dropdown/dropdown";
|
||||
|
||||
export function OrgSelector() {
|
||||
const { t } = useTranslation();
|
||||
const { organizationId } = useSelectedOrganizationId();
|
||||
const { data, isLoading } = useOrganizations();
|
||||
const organizations = data?.organizations;
|
||||
const { mutate: switchOrganization, isPending: isSwitching } =
|
||||
useSwitchOrganization();
|
||||
const shouldHideSelector = useShouldHideOrgSelector();
|
||||
|
||||
const getOrgDisplayName = React.useCallback(
|
||||
(org: Organization) =>
|
||||
org.is_personal ? t(I18nKey.ORG$PERSONAL_WORKSPACE) : org.name,
|
||||
[t],
|
||||
);
|
||||
|
||||
const selectedOrg = React.useMemo(() => {
|
||||
if (organizationId) {
|
||||
return organizations?.find((org) => org.id === organizationId);
|
||||
}
|
||||
|
||||
return organizations?.[0];
|
||||
}, [organizationId, organizations]);
|
||||
|
||||
if (shouldHideSelector) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dropdown
|
||||
testId="org-selector"
|
||||
key={`${selectedOrg?.id}-${selectedOrg?.name}`}
|
||||
defaultValue={{
|
||||
label: selectedOrg ? getOrgDisplayName(selectedOrg) : "",
|
||||
value: selectedOrg?.id || "",
|
||||
}}
|
||||
onChange={(item) => {
|
||||
if (item && item.value !== organizationId) {
|
||||
switchOrganization(item.value);
|
||||
}
|
||||
}}
|
||||
placeholder={t(I18nKey.ORG$SELECT_ORGANIZATION_PLACEHOLDER)}
|
||||
loading={isLoading || isSwitching}
|
||||
options={
|
||||
organizations?.map((org) => ({
|
||||
value: org.id,
|
||||
label: getOrgDisplayName(org),
|
||||
})) || []
|
||||
}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { ChevronDown } from "lucide-react";
|
||||
import { OrganizationMember, OrganizationUserRole } from "#/types/org";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { OrganizationMemberRoleContextMenu } from "./organization-member-role-context-menu";
|
||||
|
||||
interface OrganizationMemberListItemProps {
|
||||
email: OrganizationMember["email"];
|
||||
role: OrganizationMember["role"];
|
||||
status: OrganizationMember["status"];
|
||||
hasPermissionToChangeRole: boolean;
|
||||
availableRolesToChangeTo: OrganizationUserRole[];
|
||||
|
||||
onRoleChange: (role: OrganizationUserRole) => void;
|
||||
onRemove?: () => void;
|
||||
}
|
||||
|
||||
export function OrganizationMemberListItem({
|
||||
email,
|
||||
role,
|
||||
status,
|
||||
hasPermissionToChangeRole,
|
||||
availableRolesToChangeTo,
|
||||
onRoleChange,
|
||||
onRemove,
|
||||
}: OrganizationMemberListItemProps) {
|
||||
const { t } = useTranslation();
|
||||
const [contextMenuOpen, setContextMenuOpen] = React.useState(false);
|
||||
|
||||
const roleSelectionIsPermitted =
|
||||
status !== "invited" && hasPermissionToChangeRole;
|
||||
|
||||
const handleRoleClick = (event: React.MouseEvent<HTMLSpanElement>) => {
|
||||
if (roleSelectionIsPermitted) {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
setContextMenuOpen(true);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-between py-4">
|
||||
<div className="flex items-center gap-2">
|
||||
<span
|
||||
className={cn(
|
||||
"text-sm font-semibold leading-6",
|
||||
status === "invited" && "text-gray-400",
|
||||
)}
|
||||
>
|
||||
{email}
|
||||
</span>
|
||||
|
||||
{status === "invited" && (
|
||||
<span className="text-xs text-tertiary-light border border-tertiary px-2 py-1 rounded-lg">
|
||||
{t(I18nKey.ORG$STATUS_INVITED)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="relative">
|
||||
<span
|
||||
onClick={handleRoleClick}
|
||||
className={cn(
|
||||
"text-xs font-normal leading-4 text-org-text flex items-center gap-1 capitalize",
|
||||
roleSelectionIsPermitted ? "cursor-pointer" : "cursor-not-allowed",
|
||||
)}
|
||||
>
|
||||
{role}
|
||||
{hasPermissionToChangeRole && <ChevronDown size={14} />}
|
||||
</span>
|
||||
|
||||
{roleSelectionIsPermitted && contextMenuOpen && (
|
||||
<OrganizationMemberRoleContextMenu
|
||||
onClose={() => setContextMenuOpen(false)}
|
||||
onRoleChange={onRoleChange}
|
||||
onRemove={onRemove}
|
||||
availableRolesToChangeTo={availableRolesToChangeTo}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
import React from "react";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { ContextMenu } from "#/ui/context-menu";
|
||||
import { ContextMenuListItem } from "../context-menu/context-menu-list-item";
|
||||
import { ContextMenuIconText } from "#/ui/context-menu-icon-text";
|
||||
import { useClickOutsideElement } from "#/hooks/use-click-outside-element";
|
||||
import { OrganizationUserRole } from "#/types/org";
|
||||
import { cn } from "#/utils/utils";
|
||||
import UserIcon from "#/icons/user.svg?react";
|
||||
import DeleteIcon from "#/icons/u-delete.svg?react";
|
||||
import AdminIcon from "#/icons/admin.svg?react";
|
||||
|
||||
const contextMenuListItemClassName = cn(
|
||||
"cursor-pointer p-0 h-auto hover:bg-transparent",
|
||||
);
|
||||
|
||||
interface OrganizationMemberRoleContextMenuProps {
|
||||
onClose: () => void;
|
||||
onRoleChange: (role: OrganizationUserRole) => void;
|
||||
onRemove?: () => void;
|
||||
availableRolesToChangeTo: OrganizationUserRole[];
|
||||
}
|
||||
|
||||
export function OrganizationMemberRoleContextMenu({
|
||||
onClose,
|
||||
onRoleChange,
|
||||
onRemove,
|
||||
availableRolesToChangeTo,
|
||||
}: OrganizationMemberRoleContextMenuProps) {
|
||||
const { t } = useTranslation();
|
||||
const menuRef = useClickOutsideElement<HTMLUListElement>(onClose);
|
||||
|
||||
const handleRoleChangeClick = (
|
||||
event: React.MouseEvent<HTMLButtonElement>,
|
||||
role: OrganizationUserRole,
|
||||
) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
onRoleChange(role);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleRemoveClick = (event: React.MouseEvent<HTMLButtonElement>) => {
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
onRemove?.();
|
||||
onClose();
|
||||
};
|
||||
|
||||
return (
|
||||
<ContextMenu
|
||||
ref={menuRef}
|
||||
testId="organization-member-role-context-menu"
|
||||
position="bottom"
|
||||
alignment="right"
|
||||
className="min-h-fit mb-2 min-w-[195px] max-w-[195px] gap-0"
|
||||
>
|
||||
{availableRolesToChangeTo.includes("owner") && (
|
||||
<ContextMenuListItem
|
||||
testId="owner-option"
|
||||
onClick={(event) => handleRoleChangeClick(event, "owner")}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={
|
||||
<AdminIcon
|
||||
width={16}
|
||||
height={16}
|
||||
className="text-white pl-[2px]"
|
||||
/>
|
||||
}
|
||||
text={t(I18nKey.ORG$ROLE_OWNER)}
|
||||
className="capitalize"
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
{availableRolesToChangeTo.includes("admin") && (
|
||||
<ContextMenuListItem
|
||||
testId="admin-option"
|
||||
onClick={(event) => handleRoleChangeClick(event, "admin")}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={
|
||||
<AdminIcon
|
||||
width={16}
|
||||
height={16}
|
||||
className="text-white pl-[2px]"
|
||||
/>
|
||||
}
|
||||
text={t(I18nKey.ORG$ROLE_ADMIN)}
|
||||
className="capitalize"
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
{availableRolesToChangeTo.includes("member") && (
|
||||
<ContextMenuListItem
|
||||
testId="member-option"
|
||||
onClick={(event) => handleRoleChangeClick(event, "member")}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={<UserIcon width={16} height={16} className="text-white" />}
|
||||
text={t(I18nKey.ORG$ROLE_MEMBER)}
|
||||
className="capitalize"
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
)}
|
||||
<ContextMenuListItem
|
||||
testId="remove-option"
|
||||
onClick={handleRemoveClick}
|
||||
className={contextMenuListItemClassName}
|
||||
>
|
||||
<ContextMenuIconText
|
||||
icon={<DeleteIcon width={16} height={16} className="text-red-500" />}
|
||||
text={t(I18nKey.ORG$REMOVE)}
|
||||
className="text-red-500 capitalize"
|
||||
/>
|
||||
</ContextMenuListItem>
|
||||
</ContextMenu>
|
||||
);
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import { amountIsValid } from "#/utils/amount-is-valid";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { PoweredByStripeTag } from "./powered-by-stripe-tag";
|
||||
|
||||
export function PaymentForm({ isDisabled }: { isDisabled?: boolean }) {
|
||||
export function PaymentForm() {
|
||||
const { t } = useTranslation();
|
||||
const { data: balance, isLoading } = useBalance();
|
||||
const { mutate: addBalance, isPending } = useCreateStripeCheckoutSession();
|
||||
@@ -69,14 +69,13 @@ export function PaymentForm({ isDisabled }: { isDisabled?: boolean }) {
|
||||
min={10}
|
||||
max={25000}
|
||||
step={1}
|
||||
isDisabled={isDisabled}
|
||||
/>
|
||||
|
||||
<div className="flex items-center w-[680px] gap-2">
|
||||
<BrandButton
|
||||
variant="primary"
|
||||
type="submit"
|
||||
isDisabled={isPending || buttonIsDisabled || isDisabled}
|
||||
isDisabled={isPending || buttonIsDisabled}
|
||||
>
|
||||
{t(I18nKey.PAYMENT$ADD_CREDIT)}
|
||||
</BrandButton>
|
||||
|
||||
@@ -31,7 +31,7 @@ export function SetupPaymentModal() {
|
||||
variant="primary"
|
||||
className="w-full"
|
||||
isDisabled={isPending}
|
||||
onClick={() => mutate()}
|
||||
onClick={mutate}
|
||||
>
|
||||
{t(I18nKey.BILLING$PROCEED_TO_STRIPE)}
|
||||
</BrandButton>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user