Compare commits

..

1 Commits

Author SHA1 Message Date
openhands 1f7335fc15 feat: add notifications scope to GitHub OAuth defaultScope
Add the 'notifications' scope to the GitHub identity provider's
defaultScope in the Keycloak realm configuration. This enables
agents to read and manage GitHub notifications via the API
(list notifications, mark as read/done).

Co-authored-by: openhands <openhands@all-hands.dev>
2026-04-10 23:34:45 +00:00
130 changed files with 1500 additions and 6144 deletions
@@ -12,7 +12,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
+2 -2
View File
@@ -192,7 +192,7 @@ jobs:
- name: Upload test results
if: always()
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: playwright-report
path: tests/e2e/test-results/
@@ -200,7 +200,7 @@ jobs:
- name: Upload OpenHands logs
if: always()
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: openhands-logs
path: |
+1 -1
View File
@@ -41,7 +41,7 @@ jobs:
working-directory: ./frontend
run: npx playwright test --project=chromium
- name: Upload Playwright report
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
if: always()
with:
name: playwright-report
+39 -197
View File
@@ -33,48 +33,47 @@ jobs:
runs-on: ubuntu-latest
outputs:
base_image: ${{ steps.define-base-images.outputs.base_image }}
architectures: ${{ steps.define-base-images.outputs.architectures }}
platforms: ${{ steps.define-base-images.outputs.platforms }}
steps:
- name: Define base images
shell: bash
id: define-base-images
run: |
architectures='["amd64", "arm64"]'
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
json=$(jq -n -c '[
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik" }
platforms="linux/amd64"
json=$(jq -n -c --arg platforms "$platforms" '[
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik", platforms: $platforms }
]')
else
json=$(jq -n -c '[
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik" },
{ image: "ubuntu:24.04", tag: "ubuntu" }
platforms="linux/amd64,linux/arm64"
json=$(jq -n -c --arg platforms "$platforms" '[
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik", platforms: $platforms },
{ image: "ubuntu:24.04", tag: "ubuntu", platforms: $platforms }
]')
fi
echo "base_image=$json" >> "$GITHUB_OUTPUT"
echo "architectures=$architectures" >> "$GITHUB_OUTPUT"
echo "platforms=$platforms" >> "$GITHUB_OUTPUT"
# Builds the OpenHands Docker images (one per architecture, natively)
# Builds the OpenHands Docker images
ghcr_build_app:
name: Build App Image (${{ matrix.arch }})
runs-on: ${{ matrix.arch == 'arm64' && 'ubuntu-24.04-arm' || 'ubuntu-22.04' }}
name: Build App Image
runs-on: ubuntu-22.04
if: "!(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/ext-v'))"
needs: define-matrix
outputs:
# All arch variants produce the same base tags, so any entry works
base_tags: ${{ steps.build.outputs.base_tags }}
permissions:
contents: read
packages: write
strategy:
matrix:
arch: ${{ fromJson(needs.define-matrix.outputs.architectures) }}
steps:
- name: Checkout
uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.7.0
with:
image: tonistiigi/binfmt:latest
- name: Login to GHCR
uses: docker/login-action@v4
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
@@ -86,79 +85,33 @@ jobs:
run: |
echo REPO_OWNER=$(echo ${{ github.repository_owner }} | tr '[:upper:]' '[:lower:]') >> $GITHUB_ENV
- name: Build and push app image
id: build
if: "!github.event.pull_request.head.repo.fork"
run: |
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --push --arch ${{ matrix.arch }}
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --push -p ${{ needs.define-matrix.outputs.platforms }}
# Output base tags (without arch suffix) for the merge job
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --arch ${{ matrix.arch }} --dry
BASE_TAGS=$(jq -r '.base_tags | join("\n")' docker-build-dry.json)
echo "base_tags<<EOF" >> "$GITHUB_OUTPUT"
echo "$BASE_TAGS" >> "$GITHUB_OUTPUT"
echo "EOF" >> "$GITHUB_OUTPUT"
# Merges per-architecture app images into a multi-arch manifest
ghcr_build_app_merge:
name: Merge App Multi-Arch Manifest
runs-on: ubuntu-22.04
needs: [define-matrix, ghcr_build_app]
if: github.event.pull_request.head.repo.fork != true
permissions:
packages: write
steps:
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Merge multi-arch manifest
run: |
ARCHS='${{ join(fromJson(needs.define-matrix.outputs.architectures), ' ') }}'
TAGS="${{ needs.ghcr_build_app.outputs.base_tags }}"
while IFS= read -r tag; do
[[ -z "$tag" ]] && continue
sources=""
for arch in $ARCHS; do
if ! docker buildx imagetools inspect "${tag}-${arch}" > /dev/null 2>&1; then
echo "::error::Missing image ${tag}-${arch}"
exit 1
fi
sources+=" ${tag}-${arch}"
done
echo "Creating manifest for $tag from:$sources"
docker buildx imagetools create -t "$tag" $sources
done <<< "$TAGS"
# Builds the runtime Docker images (one per architecture x base_image, natively)
# Builds the runtime Docker images
ghcr_build_runtime:
name: Build Runtime Image (${{ matrix.base_image.tag }}, ${{ matrix.arch }})
runs-on: ${{ matrix.arch == 'arm64' && 'ubuntu-24.04-arm' || 'ubuntu-22.04' }}
name: Build Runtime Image
runs-on: ubuntu-22.04
if: "!(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/ext-v'))"
permissions:
contents: read
packages: write
needs: define-matrix
outputs:
# Keyed by base_image tag so the merge job can access per-variant tags.
# Matrix outputs from different entries with the same key overwrite each other,
# but all arch variants of the same base_image produce identical base tags.
base_tags_nikolaik: ${{ steps.params.outputs.base_tags_nikolaik }}
base_tags_ubuntu: ${{ steps.params.outputs.base_tags_ubuntu }}
strategy:
matrix:
base_image: ${{ fromJson(needs.define-matrix.outputs.base_image) }}
arch: ${{ fromJson(needs.define-matrix.outputs.architectures) }}
steps:
- name: Checkout
uses: actions/checkout@v6
with:
ref: ${{ github.event.pull_request.head.sha }}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3.7.0
with:
image: tonistiigi/binfmt:latest
- name: Login to GHCR
uses: docker/login-action@v4
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
@@ -184,22 +137,16 @@ jobs:
run: |
echo SHORT_SHA=$(git rev-parse --short "$RELEVANT_SHA") >> $GITHUB_ENV
- name: Determine docker build params
id: params
if: github.event.pull_request.head.repo.fork != true
shell: bash
run: |
./containers/build.sh -i runtime -o ${{ env.REPO_OWNER }} -t ${{ matrix.base_image.tag }} --arch ${{ matrix.arch }} --dry
./containers/build.sh -i runtime -o ${{ env.REPO_OWNER }} -t ${{ matrix.base_image.tag }} --dry -p ${{ matrix.base_image.platforms }}
DOCKER_BUILD_JSON=$(jq -c . < docker-build-dry.json)
echo "DOCKER_TAGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.tags | join(",")')" >> $GITHUB_ENV
echo "DOCKER_PLATFORM=$(echo "$DOCKER_BUILD_JSON" | jq -r '.platform')" >> $GITHUB_ENV
echo "DOCKER_BUILD_ARGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.build_args | join(",")')" >> $GITHUB_ENV
# Output base tags (without arch suffix) keyed by base_image tag for the merge job
BASE_TAGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.base_tags | join("\n")')
echo "base_tags_${{ matrix.base_image.tag }}<<EOF" >> "$GITHUB_OUTPUT"
echo "$BASE_TAGS" >> "$GITHUB_OUTPUT"
echo "EOF" >> "$GITHUB_OUTPUT"
- name: Build and push runtime image ${{ matrix.base_image.image }}
if: github.event.pull_request.head.repo.fork != true
uses: docker/build-push-action@v6
@@ -207,8 +154,9 @@ jobs:
push: true
tags: ${{ env.DOCKER_TAGS }}
platforms: ${{ env.DOCKER_PLATFORM }}
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}-${{ matrix.arch }}
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}-${{ matrix.arch }},mode=max
# Caching directives to boost performance
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
build-args: ${{ env.DOCKER_BUILD_ARGS }}
context: containers/runtime
provenance: false
@@ -221,66 +169,20 @@ jobs:
context: containers/runtime
- name: Upload runtime source for fork
if: github.event.pull_request.head.repo.fork
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: runtime-src-${{ matrix.base_image.tag }}-${{ matrix.arch }}
name: runtime-src-${{ matrix.base_image.tag }}
path: containers/runtime
# Merges per-architecture runtime images into multi-arch manifests
ghcr_build_runtime_merge:
name: Merge Runtime Multi-Arch Manifest
runs-on: ubuntu-22.04
needs: [define-matrix, ghcr_build_runtime]
if: github.event.pull_request.head.repo.fork != true
permissions:
packages: write
steps:
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Merge multi-arch manifests
run: |
ARCHS='${{ join(fromJson(needs.define-matrix.outputs.architectures), ' ') }}'
# Merge all runtime base_image variants
for variant_tags in \
"${{ needs.ghcr_build_runtime.outputs.base_tags_nikolaik }}" \
"${{ needs.ghcr_build_runtime.outputs.base_tags_ubuntu }}"; do
while IFS= read -r tag; do
[[ -z "$tag" ]] && continue
sources=""
for arch in $ARCHS; do
if ! docker buildx imagetools inspect "${tag}-${arch}" > /dev/null 2>&1; then
echo "::error::Missing image ${tag}-${arch}"
exit 1
fi
sources+=" ${tag}-${arch}"
done
echo "Creating manifest for $tag from:$sources"
docker buildx imagetools create -t "$tag" $sources
done <<< "$variant_tags"
done
ghcr_build_enterprise:
name: Push Enterprise Image (${{ matrix.arch }})
runs-on: ${{ matrix.arch == 'arm64' && 'ubuntu-24.04-arm' || 'ubuntu-22.04' }}
name: Push Enterprise Image
runs-on: ubuntu-22.04
permissions:
contents: read
packages: write
needs: [define-matrix, ghcr_build_app_merge]
needs: [define-matrix, ghcr_build_app]
# Do not build enterprise in forks
if: github.event.pull_request.head.repo.fork != true
outputs:
# Tags without arch suffix, for the merge job
base_tags: ${{ steps.meta_base.outputs.tags }}
strategy:
matrix:
arch: ${{ fromJson(needs.define-matrix.outputs.architectures) }}
steps:
- name: Checkout
uses: actions/checkout@v6
@@ -294,7 +196,7 @@ jobs:
driver-opts: network=host
- name: Login to GHCR
uses: docker/login-action@v4
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
@@ -302,28 +204,6 @@ jobs:
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@v6
with:
images: ghcr.io/openhands/enterprise-server
tags: |
type=ref,event=branch
type=ref,event=pr
type=sha
type=sha,format=long
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=semver,pattern={{major}}
type=match,pattern=cloud-\d+\.\d+\.\d+
flavor: |
latest=auto
prefix=
suffix=-${{ matrix.arch }}
env:
DOCKER_METADATA_PR_HEAD_SHA: true
# Also compute base tags (no arch suffix) for the merge job output
- name: Extract base metadata for merge
id: meta_base
uses: docker/metadata-action@v5
with:
images: ghcr.io/openhands/enterprise-server
@@ -342,7 +222,6 @@ jobs:
suffix=
env:
DOCKER_METADATA_PR_HEAD_SHA: true
- name: Determine app image tag
shell: bash
run: |
@@ -359,49 +238,12 @@ jobs:
labels: ${{ steps.meta.outputs.labels }}
build-args: |
OPENHANDS_VERSION=${{ env.OPENHANDS_DOCKER_TAG }}
platforms: linux/${{ matrix.arch }}
platforms: linux/amd64
# Add build provenance
provenance: true
# Add build attestations for better security
sbom: true
# Merges per-architecture enterprise images into a multi-arch manifest
ghcr_build_enterprise_merge:
name: Merge Enterprise Multi-Arch Manifest
runs-on: ubuntu-22.04
permissions:
packages: write
needs: [define-matrix, ghcr_build_enterprise]
if: github.event.pull_request.head.repo.fork != true
steps:
- name: Login to GHCR
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Merge multi-arch manifest
run: |
ARCHS='${{ join(fromJson(needs.define-matrix.outputs.architectures), ' ') }}'
TAGS="${{ needs.ghcr_build_enterprise.outputs.base_tags }}"
while IFS= read -r tag; do
[[ -z "$tag" ]] && continue
sources=""
for arch in $ARCHS; do
if ! docker buildx imagetools inspect "${tag}-${arch}" > /dev/null 2>&1; then
echo "::error::Missing image ${tag}-${arch}"
exit 1
fi
sources+=" ${tag}-${arch}"
done
echo "Creating manifest for $tag from:$sources"
docker buildx imagetools create -t "$tag" $sources
done <<< "$TAGS"
# "All Runtime Tests Passed" is a required job for PRs to merge
# We can remove this once the config changes
runtime_tests_check_success:
@@ -414,7 +256,7 @@ jobs:
update_pr_description:
name: Update PR Description
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]'
needs: [ghcr_build_runtime_merge]
needs: [ghcr_build_runtime]
runs-on: ubuntu-22.04
steps:
- name: Checkout
+1 -1
View File
@@ -269,7 +269,7 @@ jobs:
fi
- name: Upload output.jsonl as artifact
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
if: always() # Upload even if the previous steps fail
with:
name: resolver-output
+2 -2
View File
@@ -31,7 +31,7 @@ jobs:
echo "is_fork=false" >> $GITHUB_OUTPUT
fi
- uses: actions/checkout@v6
- uses: actions/checkout@v5
if: steps.check-fork.outputs.is_fork == 'false'
with:
ref: ${{ github.event.pull_request.head.ref }}
@@ -93,7 +93,7 @@ jobs:
contents: read
pull-requests: write
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v5
- name: Check for .pr/ directory
id: check
+2 -2
View File
@@ -51,7 +51,7 @@ jobs:
# Always checkout main branch for security - cannot test script changes in PRs
- name: Checkout extensions repository
if: steps.check-trace.outputs.trace_exists == 'true'
uses: actions/checkout@v6
uses: actions/checkout@v5
with:
repository: OpenHands/extensions
path: extensions
@@ -77,7 +77,7 @@ jobs:
--trace-file trace-info/laminar_trace_info.json
- name: Upload evaluation logs
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v5
if: always() && steps.check-trace.outputs.trace_exists == 'true'
with:
name: pr-review-evaluation-${{ github.event.pull_request.number }}
+2 -2
View File
@@ -65,7 +65,7 @@ jobs:
env:
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
- name: Store coverage file
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: coverage-openhands
path: |
@@ -97,7 +97,7 @@ jobs:
env:
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
- name: Store coverage file
uses: actions/upload-artifact@v7
uses: actions/upload-artifact@v6
with:
name: coverage-enterprise
path: ".coverage.enterprise.${{ matrix.python_version }}"
+9 -30
View File
@@ -8,18 +8,18 @@ push=0
load=0
tag_suffix=""
dry_run=0
arch_suffix=""
platform_override=""
# Function to display usage information
usage() {
echo "Usage: $0 -i <image_name> [-o <org_name>] [--push] [--load] [-t <tag_suffix>] [--dry] [--arch <arch>]"
echo "Usage: $0 -i <image_name> [-o <org_name>] [--push] [--load] [-t <tag_suffix>] [-p <platform>] [--dry]"
echo " -i: Image name (required)"
echo " -o: Organization name"
echo " --push: Push the image"
echo " --load: Load the image"
echo " -t: Tag suffix"
echo " -p: Platform(s) to build for (e.g. linux/amd64 or linux/amd64,linux/arm64)"
echo " --dry: Don't build, only create build-args.json"
echo " --arch: Architecture suffix (e.g. amd64 or arm64). Appends -<arch> to tags and forces single-platform build"
exit 1
}
@@ -31,8 +31,8 @@ while [[ $# -gt 0 ]]; do
--push) push=1; shift ;;
--load) load=1; shift ;;
-t) tag_suffix="$2"; shift 2 ;;
-p) platform_override="$2"; shift 2 ;;
--dry) dry_run=1; shift ;;
--arch) arch_suffix="$2"; shift 2 ;;
*) usage ;;
esac
done
@@ -78,7 +78,7 @@ if [[ -n $tag_suffix ]]; then
done
fi
echo "Tags (before arch suffix): ${tags[@]}"
echo "Tags: ${tags[@]}"
if [[ "$image_name" == "openhands" ]]; then
dir="./containers/app"
@@ -113,21 +113,10 @@ if [[ -n "$DOCKER_IMAGE_TAG" ]]; then
tags+=("$DOCKER_IMAGE_TAG")
fi
# Apply architecture suffix for split-arch builds (after all tags are collected)
if [[ -n "$arch_suffix" ]]; then
cache_tag+="-${arch_suffix}"
for i in "${!tags[@]}"; do
tags[$i]="${tags[$i]}-${arch_suffix}"
done
# Force single-platform build for this architecture
arch_platform="linux/${arch_suffix}"
fi
DOCKER_REPOSITORY="$DOCKER_REGISTRY/$DOCKER_ORG/$DOCKER_IMAGE"
DOCKER_REPOSITORY=${DOCKER_REPOSITORY,,} # lowercase
echo "Repo: $DOCKER_REPOSITORY"
echo "Base dir: $DOCKER_BASE_DIR"
echo "Tags: ${tags[@]}"
args=""
full_tags=()
@@ -136,6 +125,7 @@ for tag in "${tags[@]}"; do
full_tags+=("$DOCKER_REPOSITORY:$tag")
done
if [[ $push -eq 1 ]]; then
args+=" --push"
args+=" --cache-to=type=registry,ref=$DOCKER_REPOSITORY:$cache_tag,mode=max"
@@ -148,8 +138,8 @@ fi
echo "Args: $args"
# Determine the platform(s) to build for
if [[ -n "$arch_platform" ]]; then
platform="$arch_platform"
if [[ -n "$platform_override" ]]; then
platform="$platform_override"
elif [[ $load -eq 1 ]]; then
# When loading, build only for the current platform
platform=$(docker version -f '{{.Server.Os}}/{{.Server.Arch}}')
@@ -159,24 +149,13 @@ else
fi
if [[ $dry_run -eq 1 ]]; then
echo "Dry Run is enabled. Writing build config to docker-build-dry.json"
# Compute base tags (arch suffix stripped) for use by merge jobs
base_tags=()
for ftag in "${full_tags[@]}"; do
if [[ -n "$arch_suffix" ]]; then
base_tags+=("${ftag%-${arch_suffix}}")
else
base_tags+=("$ftag")
fi
done
jq -n \
--argjson tags "$(printf '%s\n' "${full_tags[@]}" | jq -R . | jq -s .)" \
--argjson base_tags "$(printf '%s\n' "${base_tags[@]}" | jq -R . | jq -s .)" \
--arg platform "$platform" \
--arg openhands_build_version "$OPENHANDS_BUILD_VERSION" \
--arg dockerfile "$dir/Dockerfile" \
'{
tags: $tags,
base_tags: $base_tags,
platform: $platform,
build_args: [
"OPENHANDS_BUILD_VERSION=" + $openhands_build_version
@@ -195,7 +174,7 @@ docker buildx build \
$args \
--build-arg OPENHANDS_BUILD_VERSION="$OPENHANDS_BUILD_VERSION" \
--cache-from=type=registry,ref=$DOCKER_REPOSITORY:$cache_tag \
--cache-from=type=registry,ref=$DOCKER_REPOSITORY:${cache_tag_base}-main${arch_suffix:+-${arch_suffix}} \
--cache-from=type=registry,ref=$DOCKER_REPOSITORY:$cache_tag_base-main \
--platform $platform \
--provenance=false \
-f "$dir/Dockerfile" \
-2
View File
@@ -1,7 +1,5 @@
# PolyForm Free Trial License 1.0.0
Copyright (c) 2026 All Hands AI
## Acceptance
In order to get any license under these terms, you must agree
+1 -1
View File
@@ -59,7 +59,7 @@ handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
level = DEBUG
handlers =
qualname = sqlalchemy.engine
@@ -1729,7 +1729,7 @@
"syncMode": "IMPORT",
"clientSecret": "$GITHUB_APP_CLIENT_SECRET",
"caseSensitiveOriginalUsername": "false",
"defaultScope": "openid email profile",
"defaultScope": "openid email profile notifications",
"baseUrl": "$GITHUB_BASE_URL"
}
},
+21 -4
View File
@@ -24,20 +24,20 @@ from integrations.jira.jira_types import (
RepositoryNotFoundError,
StartingConvoException,
)
from integrations.jira.jira_view import JiraFactory
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
from integrations.manager import Manager
from integrations.models import Message
from integrations.utils import (
HOST,
HOST_URL,
OPENHANDS_RESOLVER_TEMPLATES_DIR,
format_jira_comment_body,
get_oh_labels,
get_session_expired_message,
)
from jinja2 import Environment, FileSystemLoader
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
from server.auth.token_manager import TokenManager
from server.utils.conversation_callback_utils import register_callback_processor
from storage.jira_integration_store import JiraIntegrationStore
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
@@ -259,6 +259,11 @@ class JiraManager(Manager[JiraViewInterface]):
async def start_job(self, view: JiraViewInterface) -> None:
"""Start a Jira job/conversation."""
# Import here to prevent circular import
from server.conversation_callback_processor.jira_callback_processor import (
JiraCallbackProcessor,
)
try:
logger.info(
'[Jira] Starting job',
@@ -280,7 +285,19 @@ class JiraManager(Manager[JiraViewInterface]):
},
)
# Create success message
# Register callback processor for updates
if isinstance(view, JiraNewConversationView):
processor = JiraCallbackProcessor(
issue_key=view.payload.issue_key,
workspace_name=view.jira_workspace.name,
)
register_callback_processor(conversation_id, processor)
logger.info(
'[Jira] Callback processor registered',
extra={'conversation_id': conversation_id},
)
# Send success response
msg_info = view.get_response_msg()
except MissingSettingsError as e:
@@ -342,7 +359,7 @@ class JiraManager(Manager[JiraViewInterface]):
url = (
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
)
data = format_jira_comment_body(message)
data = {'body': message}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(
url, auth=(svc_acc_email, svc_acc_api_key), json=data
+5 -4
View File
@@ -136,10 +136,11 @@ class JiraPayloadParser:
items = changelog.get('items', [])
# Extract labels that were added
labels = set()
for item in items:
if item.get('field') == 'labels' and item.get('toString'):
labels.update(item['toString'].split())
labels = [
item.get('toString', '')
for item in items
if item.get('field') == 'labels' and 'toString' in item
]
if self.oh_label not in labels:
return JiraPayloadSkipped(
@@ -1,238 +0,0 @@
import logging
from uuid import UUID
import httpx
from integrations.utils import format_jira_comment_body, get_summary_instruction
from pydantic import Field
from openhands.agent_server.models import AskAgentRequest, AskAgentResponse
from openhands.app_server.event_callback.event_callback_models import (
EventCallback,
EventCallbackProcessor,
)
from openhands.app_server.event_callback.event_callback_result_models import (
EventCallbackResult,
EventCallbackResultStatus,
)
from openhands.app_server.event_callback.util import (
ensure_conversation_found,
ensure_running_sandbox,
get_agent_server_url_from_sandbox,
)
from openhands.sdk import Event
from openhands.sdk.event import ConversationStateUpdateEvent
from openhands.utils.http_session import httpx_verify_option
_logger = logging.getLogger(__name__)
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
class JiraV1CallbackProcessor(EventCallbackProcessor):
"""Callback processor for Jira V1 integrations."""
should_request_summary: bool = Field(default=True)
svc_acc_email: str
decrypted_api_key: str
issue_key: str
jira_cloud_id: str
async def __call__(
self,
conversation_id: UUID,
callback: EventCallback,
event: Event,
) -> EventCallbackResult | None:
"""Process events for Jira V1 integration."""
# Only handle ConversationStateUpdateEvent for execution_status
if not isinstance(event, ConversationStateUpdateEvent):
return None
if event.key != 'execution_status':
return None
_logger.info('[Jira] Callback agent state was %s', event)
# Only request summary when execution has finished successfully
if event.value != 'finished':
return None
_logger.info('[Jira] Should request summary: %s', self.should_request_summary)
if not self.should_request_summary:
return None
self.should_request_summary = False
try:
_logger.info(f'[Jira] Requesting summary {conversation_id}')
summary = await self._request_summary(conversation_id)
_logger.info(
f'[Jira] Posting summary {conversation_id}',
extra={'summary': summary},
)
await self._post_summary_to_jira(summary)
return EventCallbackResult(
status=EventCallbackResultStatus.SUCCESS,
event_callback_id=callback.id,
event_id=event.id,
conversation_id=conversation_id,
detail=summary,
)
except Exception as e:
_logger.exception(f'[Jira] Failed to post summary: {e}', stack_info=True)
return EventCallbackResult(
status=EventCallbackResultStatus.ERROR,
event_callback_id=callback.id,
event_id=event.id,
conversation_id=conversation_id,
detail=str(e),
)
async def _request_summary(self, conversation_id: UUID) -> str:
"""Ask the agent to produce a summary of its work and return the agent response."""
# Import services within the method to avoid circular imports
from openhands.app_server.config import (
get_app_conversation_info_service,
get_httpx_client,
get_sandbox_service,
)
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import (
ADMIN,
USER_CONTEXT_ATTR,
)
# Create injector state for dependency injection
state = InjectorState()
setattr(state, USER_CONTEXT_ATTR, ADMIN)
async with (
get_app_conversation_info_service(state) as app_conversation_info_service,
get_sandbox_service(state) as sandbox_service,
get_httpx_client(state) as httpx_client,
):
# 1. Conversation lookup
app_conversation_info = ensure_conversation_found(
await app_conversation_info_service.get_app_conversation_info(
conversation_id
),
conversation_id,
)
# 2. Sandbox lookup + validation
sandbox = ensure_running_sandbox(
await sandbox_service.get_sandbox(app_conversation_info.sandbox_id),
app_conversation_info.sandbox_id,
)
assert (
sandbox.session_api_key is not None
), f'No session API key for sandbox: {sandbox.id}'
# 3. URL + instruction
agent_server_url = get_agent_server_url_from_sandbox(sandbox)
# Prepare message based on agent state
message_content = get_summary_instruction()
# Ask the agent and return the response text
return await self._ask_question(
httpx_client=httpx_client,
agent_server_url=agent_server_url,
conversation_id=conversation_id,
session_api_key=sandbox.session_api_key,
message_content=message_content,
)
async def _ask_question(
self,
httpx_client: httpx.AsyncClient,
agent_server_url: str,
conversation_id: UUID,
session_api_key: str,
message_content: str,
) -> str:
"""Send a message to the agent server via the V1 API and return response text."""
send_message_request = AskAgentRequest(question=message_content)
url = (
f"{agent_server_url.rstrip('/')}"
f"/api/conversations/{conversation_id}/ask_agent"
)
headers = {'X-Session-API-Key': session_api_key}
payload = send_message_request.model_dump()
try:
response = await httpx_client.post(
url,
json=payload,
headers=headers,
timeout=30.0,
)
response.raise_for_status()
agent_response = AskAgentResponse.model_validate(response.json())
return agent_response.response
except httpx.HTTPStatusError as e:
error_detail = f'HTTP {e.response.status_code} error'
try:
error_body = e.response.text
if error_body:
error_detail += f': {error_body}'
except Exception:
pass
_logger.exception(
'[Jira] HTTP error sending message to %s: %s. '
'Request payload: %s. Response headers: %s',
url,
error_detail,
payload,
dict(e.response.headers),
stack_info=True,
)
raise Exception(f'Failed to send message to agent server: {error_detail}')
except httpx.TimeoutException:
error_detail = f'Request timeout after 30 seconds to {url}'
_logger.exception(
'[Jira] Timeout error: %s. Request payload: %s',
error_detail,
payload,
stack_info=True,
)
raise Exception(f'Failed to send message to agent server: {error_detail}')
async def _post_summary_to_jira(self, summary: str):
"""Post the summary back to the Jira issue."""
if not all(
[
self.svc_acc_email,
self.decrypted_api_key,
self.issue_key,
self.jira_cloud_id,
]
):
_logger.warning('[Jira] Missing required data for posting summary')
return
# Add a comment to the Jira issue with the summary
comment_url = (
f'{JIRA_CLOUD_API_URL}/{self.jira_cloud_id}'
f'/rest/api/2/issue/{self.issue_key}/comment'
)
message = f'OpenHands resolved this issue:\n\n{summary}'
comment_body = format_jira_comment_body(message)
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(
comment_url,
auth=(self.svc_acc_email, self.decrypted_api_key),
json=comment_body,
)
response.raise_for_status()
_logger.info(f'[Jira] Posted summary to {self.issue_key}')
+103 -142
View File
@@ -7,7 +7,7 @@ Views are responsible for:
"""
from dataclasses import dataclass, field
from uuid import UUID, uuid4
from uuid import uuid4
import httpx
from integrations.jira.jira_payload import JiraWebhookPayload
@@ -16,37 +16,25 @@ from integrations.jira.jira_types import (
RepositoryNotFoundError,
StartingConvoException,
)
from integrations.jira.jira_v1_callback_processor import (
JiraV1CallbackProcessor,
)
from integrations.resolver_context import ResolverUserContext
from integrations.resolver_org_router import resolve_org_for_repo
from integrations.utils import (
CONVERSATION_URL,
infer_repo_from_message,
)
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
from jinja2 import Environment
from server.config import get_config
from storage.jira_conversation import JiraConversation
from storage.jira_integration_store import JiraIntegrationStore
from storage.jira_user import JiraUser
from storage.jira_workspace import JiraWorkspace
from storage.saas_conversation_store import SaasConversationStore
from openhands.agent_server.models import SendMessageRequest
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartRequest,
AppConversationStartTaskStatus,
)
from openhands.app_server.config import get_app_conversation_service
from openhands.app_server.services.injector import InjectorState
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderHandler, ProviderType
from openhands.sdk import TextContent
from openhands.integrations.provider import ProviderHandler
from openhands.server.services.conversation_service import start_conversation
from openhands.server.user_auth.user_auth import UserAuth
from openhands.storage.data_models.conversation_metadata import (
ConversationMetadata,
ConversationTrigger,
)
from openhands.utils.conversation_summary import get_default_conversation_title
from openhands.utils.http_session import httpx_verify_option
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
@@ -66,7 +54,7 @@ class JiraNewConversationView(JiraViewInterface):
saas_user_auth: UserAuth
jira_user: JiraUser
jira_workspace: JiraWorkspace
selected_repo: str = ''
selected_repo: str | None = None
conversation_id: str = ''
# Lazy-loaded issue details (cached after first fetch)
@@ -76,9 +64,6 @@ class JiraNewConversationView(JiraViewInterface):
# Decrypted API key (set by factory)
_decrypted_api_key: str = field(default='', repr=False)
# Resolved org ID for V1 conversations
resolved_org_id: UUID | None = None
async def get_issue_details(self) -> tuple[str, str]:
"""Fetch issue details from Jira API (cached after first call).
@@ -184,131 +169,107 @@ class JiraNewConversationView(JiraViewInterface):
if not self.selected_repo:
raise StartingConvoException('No repository selected for this conversation')
jira_conversation = JiraConversation(
conversation_id=self.conversation_id,
issue_id=self.payload.issue_id,
issue_key=self.payload.issue_key,
jira_user_id=self.jira_user.id,
)
await integration_store.create_conversation(jira_conversation)
conversation_metadata = await self._create_v1_metadata()
await self._create_v1_conversation(jinja_env, conversation_metadata)
return self.conversation_id
async def _create_v1_metadata(self) -> ConversationMetadata:
"""Create conversation metadata for V1 conversations.
The JiraConversation mapping is saved to the integration store (above), but
V1 conversation metadata is managed by the app conversation system, not
the legacy conversation store.
"""
logger.info('[Jira]: Creating V1 metadata')
# Generate a dummy conversation for V1 (not saved to store)
self.conversation_id = uuid4().hex
self.resolved_org_id = await self._get_resolved_org_id()
return ConversationMetadata(
conversation_id=self.conversation_id,
selected_repository=self.selected_repo,
)
async def _create_v1_conversation(
self,
jinja_env: Environment,
conversation_metadata: ConversationMetadata,
):
"""Create conversation using the new V1 app conversation system."""
logger.info('[Jira]: Creating V1 conversation')
initial_user_text = await self._get_v1_initial_user_message(jinja_env)
# Create the initial message request
initial_message = SendMessageRequest(
role='user', content=[TextContent(text=initial_user_text)]
)
# Create the Jira V1 callback processor
jira_callback_processor = self._create_jira_v1_callback_processor()
injector_state = InjectorState()
# Create the V1 conversation start request
start_request = AppConversationStartRequest(
conversation_id=UUID(conversation_metadata.conversation_id),
system_message_suffix=None,
initial_message=initial_message,
selected_repository=self.selected_repo,
selected_branch=None,
git_provider=ProviderType.GITHUB,
title=f'Jira Issue {self.payload.issue_key}: {self._issue_title or "Unknown"}',
trigger=ConversationTrigger.JIRA,
processors=[jira_callback_processor],
)
# Set up the Jira user context for the V1 system
jira_user_context = ResolverUserContext(
saas_user_auth=self.saas_user_auth,
resolver_org_id=self.resolved_org_id,
)
setattr(injector_state, USER_CONTEXT_ATTR, jira_user_context)
async with get_app_conversation_service(
injector_state
) as app_conversation_service:
async for task in app_conversation_service.start_app_conversation(
start_request
):
if task.status == AppConversationStartTaskStatus.ERROR:
logger.error(f'Failed to start V1 conversation: {task.detail}')
raise RuntimeError(
f'Failed to start V1 conversation: {task.detail}'
)
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
"""Build the initial user message for V1 resolver conversations."""
issue_title, issue_description = await self.get_issue_details()
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
user_msg = user_msg_template.render(
issue_key=self.payload.issue_key,
issue_title=issue_title,
issue_description=issue_description,
user_message=self.payload.user_msg,
)
return user_msg
def _create_jira_v1_callback_processor(self):
"""Create a V1 callback processor for Jira integration."""
return JiraV1CallbackProcessor(
svc_acc_email=self.jira_workspace.svc_acc_email,
decrypted_api_key=self._decrypted_api_key,
issue_key=self.payload.issue_key,
jira_cloud_id=self.jira_workspace.jira_cloud_id,
)
async def _get_resolved_org_id(self) -> UUID | None:
"""Resolve the org ID for V1 conversations."""
provider_tokens = await self.saas_user_auth.get_provider_tokens()
if not provider_tokens:
return None
user_secrets = await self.saas_user_auth.get_secrets()
instructions, user_msg = await self._get_instructions(jinja_env)
try:
provider_handler = ProviderHandler(provider_tokens)
repository = await provider_handler.verify_repo_provider(self.selected_repo)
resolved_org_id = await resolve_org_for_repo(
provider=repository.git_provider.value,
full_repo_name=self.selected_repo,
keycloak_user_id=self.jira_user.keycloak_user_id,
user_id = self.jira_user.keycloak_user_id
# Resolve git provider from repository
resolved_git_provider = None
if provider_tokens:
try:
provider_handler = ProviderHandler(provider_tokens)
repository = await provider_handler.verify_repo_provider(
self.selected_repo
)
resolved_git_provider = repository.git_provider
except Exception as e:
logger.warning(
f'[Jira] Failed to resolve git provider for {self.selected_repo}: {e}'
)
# Resolve target org based on claimed git organizations
resolved_org_id = None
if resolved_git_provider and self.selected_repo:
try:
resolved_org_id = await resolve_org_for_repo(
provider=resolved_git_provider.value,
full_repo_name=self.selected_repo,
keycloak_user_id=user_id,
)
except Exception as e:
logger.warning(
f'[Jira] Failed to resolve org for {self.selected_repo}: {e}'
)
# Create the conversation store with resolver org routing
store = await SaasConversationStore.get_resolver_instance(
get_config(),
user_id,
resolved_org_id,
)
return resolved_org_id
conversation_id = uuid4().hex
conversation_metadata = ConversationMetadata(
trigger=ConversationTrigger.JIRA,
conversation_id=conversation_id,
title=get_default_conversation_title(conversation_id),
user_id=user_id,
selected_repository=self.selected_repo,
selected_branch=None,
git_provider=resolved_git_provider,
)
await store.save_metadata(conversation_metadata)
await start_conversation(
user_id=user_id,
git_provider_tokens=provider_tokens,
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
initial_user_msg=user_msg,
image_urls=None,
replay_json=None,
conversation_id=conversation_id,
conversation_metadata=conversation_metadata,
conversation_instructions=instructions,
)
self.conversation_id = conversation_id
logger.info(
'[Jira] Created conversation',
extra={
'conversation_id': self.conversation_id,
'issue_key': self.payload.issue_key,
'selected_repo': self.selected_repo,
'resolved_org_id': str(resolved_org_id)
if resolved_org_id
else None,
},
)
# Store Jira conversation mapping
jira_conversation = JiraConversation(
conversation_id=self.conversation_id,
issue_id=self.payload.issue_id,
issue_key=self.payload.issue_key,
jira_user_id=self.jira_user.id,
)
await integration_store.create_conversation(jira_conversation)
return self.conversation_id
except Exception as e:
logger.warning(
f'[Jira] Failed to resolve org for {self.selected_repo}: {e}'
if isinstance(e, StartingConvoException):
raise
logger.error(
'[Jira] Failed to create conversation',
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
exc_info=True,
)
return None
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
def get_response_msg(self) -> str:
"""Get the response message to send back to Jira."""
@@ -20,7 +20,6 @@ from integrations.utils import (
OPENHANDS_RESOLVER_TEMPLATES_DIR,
filter_potential_repos_by_user_msg,
get_session_expired_message,
markdown_to_jira_markup,
)
from jinja2 import Environment, FileSystemLoader
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
@@ -469,8 +468,7 @@ class JiraDcManager(Manager[JiraDcViewInterface]):
"""
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
# Convert standard Markdown to Jira Wiki Markup for proper rendering
data = {'body': markdown_to_jira_markup(message)}
data = {'body': message}
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
response = await client.post(url, headers=headers, json=data)
response.raise_for_status()
+6 -16
View File
@@ -16,23 +16,21 @@ from openhands.core.logger import openhands_logger as logger
async def resolve_org_for_repo(
provider: str,
full_repo_name: str,
keycloak_user_id: str | None = None,
keycloak_user_id: str,
) -> UUID | None:
"""Determine the OpenHands org_id for a resolver conversation.
If the repo's git organization is claimed by an OpenHands org, returns the
claiming org's ID. When keycloak_user_id is provided, also verifies the user
is a member of that org.
If the repo's git organization is claimed by an OpenHands org AND the user
is a member of that org, returns the claiming org's ID. Otherwise returns
None (caller should fall back to user.current_org_id / personal workspace).
Args:
provider: Git provider name ("github", "gitlab", "bitbucket")
full_repo_name: Full repository name (e.g., "OpenHands/foo")
keycloak_user_id: The user's Keycloak UUID string (optional). If provided,
membership is verified before returning the org_id.
keycloak_user_id: The user's Keycloak UUID string
Returns:
The org_id if the repo's org is claimed (and user is a member when
keycloak_user_id is provided), else None
The org_id if the repo's org is claimed and user is a member, else None
"""
git_org = full_repo_name.split('/')[0].lower()
@@ -46,14 +44,6 @@ async def resolve_org_for_repo(
)
return None
# Skip membership check if no user_id provided
if keycloak_user_id is None:
logger.info(
f'[OrgResolver] Resolved org {claim.org_id} '
f'for {provider}/{git_org} (no user membership check)',
)
return claim.org_id
member = await OrgMemberStore.get_org_member(
claim.org_id, UUID(keycloak_user_id)
)
+4 -21
View File
@@ -436,13 +436,12 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
)
# Use dict to preserve ordering
matches: dict[str, bool] = {}
matches: list[str] = []
# Git URLs first (highest priority)
for owner, repo in re.findall(git_url_pattern, normalized_msg):
repo = re.sub(r'\.git$', '', repo)
matches[f'{owner}/{repo}'] = True
matches.append(f'{owner}/{repo}')
# Direct mentions
for owner, repo in re.findall(direct_pattern, normalized_msg):
@@ -458,10 +457,9 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
continue
if full_match not in matches:
matches[full_match] = True
matches.append(full_match)
result = list(matches)
return result
return matches
def filter_potential_repos_by_user_msg(
@@ -597,18 +595,3 @@ def markdown_to_jira_markup(markdown_text: str) -> str:
# Log the error but don't raise it - return original text as fallback
print(f'Error converting markdown to Jira markup: {str(e)}')
return markdown_text or ''
def format_jira_comment_body(message: str) -> dict:
"""Format a message as a Jira API v2 comment body.
This helper ensures consistent comment formatting across all Jira integrations.
Converts markdown to Jira Wiki Markup and wraps in the expected API structure.
Args:
message: The message content to send (may contain markdown)
Returns:
dict: The comment body in Jira API v2 format {'body': ...}
"""
return {'body': markdown_to_jira_markup(message)}
-12
View File
@@ -6,12 +6,6 @@ from logging.config import fileConfig
# These plugin setup messages would otherwise appear before logging is configured
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
# Prevent SQLAlchemy engine from logging SQL results at DEBUG level, which can
# leak sensitive column data (e.g. API keys, tokens) into log aggregators.
# This is set before any engine is created so it takes effect immediately.
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.engine.Engine').setLevel(logging.WARNING)
from alembic import context # noqa: E402
from google.cloud.sql.connector import Connector # noqa: E402
from sqlalchemy import create_engine, text # noqa: E402
@@ -76,12 +70,6 @@ config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Re-apply SQLAlchemy engine log suppression after fileConfig, which may override
# our earlier settings from alembic.ini. This ensures DEBUG-level SQL result logging
# is always suppressed, preventing sensitive data from leaking into log aggregators.
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
logging.getLogger('sqlalchemy.engine.Engine').setLevel(logging.WARNING)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
@@ -6,6 +6,7 @@ Create Date: 2026-03-26
"""
import json
from typing import Sequence, Union
import sqlalchemy as sa
@@ -23,18 +24,18 @@ def upgrade() -> None:
# Migrate existing org-level MCP configs to all members in each org.
# This preserves existing configurations while transitioning to user-specific settings.
# Uses server-side SQL to avoid pulling sensitive config data into the Python process.
op.execute(
sa.text(
"""
UPDATE org_member
SET mcp_config = org.mcp_config
FROM org
WHERE org_member.org_id = org.id
AND org.mcp_config IS NOT NULL
"""
conn = op.get_bind()
orgs_with_config = conn.execute(
sa.text('SELECT id, mcp_config FROM org WHERE mcp_config IS NOT NULL')
).fetchall()
for org_id, mcp_config in orgs_with_config:
conn.execute(
sa.text(
'UPDATE org_member SET mcp_config = :config WHERE org_id = :org_id'
),
{'config': json.dumps(mcp_config), 'org_id': str(org_id)},
)
)
def downgrade() -> None:
@@ -1,31 +0,0 @@
"""Add onboarding_completed column to user table.
Tracks whether a user has completed the onboarding flow.
Used to redirect new SaaS users to /onboarding after accepting TOS.
Revision ID: 107
Revises: 106
Create Date: 2026-03-31
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '107'
down_revision: Union[str, None] = '106'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
'user',
sa.Column('onboarding_completed', sa.Boolean(), nullable=True, default=False),
)
def downgrade() -> None:
op.drop_column('user', 'onboarding_completed')
-9
View File
@@ -87,9 +87,6 @@ class Permission(str, Enum):
# Git organization claims
MANAGE_ORG_CLAIMS = 'manage_org_claims'
# Manage Automations
MANAGE_AUTOMATIONS = 'manage_automations'
class RoleName(str, Enum):
"""Role names used in the system."""
@@ -126,8 +123,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
Permission.DELETE_ORGANIZATION,
# Git organization claims
Permission.MANAGE_ORG_CLAIMS,
# Manage Automations
Permission.MANAGE_AUTOMATIONS,
]
),
RoleName.ADMIN: frozenset(
@@ -151,8 +146,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
Permission.EDIT_ORG_SETTINGS,
# Git organization claims
Permission.MANAGE_ORG_CLAIMS,
# Manage Automations
Permission.MANAGE_AUTOMATIONS,
]
),
RoleName.MEMBER: frozenset(
@@ -166,8 +159,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
# Settings (View only)
Permission.VIEW_ORG_SETTINGS,
Permission.VIEW_LLM_SETTINGS,
# Manage Automations
Permission.MANAGE_AUTOMATIONS,
]
),
}
-17
View File
@@ -56,23 +56,6 @@ RECAPTCHA_SITE_KEY = os.getenv('RECAPTCHA_SITE_KEY', '').strip()
RECAPTCHA_HMAC_SECRET = os.getenv('RECAPTCHA_HMAC_SECRET', '').strip()
RECAPTCHA_BLOCK_THRESHOLD = float(os.getenv('RECAPTCHA_BLOCK_THRESHOLD', '0.3'))
# Automation Service
AUTOMATION_SERVICE_URL = os.getenv('AUTOMATION_SERVICE_URL', '').strip()
if AUTOMATION_SERVICE_URL and not AUTOMATION_SERVICE_URL.startswith(
('http://', 'https://')
):
raise ValueError(
f'AUTOMATION_SERVICE_URL must start with http:// or https://, '
f'got: {AUTOMATION_SERVICE_URL}'
)
AUTOMATION_EVENT_FORWARDING_ENABLED = os.getenv(
'AUTOMATION_EVENT_FORWARDING_ENABLED', 'false'
) in ('1', 'true')
# Shared secret for signing payloads sent to automation service (separate from GitHub webhook secret)
AUTOMATION_WEBHOOK_SECRET = os.getenv('AUTOMATION_WEBHOOK_SECRET', '').strip()
# Default HTTP timeout for automation service requests (seconds)
AUTOMATION_SERVICE_TIMEOUT = int(os.getenv('AUTOMATION_SERVICE_TIMEOUT', '30'))
# Account Defender labels that indicate suspicious activity
SUSPICIOUS_LABELS = {
'SUSPICIOUS_LOGIN_ACTIVITY',
-2
View File
@@ -20,7 +20,6 @@ from server.auth.constants import (
GITLAB_APP_CLIENT_ID,
RECAPTCHA_SITE_KEY,
)
from server.constants import DEPLOYMENT_MODE
from openhands.core.config.utils import load_openhands_config
from openhands.integrations.service_types import ProviderType
@@ -180,7 +179,6 @@ class SaaSServerConfig(ServerConfig):
'ENABLE_JIRA': self.enable_jira,
'ENABLE_JIRA_DC': self.enable_jira_dc,
'ENABLE_LINEAR': self.enable_linear,
'DEPLOYMENT_MODE': DEPLOYMENT_MODE,
},
'PROVIDERS_CONFIGURED': providers_configured,
}
-27
View File
@@ -15,33 +15,6 @@ IS_FEATURE_ENV = (
) # Does not include the staging deployment
IS_LOCAL_ENV = bool(HOST == 'localhost')
# _is_all_hands_managed_domain() can be removed/replaced when a self-hosted specific
# env var is created (e.g is_self_hosted` or `deployment_mode`)
def _is_all_hands_managed_domain(host: str) -> bool:
"""Check if the host is an All-Hands managed domain."""
return (
host == 'app.all-hands.dev'
or host == 'app.openhands.ai'
or host.endswith('.all-hands.dev')
or host.endswith('.openhands.ai')
)
def _get_deployment_mode() -> str:
"""Determine deployment mode based on WEB_HOST.
Returns:
'cloud' for All-Hands managed infrastructure (app.all-hands.dev, etc.)
'self_hosted' for enterprise self-hosted deployments (customer domains)
"""
if _is_all_hands_managed_domain(HOST):
return 'cloud'
return 'self_hosted'
DEPLOYMENT_MODE = _get_deployment_mode()
# Role name constants
ROLE_OWNER = 'owner'
ROLE_ADMIN = 'admin'
+5 -139
View File
@@ -27,10 +27,7 @@ from server.auth.user.user_authorizer import (
depends_user_authorizer,
)
from server.config import sign_token
from server.constants import (
DEPLOYMENT_MODE,
IS_FEATURE_ENV,
)
from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV
from server.routes.event_webhook import _get_session_api_key, _get_user_id
from server.services.org_invitation_service import (
EmailMismatchError,
@@ -465,20 +462,8 @@ async def keycloak_callback(
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
response = RedirectResponse(tos_redirect_url, status_code=302)
else:
# User has accepted TOS - check if they need onboarding
# Only redirect to onboarding if user has a valid offline token,
# otherwise they need to complete the Keycloak offline token flow first
if valid_offline_token and await _should_redirect_to_onboarding(user_id, user):
redirect_url = f'{web_url}/onboarding'
logger.info(
'Redirecting returning user to onboarding',
extra={'user_id': user_id, 'deployment_mode': DEPLOYMENT_MODE},
)
if invitation_token:
if '?' in redirect_url:
redirect_url = f'{redirect_url}&invitation_success=true'
else:
redirect_url = f'{redirect_url}?invitation_success=true'
redirect_url = f'{redirect_url}&invitation_success=true'
response = RedirectResponse(redirect_url, status_code=302)
set_response_cookie(
@@ -486,7 +471,7 @@ async def keycloak_callback(
response=response,
keycloak_access_token=keycloak_access_token,
keycloak_refresh_token=keycloak_refresh_token,
secure=True if web_url.startswith('https') else False,
secure=True if redirect_url.startswith('https') else False,
accepted_tos=has_accepted_tos,
)
@@ -527,23 +512,8 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
user_id=user_info.sub, offline_token=keycloak_refresh_token
)
user = await UserStore.get_user_by_id(user_info.sub)
has_accepted_tos = user is not None and user.accepted_tos is not None
redirect_url, _, _ = _extract_oauth_state(state)
default_url = redirect_url if redirect_url else web_url
final_url = await _get_post_auth_redirect(user_info.sub, default_url, web_url, user)
response = RedirectResponse(final_url, status_code=302)
set_response_cookie(
request=request,
response=response,
keycloak_access_token=keycloak_access_token,
keycloak_refresh_token=keycloak_refresh_token,
secure=True if web_url.startswith('https') else False,
accepted_tos=has_accepted_tos,
)
return response
return RedirectResponse(redirect_url if redirect_url else web_url, status_code=302)
@oauth_router.get('/github/callback')
@@ -579,74 +549,6 @@ async def authenticate(request: Request):
return response
async def _should_redirect_to_onboarding(user_id: str, user: User) -> bool:
"""Check if user should be redirected to onboarding after TOS acceptance.
Backend always redirects applicable users to /onboarding. The frontend
checks the ENABLE_ONBOARDING feature flag (localStorage) and redirects
to / if the flag is disabled. This avoids needing helm chart changes.
Returns True if:
- User has onboarding_completed explicitly set to False (new users)
- Either:
- Deployment mode is 'cloud' (all users)
- Deployment mode is 'self_hosted' AND user is the super admin
(first owner in their current org to accept TOS)
Returns False if:
- User has onboarding_completed=True (already completed)
- User has onboarding_completed=None (existing users before this feature)
"""
# Already completed onboarding
if user.onboarding_completed is True:
return False
# Existing user before this feature (NULL in database)
if user.onboarding_completed is None:
return False
# Cloud SaaS: all users go to onboarding
if DEPLOYMENT_MODE == 'cloud':
return True
# Self-hosted SaaS: only the super admin (first owner to accept TOS in the org)
if DEPLOYMENT_MODE == 'self_hosted':
first_owner = await UserStore.get_first_owner_in_org(user.current_org_id)
if first_owner and str(first_owner.id) == user_id:
return True
return False
async def _get_post_auth_redirect(
user_id: str, default_url: str, web_url: str, user: User | None = None
) -> str:
"""Determine where to redirect user after authentication completes.
Called after offline token is stored to determine final redirect destination.
Checks for pending user flows (e.g., onboarding) before falling back to default.
Args:
user_id: The user's ID.
default_url: The default URL to redirect to if no special flow is needed.
web_url: The base web URL for constructing absolute paths.
user: Optional user object to avoid refetching.
Returns:
The URL to redirect the user to.
"""
if not user:
user = await UserStore.get_user_by_id(user_id)
if user and await _should_redirect_to_onboarding(user_id, user):
logger.info(
'Redirecting user to onboarding',
extra={'user_id': user_id, 'deployment_mode': DEPLOYMENT_MODE},
)
return f'{web_url}/onboarding'
return default_url
@api_router.post('/accept_tos')
async def accept_tos(request: Request):
user_auth = cast(SaasUserAuth, await get_user_auth(request))
@@ -687,12 +589,6 @@ async def accept_tos(request: Request):
logger.info(f'User {user_id} accepted TOS')
# Determine final redirect - but don't override if it's the offline token flow
# (the offline callback will handle post-auth redirect after storing the token)
is_offline_flow = 'offline' in redirect_url
if not is_offline_flow:
redirect_url = await _get_post_auth_redirect(user_id, redirect_url, web_url)
response = JSONResponse(
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
)
@@ -702,42 +598,12 @@ async def accept_tos(request: Request):
response=response,
keycloak_access_token=access_token.get_secret_value(),
keycloak_refresh_token=refresh_token.get_secret_value(),
secure=True if web_url.startswith('https') else False,
secure=not IS_LOCAL_ENV,
accepted_tos=True,
)
return response
@api_router.post('/complete_onboarding')
async def complete_onboarding(request: Request):
"""Mark onboarding as completed for the current user."""
user_auth = cast(SaasUserAuth, await get_user_auth(request))
user_id = await user_auth.get_user_id()
if not user_id:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={'error': 'User is not authenticated'},
)
user = await UserStore.mark_onboarding_completed(user_id)
if not user:
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content={'error': 'User not found'},
)
logger.info(
'User completed onboarding',
extra={'user_id': user_id},
)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={'message': 'Onboarding completed'},
)
@api_router.post('/logout')
async def logout(request: Request):
# Always create the response object first to ensure we can return it even if errors occur
+2 -21
View File
@@ -3,17 +3,13 @@ import hashlib
import hmac
import os
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request
from fastapi import APIRouter, Header, HTTPException, Request
from fastapi.responses import JSONResponse
from integrations.github.data_collector import GitHubDataCollector
from integrations.github.github_manager import GithubManager
from integrations.models import Message, SourceType
from server.auth.constants import (
AUTOMATION_EVENT_FORWARDING_ENABLED,
GITHUB_APP_WEBHOOK_SECRET,
)
from server.auth.constants import GITHUB_APP_WEBHOOK_SECRET
from server.auth.token_manager import TokenManager
from server.services.automation_event_service import AutomationEventService
from openhands.core.logger import openhands_logger as logger
@@ -26,7 +22,6 @@ github_integration_router = APIRouter(prefix='/integration')
token_manager = TokenManager()
data_collector = GitHubDataCollector()
github_manager = GithubManager(token_manager, data_collector)
automation_event_service = AutomationEventService(token_manager)
def verify_github_signature(payload: bytes, signature: str):
@@ -51,9 +46,7 @@ def verify_github_signature(payload: bytes, signature: str):
@github_integration_router.post('/github/events')
async def github_events(
request: Request,
background_tasks: BackgroundTasks,
x_hub_signature_256: str = Header(None),
x_github_event: str = Header(None),
):
# Check if GitHub webhooks are enabled
if not GITHUB_WEBHOOKS_ENABLED:
@@ -79,18 +72,6 @@ async def github_events(
content={'error': 'Installation ID is missing in the payload.'},
)
# Forward to automation service (fire-and-forget background task)
if AUTOMATION_EVENT_FORWARDING_ENABLED:
logger.info(
f'triggering forward_github_event with payload: {payload_data}, installation_id: {installation_id}'
)
background_tasks.add_task(
automation_event_service.forward_github_event,
payload=payload_data,
installation_id=installation_id,
)
# Existing resolver bot processing
message_payload = {'payload': payload_data, 'installation': installation_id}
message = Message(source=SourceType.GITHUB, message=message_payload)
await github_manager.receive_message(message)
+1 -6
View File
@@ -149,12 +149,7 @@ async def verify_jira_signature(body: bytes, signature: str, payload: dict):
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
if workspace_name is None:
logger.warning(
'[Jira] No workspace name found in webhook payload',
extra={
'payload': payload,
},
)
logger.warning('[Jira] No workspace name found in webhook payload')
raise HTTPException(
status_code=403, detail='Workspace name not found in payload'
)
@@ -1,448 +0,0 @@
"""
Service for forwarding GitHub webhook events to the automation service.
This service is optimized for high-traffic scenarios:
1. Resolves GitHub org → OpenHands org_id (via cached OrgGitClaim lookup)
2. For personal repos, resolves to personal org (via cached GitHub→Keycloak mapping)
3. Forwards minimal payload to automation service (just org_id + payload)
4. Access control checks are deferred to automation execution time
The lazy access control approach means:
- Most webhooks only do cached lookups + HTTP forward
- Membership checks only happen when an automation actually matches
Security notes:
- Uses AUTOMATION_WEBHOOK_SECRET (not GitHub webhook secret) for internal service signing
- Negative results are cached to prevent DoS via repeated lookups for unclaimed orgs
"""
import asyncio
import hashlib
import hmac
import json
from dataclasses import dataclass
from typing import Any
from uuid import UUID
import aiohttp
from integrations.resolver_org_router import resolve_org_for_repo
from server.auth.constants import (
AUTOMATION_SERVICE_TIMEOUT,
AUTOMATION_SERVICE_URL,
AUTOMATION_WEBHOOK_SECRET,
)
from server.auth.token_manager import TokenManager
from openhands.core.logger import openhands_logger as logger
from openhands.integrations.provider import ProviderType
from openhands.server.shared import sio
# Cache TTL constants
ORG_CLAIM_CACHE_TTL_SECONDS = 3600 # 1 hour for org claims (rarely change)
USER_ID_CACHE_TTL_SECONDS = 86400 # 24 hours for user ID mappings (never change)
# Cache key prefixes
ORG_CLAIM_CACHE_PREFIX = 'automation:org_claim'
USER_ID_CACHE_PREFIX = 'automation:gh_to_kc_user'
@dataclass
class OrgContext:
"""Context for the resolved organization."""
org_id: UUID
github_org: str
class AutomationEventService:
"""
Service for forwarding webhook events to the automation service.
Optimized for high traffic with:
- Redis caching for org claim lookups (1 hour TTL)
- Redis caching for GitHub→Keycloak user ID mappings (24 hour TTL)
- Lazy access control (membership checks deferred to execution time)
"""
def __init__(self, token_manager: TokenManager):
from server.auth.constants import AUTOMATION_EVENT_FORWARDING_ENABLED
self.token_manager = token_manager
# Fail fast if forwarding is enabled but misconfigured
if AUTOMATION_EVENT_FORWARDING_ENABLED:
if not AUTOMATION_SERVICE_URL:
raise ValueError(
'AUTOMATION_EVENT_FORWARDING_ENABLED=true but '
'AUTOMATION_SERVICE_URL is not configured'
)
if not AUTOMATION_WEBHOOK_SECRET:
raise ValueError(
'AUTOMATION_EVENT_FORWARDING_ENABLED=true but '
'AUTOMATION_WEBHOOK_SECRET is not configured'
)
async def forward_github_event(
self,
payload: dict[str, Any],
installation_id: int,
) -> None:
"""
Forward a GitHub webhook event to the automation service.
This is designed to be called as a fire-and-forget background task.
The forward path is optimized for speed - only org resolution is done here.
Access control checks are deferred to automation execution time.
Args:
payload: The raw GitHub webhook payload
installation_id: The GitHub App installation ID
"""
org_id: UUID | None = None
try:
logger.info(f'Retrieving org context for payload {payload}')
# Resolve org context (org_id and github_org name) - uses Redis cache
org_context = await self._resolve_org_context(payload)
logger.info(f'org context for payload {payload} is {org_context}')
if not org_context:
return
org_id = org_context.org_id
# Build minimal payload and forward immediately
# Access control is NOT computed here - it's deferred to execution time
event_payload = self._build_event_payload(org_context, payload)
logger.info(f'event_payload is {event_payload}')
await self._send_to_automation_service(org_id, event_payload)
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
# Network errors are expected and recoverable
logger.error(
f'[AutomationEventService] Network error forwarding event '
f'(org_id={org_id}): {e}',
exc_info=True,
extra={'installation_id': installation_id},
)
except Exception as e:
# Log unexpected errors. Note: This is a background task, so exceptions
# won't surface to the HTTP caller - they're logged for debugging only.
logger.error(
f'[AutomationEventService] Unexpected error forwarding event '
f'(org_id={org_id}): {e}',
exc_info=True,
extra={'installation_id': installation_id},
)
# Don't re-raise in background task - just log for debugging
async def _resolve_org_context(self, payload: dict[str, Any]) -> OrgContext | None:
"""
Resolve the organization context from the webhook payload.
Uses Redis caching for both org claims and user ID mappings.
Returns None if the org cannot be resolved (not claimed, no personal org).
"""
repo = payload.get('repository', {})
owner = repo.get('owner', {})
git_org_name = owner.get('login')
owner_type = owner.get('type') # 'User' or 'Organization'
if not git_org_name:
logger.warning(
'[AutomationEventService] No repository owner in payload, skipping'
)
return None
# Try to resolve via OrgGitClaim
org_id = await self._resolve_github_org(git_org_name)
# Fallback for personal repos
if not org_id and owner_type == 'User':
org_id = await self._resolve_personal_org(owner.get('id'))
if org_id:
logger.info(
f'[AutomationEventService] Resolved personal repo owner '
f'{git_org_name} to personal org {org_id}'
)
if not org_id:
logger.warning(
f'[AutomationEventService] GitHub org {git_org_name} not claimed '
f'and no personal org found, skipping'
)
return None
return OrgContext(org_id=org_id, github_org=git_org_name)
def _build_event_payload(
self,
org_context: OrgContext,
payload: dict[str, Any],
) -> dict[str, Any]:
"""
Build the minimal event payload to forward to the automation service.
Access control is NOT included here - it's deferred to execution time.
This keeps the forward path fast for high-traffic scenarios.
"""
return {
'organization': {
'github_org': org_context.github_org,
'openhands_org_id': str(org_context.org_id),
},
'payload': payload,
}
# =========================================================================
# Cached Org Resolution Methods
# =========================================================================
async def _resolve_github_org(self, git_org_name: str) -> UUID | None:
"""
Resolve a GitHub organization name to an OpenHands org_id.
Uses Redis caching with 1-hour TTL. Caches both positive and negative
results to avoid repeated DB queries for unclaimed orgs.
Note: GitHub org names are case-insensitive. We normalize to lowercase
for both cache keys and DB queries. This matches the OrgGitClaim schema
which stores git_organization as lowercase (enforced by GitOrgClaimRequest
validator in org_models.py).
"""
normalized_org = git_org_name.lower()
cache_key = f'{ORG_CLAIM_CACHE_PREFIX}:{normalized_org}'
# Check cache first
cached = await self._get_cached_value(cache_key)
if cached is not None:
if cached == 'none':
logger.debug(
f'[AutomationEventService] Cache hit (negative): org {git_org_name} not claimed'
)
return None
logger.debug(
f'[AutomationEventService] Cache hit: org {git_org_name} -> {cached}'
)
return UUID(cached)
# Cache miss - use resolve_org_for_repo without user_id (no membership check)
# Construct a minimal repo name since resolve_org_for_repo extracts the org
org_id = await resolve_org_for_repo(
provider='github',
full_repo_name=f'{normalized_org}/',
)
# Cache the result (including negative results)
if org_id:
await self._set_cached_value(
cache_key, str(org_id), ORG_CLAIM_CACHE_TTL_SECONDS
)
return org_id
else:
# Cache negative result to avoid repeated DB queries
await self._set_cached_value(cache_key, 'none', ORG_CLAIM_CACHE_TTL_SECONDS)
return None
async def _resolve_personal_org(self, github_user_id: int | None) -> UUID | None:
"""
Resolve a GitHub user to their personal OpenHands org.
For personal repos (owner type is 'User'), the OpenHands org_id
is the user's keycloak user ID. This allows users to set up
automations on their personal repos without needing an OrgGitClaim.
Uses Redis caching for the GitHub→Keycloak user ID mapping (24h TTL).
"""
if not github_user_id:
return None
keycloak_id = await self._get_keycloak_user_id_cached(github_user_id)
if keycloak_id:
return UUID(keycloak_id)
return None
async def _get_keycloak_user_id_cached(self, github_user_id: int) -> str | None:
"""
Convert a GitHub user ID to a Keycloak user ID.
Uses Redis caching with 24-hour TTL since this mapping never changes.
Caches negative results to avoid repeated Keycloak queries.
"""
cache_key = f'{USER_ID_CACHE_PREFIX}:{github_user_id}'
# Check cache first
cached = await self._get_cached_value(cache_key)
if cached is not None:
if cached == 'none':
logger.debug(
f'[AutomationEventService] Cache hit (negative): GitHub user {github_user_id} not in Keycloak'
)
return None
logger.debug(
f'[AutomationEventService] Cache hit: GitHub user {github_user_id} -> Keycloak {cached}'
)
return cached
# Cache miss - query Keycloak
try:
keycloak_id = await self.token_manager.get_user_id_from_idp_user_id(
str(github_user_id), ProviderType.GITHUB
)
# Cache the result (including negative results)
if keycloak_id:
await self._set_cached_value(
cache_key, keycloak_id, USER_ID_CACHE_TTL_SECONDS
)
else:
# Cache negative result to prevent repeated Keycloak queries (DoS mitigation)
await self._set_cached_value(
cache_key, 'none', USER_ID_CACHE_TTL_SECONDS
)
return keycloak_id
except Exception as e:
# Log at warning level to surface programmer errors and API issues
logger.warning(
f'[AutomationEventService] Failed to get keycloak ID for GitHub user {github_user_id}: {e}'
)
return None
# =========================================================================
# Generic Redis Cache Helpers
# =========================================================================
async def _get_cached_value(self, cache_key: str) -> str | None:
"""
Get a cached value from Redis.
Returns the cached string value, or None if not cached or Redis unavailable.
Falls back to DB/API queries if Redis is unavailable (graceful degradation).
Warning: When Redis is unavailable, every webhook will hit the DB directly.
Monitor logs for 'Redis unavailable' warnings to detect degradation.
"""
try:
redis = getattr(sio.manager, 'redis', None)
if not redis:
# Log at warning level - this is a significant degradation that
# will cause DB load. Monitor these logs for alerting.
logger.warning(
'[AutomationEventService] Redis unavailable for cache read, '
'falling back to direct DB queries (this will increase DB load)'
)
return None
cached = await redis.get(cache_key)
if cached is None:
return None
# Redis returns bytes, decode to string
return cached.decode('utf-8') if isinstance(cached, bytes) else cached
except Exception as e:
# Log at warning level - cache errors cause DB fallback
logger.warning(
f'[AutomationEventService] Redis cache read error (falling back to DB): {e}'
)
return None
async def _set_cached_value(
self, cache_key: str, value: str, ttl_seconds: int
) -> None:
"""
Set a cached value in Redis with TTL.
Fails silently if Redis is unavailable (graceful degradation).
"""
try:
redis = getattr(sio.manager, 'redis', None)
if not redis:
# Silent failure - read path already logs the warning
return
await redis.setex(cache_key, ttl_seconds, value)
except Exception as e:
# Log at warning level for visibility
logger.warning(f'[AutomationEventService] Redis cache write error: {e}')
def _sign_payload(self, payload_bytes: bytes) -> str:
"""
Sign a payload using the dedicated automation shared secret.
Uses AUTOMATION_WEBHOOK_SECRET (not GitHub webhook secret) to maintain
separate trust boundaries between GitHub webhooks and internal services.
Returns the signature in the format 'sha256=<hex_digest>'.
"""
signature = hmac.new(
AUTOMATION_WEBHOOK_SECRET.encode('utf-8'),
msg=payload_bytes,
digestmod=hashlib.sha256,
).hexdigest()
return f'sha256={signature}'
async def _send_to_automation_service(
self,
org_id: UUID,
payload: dict[str, Any],
) -> None:
"""
Send the normalized payload to the automation service.
The payload is signed using AUTOMATION_WEBHOOK_SECRET so the
automation service can verify it came from the OpenHands server.
"""
if not AUTOMATION_SERVICE_URL:
logger.warning(
'[AutomationEventService] AUTOMATION_SERVICE_URL not configured'
)
return
# Build endpoint URL. AUTOMATION_SERVICE_URL may include path segments
# (e.g., https://example.com/api/automation), so we strip trailing slash
# and append our path.
url = f'{AUTOMATION_SERVICE_URL.rstrip("/")}/v1/events/{org_id}/github'
# Serialize payload to JSON bytes for signing
payload_bytes = json.dumps(payload, separators=(',', ':')).encode('utf-8')
signature = self._sign_payload(payload_bytes)
headers = {
'Content-Type': 'application/json',
'X-Hub-Signature-256': signature,
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(
url,
data=payload_bytes,
headers=headers,
timeout=aiohttp.ClientTimeout(total=AUTOMATION_SERVICE_TIMEOUT),
) as resp:
if resp.status >= 400:
# Try JSON first (expected interface), fall back to text
# for infrastructure errors (502/503 from load balancer)
try:
body = await resp.json()
except (aiohttp.ContentTypeError, ValueError):
body = await resp.text()
logger.warning(
f'[AutomationEventService] Automation service returned '
f'{resp.status} for org {org_id}: {body}'
)
else:
data = await resp.json()
matched = data.get('matched', 0)
logger.info(
f'[AutomationEventService] Forwarded event to org {org_id}: '
f'{matched} automations matched'
)
except asyncio.TimeoutError:
logger.warning(
f'[AutomationEventService] Timeout ({AUTOMATION_SERVICE_TIMEOUT}s) '
'forwarding to automation service'
)
except aiohttp.ClientError as e:
logger.warning(
f'[AutomationEventService] HTTP error forwarding to automation service: {e}'
)
@@ -1,12 +1,4 @@
"""Shared Event router for OpenHands Server.
All endpoints in this router are unauthenticated — shared conversations are
public. To avoid returning internal system state that the viewer does not
need, ``ConversationStateUpdateEvent`` instances are filtered out before the
response is sent. The shared-conversation frontend only renders messages,
actions, observations, errors, and hook-execution events; state snapshots
are consumed exclusively by the authenticated WebSocket path.
"""
"""Shared Event router for OpenHands Server."""
from datetime import datetime
from typing import Annotated
@@ -21,15 +13,9 @@ from server.sharing.shared_event_service import (
from openhands.agent_server.models import EventPage, EventSortOrder
from openhands.app_server.event_callback.event_callback_models import EventKind
from openhands.sdk import Event
from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent
from openhands.utils.environment import StorageProvider, get_storage_provider
def _is_viewable(event: Event) -> bool:
"""Return True if *event* should be included in public shared responses."""
return not isinstance(event, ConversationStateUpdateEvent)
def get_shared_event_service_injector() -> SharedEventServiceInjector:
"""Get the appropriate SharedEventServiceInjector based on configuration.
@@ -101,36 +87,15 @@ async def search_shared_events(
] = 100,
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> EventPage:
"""Search / List events for a shared conversation.
Because non-viewable events (e.g. ``ConversationStateUpdateEvent``) are
filtered out after fetching, a single backend page may yield fewer items
than *limit*. This method transparently fetches additional backend pages
until the requested *limit* is reached or there are no more results.
"""
conv_id = UUID(conversation_id)
viewable: list[Event] = []
cursor = page_id
while len(viewable) < limit:
remaining = limit - len(viewable)
page = await shared_event_service.search_shared_events(
conversation_id=conv_id,
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=cursor,
limit=remaining,
)
viewable.extend(e for e in page.items if _is_viewable(e))
cursor = page.next_page_id
if cursor is None:
break
return EventPage(
items=viewable[:limit],
next_page_id=cursor,
"""Search / List events for a shared conversation."""
return await shared_event_service.search_shared_events(
conversation_id=UUID(conversation_id),
kind__eq=kind__eq,
timestamp__gte=timestamp__gte,
timestamp__lt=timestamp__lt,
sort_order=sort_order,
page_id=page_id,
limit=limit,
)
@@ -182,7 +147,7 @@ async def batch_get_shared_events(
events = await shared_event_service.batch_get_shared_events(
UUID(conversation_id), event_ids
)
return [e if e is not None and _is_viewable(e) else None for e in events]
return events
@router.get('/{conversation_id}/{event_id}')
@@ -192,9 +157,6 @@ async def get_shared_event(
shared_event_service: SharedEventService = shared_event_service_dependency,
) -> Event | None:
"""Get a single event from a shared conversation by conversation_id and event_id."""
event = await shared_event_service.get_shared_event(
return await shared_event_service.get_shared_event(
UUID(conversation_id), UUID(event_id)
)
if event is not None and not _is_viewable(event):
return None
return event
-1
View File
@@ -36,7 +36,6 @@ class User(Base): # type: ignore
git_user_email = Column(String, nullable=True)
sandbox_grouping_strategy = Column(String, nullable=True)
disabled_skills = Column(JSON, nullable=True)
onboarding_completed = Column(Boolean, nullable=True, default=False)
# Relationships
role = relationship('Role', back_populates='users')
-60
View File
@@ -24,7 +24,6 @@ from storage.encrypt_utils import (
)
from storage.org import Org
from storage.org_member import OrgMember
from storage.role import Role
from storage.role_store import RoleStore
from storage.user import User
from storage.user_settings import UserSettings
@@ -750,65 +749,6 @@ class UserStore:
await session.refresh(user)
return user
@staticmethod
async def mark_onboarding_completed(user_id: str) -> Optional[User]:
"""Mark the user's onboarding as completed.
Args:
user_id: The user's ID (Keycloak user ID)
Returns:
User: The updated user object, or None if user not found
"""
async with a_session_maker() as session:
result = await session.execute(
select(User).filter(User.id == uuid.UUID(user_id)).with_for_update()
)
user = result.scalars().first()
if not user:
logger.warning(
'mark_onboarding_completed:user_not_found',
extra={'user_id': user_id},
)
return None
user.onboarding_completed = True
await session.commit()
await session.refresh(user)
logger.info(
'mark_onboarding_completed:success',
extra={'user_id': user_id},
)
return user
@staticmethod
async def get_first_owner_in_org(org_id: UUID) -> Optional[User]:
"""Get the first owner in an organization who accepted the Terms of Service.
This user is considered the super admin for that org in self-hosted deployments.
The super admin is identified as the owner with the earliest accepted_tos timestamp.
Args:
org_id: The organization UUID
Returns:
User: The first owner to accept TOS in this org, or None if not found.
"""
async with a_session_maker() as session:
result = await session.execute(
select(User)
.join(OrgMember, OrgMember.user_id == User.id)
.join(Role, Role.id == OrgMember.role_id)
.filter(
OrgMember.org_id == org_id,
Role.name == 'owner',
User.accepted_tos.isnot(None),
)
.order_by(User.accepted_tos.asc())
.limit(1)
)
return result.scalars().first()
@staticmethod
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
"""Update contact_name on the personal org if it still has a username-style value.
@@ -206,7 +206,7 @@ def new_conversation_view(
sample_webhook_payload, sample_user_auth, sample_jira_user, sample_jira_workspace
):
"""JiraNewConversationView instance for testing"""
view = JiraNewConversationView(
return JiraNewConversationView(
payload=sample_webhook_payload,
saas_user_auth=sample_user_auth,
jira_user=sample_jira_user,
@@ -215,8 +215,6 @@ def new_conversation_view(
conversation_id='conv-123',
_decrypted_api_key='decrypted_key',
)
view.v1_enabled = False
return view
@pytest.fixture
@@ -202,10 +202,14 @@ class TestStartJob:
)
jira_manager._send_comment = AsyncMock()
await jira_manager.start_job(new_conversation_view)
with patch(
'integrations.jira.jira_manager.register_callback_processor'
) as mock_register:
await jira_manager.start_job(new_conversation_view)
new_conversation_view.create_or_update_conversation.assert_called_once()
jira_manager._send_comment.assert_called_once()
new_conversation_view.create_or_update_conversation.assert_called_once()
mock_register.assert_called_once()
jira_manager._send_comment.assert_called_once()
@pytest.mark.asyncio
async def test_start_job_missing_settings_error(
@@ -1,368 +0,0 @@
"""Tests for JiraV1CallbackProcessor.
This module tests the V1 callback processor that handles Jira integration
callbacks when conversations complete.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID, uuid4
import httpx
import pytest
from integrations.jira.jira_v1_callback_processor import (
JIRA_CLOUD_API_URL,
JiraV1CallbackProcessor,
)
from openhands.app_server.event_callback.event_callback_models import EventCallback
from openhands.app_server.event_callback.event_callback_result_models import (
EventCallbackResultStatus,
)
from openhands.sdk.event import ConversationStateUpdateEvent
@pytest.fixture
def callback_processor():
"""Create a JiraV1CallbackProcessor for testing."""
return JiraV1CallbackProcessor(
svc_acc_email='service@example.com',
decrypted_api_key='test_api_key',
issue_key='TEST-123',
jira_cloud_id='cloud-123',
)
@pytest.fixture
def mock_event_callback():
"""Create a mock EventCallback."""
callback = MagicMock(spec=EventCallback)
callback.id = UUID('12345678-1234-5678-1234-567812345678')
callback.conversation_id = UUID('87654321-4321-8765-4321-876543218765')
return callback
@pytest.fixture
def finished_event():
"""Create a ConversationStateUpdateEvent for finished state."""
return ConversationStateUpdateEvent(
id='event-123',
key='execution_status',
value='finished',
)
@pytest.fixture
def running_event():
"""Create a ConversationStateUpdateEvent for running state."""
return ConversationStateUpdateEvent(
id='event-456',
key='execution_status',
value='running',
)
class TestJiraV1CallbackProcessor:
"""Tests for JiraV1CallbackProcessor."""
@pytest.mark.asyncio
async def test_ignores_non_conversation_state_events(
self, callback_processor, mock_event_callback
):
"""Test that non-ConversationStateUpdateEvent events are ignored."""
# Use a different event type (mock)
other_event = MagicMock()
other_event.__class__ = object # Not ConversationStateUpdateEvent
result = await callback_processor(
conversation_id=uuid4(),
callback=mock_event_callback,
event=other_event,
)
assert result is None
@pytest.mark.asyncio
async def test_ignores_non_execution_status_keys(
self, callback_processor, mock_event_callback
):
"""Test that events with keys other than 'execution_status' are ignored."""
event = ConversationStateUpdateEvent(
id='event-123',
key='agent_status', # Different key
value='finished',
)
result = await callback_processor(
conversation_id=uuid4(),
callback=mock_event_callback,
event=event,
)
assert result is None
@pytest.mark.asyncio
async def test_ignores_non_finished_status(
self, callback_processor, mock_event_callback, running_event
):
"""Test that non-finished execution statuses are ignored."""
result = await callback_processor(
conversation_id=uuid4(),
callback=mock_event_callback,
event=running_event,
)
assert result is None
@pytest.mark.asyncio
async def test_only_requests_summary_once(
self, callback_processor, mock_event_callback, finished_event
):
"""Test that summary is only requested once (should_request_summary flag)."""
callback_processor.should_request_summary = False
result = await callback_processor(
conversation_id=uuid4(),
callback=mock_event_callback,
event=finished_event,
)
assert result is None
@pytest.mark.asyncio
@patch.object(JiraV1CallbackProcessor, '_request_summary')
@patch.object(JiraV1CallbackProcessor, '_post_summary_to_jira')
async def test_successful_summary_flow(
self,
mock_post_summary,
mock_request_summary,
callback_processor,
mock_event_callback,
finished_event,
):
"""Test successful summary request and posting flow."""
conversation_id = uuid4()
mock_request_summary.return_value = 'Test summary content'
mock_post_summary.return_value = None
result = await callback_processor(
conversation_id=conversation_id,
callback=mock_event_callback,
event=finished_event,
)
assert result is not None
assert result.status == EventCallbackResultStatus.SUCCESS
assert result.detail == 'Test summary content'
assert callback_processor.should_request_summary is False
mock_request_summary.assert_called_once_with(conversation_id)
mock_post_summary.assert_called_once_with('Test summary content')
@pytest.mark.asyncio
@patch.object(JiraV1CallbackProcessor, '_request_summary')
async def test_error_handling_on_summary_request_failure(
self,
mock_request_summary,
callback_processor,
mock_event_callback,
finished_event,
):
"""Test error handling when summary request fails."""
conversation_id = uuid4()
mock_request_summary.side_effect = Exception('Agent server unavailable')
result = await callback_processor(
conversation_id=conversation_id,
callback=mock_event_callback,
event=finished_event,
)
assert result is not None
assert result.status == EventCallbackResultStatus.ERROR
assert 'Agent server unavailable' in result.detail
@pytest.mark.asyncio
@patch.object(JiraV1CallbackProcessor, '_request_summary')
@patch.object(JiraV1CallbackProcessor, '_post_summary_to_jira')
async def test_error_handling_on_post_failure(
self,
mock_post_summary,
mock_request_summary,
callback_processor,
mock_event_callback,
finished_event,
):
"""Test error handling when posting to Jira fails."""
conversation_id = uuid4()
mock_request_summary.return_value = 'Test summary'
mock_post_summary.side_effect = Exception('Jira API error')
result = await callback_processor(
conversation_id=conversation_id,
callback=mock_event_callback,
event=finished_event,
)
assert result is not None
assert result.status == EventCallbackResultStatus.ERROR
assert 'Jira API error' in result.detail
class TestPostSummaryToJira:
"""Tests for _post_summary_to_jira method."""
@pytest.mark.asyncio
async def test_skips_when_missing_credentials(self, callback_processor):
"""Test that posting is skipped when credentials are missing."""
callback_processor.svc_acc_email = ''
# Should not raise, just log and return
await callback_processor._post_summary_to_jira('Test summary')
@pytest.mark.asyncio
async def test_skips_when_missing_issue_key(self, callback_processor):
"""Test that posting is skipped when issue key is missing."""
callback_processor.issue_key = ''
# Should not raise, just log and return
await callback_processor._post_summary_to_jira('Test summary')
@pytest.mark.asyncio
async def test_skips_when_missing_cloud_id(self, callback_processor):
"""Test that posting is skipped when cloud ID is missing."""
callback_processor.jira_cloud_id = ''
# Should not raise, just log and return
await callback_processor._post_summary_to_jira('Test summary')
@pytest.mark.asyncio
@patch('httpx.AsyncClient')
async def test_posts_comment_with_correct_format(
self, mock_async_client, callback_processor
):
"""Test that comment is posted with correct format (plain string body)."""
mock_response = MagicMock()
mock_response.raise_for_status = MagicMock()
mock_client_instance = AsyncMock()
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_async_client.return_value.__aenter__.return_value = mock_client_instance
await callback_processor._post_summary_to_jira('Test summary content')
# Verify the post was called
mock_client_instance.post.assert_called_once()
call_args = mock_client_instance.post.call_args
# Check URL
expected_url = (
f'{JIRA_CLOUD_API_URL}/cloud-123/rest/api/2/issue/TEST-123/comment'
)
assert call_args[0][0] == expected_url
# Check that body contains the summary message
json_body = call_args[1]['json']
assert 'body' in json_body
assert 'OpenHands resolved this issue' in json_body['body']
assert 'Test summary content' in json_body['body']
# Check auth
assert call_args[1]['auth'] == ('service@example.com', 'test_api_key')
@pytest.mark.asyncio
@patch('httpx.AsyncClient')
async def test_raises_on_http_error(self, mock_async_client, callback_processor):
"""Test that HTTP errors are propagated."""
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
'Bad Request',
request=MagicMock(),
response=MagicMock(status_code=400),
)
mock_client_instance = AsyncMock()
mock_client_instance.post = AsyncMock(return_value=mock_response)
mock_async_client.return_value.__aenter__.return_value = mock_client_instance
with pytest.raises(httpx.HTTPStatusError):
await callback_processor._post_summary_to_jira('Test summary')
class TestAskQuestion:
"""Tests for _ask_question method."""
@pytest.mark.asyncio
async def test_sends_request_with_correct_payload(self, callback_processor):
"""Test that ask_question sends correct request to agent server."""
mock_httpx_client = AsyncMock()
mock_response = MagicMock()
mock_response.json.return_value = {'response': 'Agent response text'}
mock_response.raise_for_status = MagicMock()
mock_httpx_client.post = AsyncMock(return_value=mock_response)
conversation_id = uuid4()
agent_server_url = 'http://localhost:8000'
session_api_key = 'test_session_key'
message_content = 'Please summarize your work'
result = await callback_processor._ask_question(
httpx_client=mock_httpx_client,
agent_server_url=agent_server_url,
conversation_id=conversation_id,
session_api_key=session_api_key,
message_content=message_content,
)
assert result == 'Agent response text'
# Verify request
mock_httpx_client.post.assert_called_once()
call_args = mock_httpx_client.post.call_args
expected_url = (
f'{agent_server_url}/api/conversations/{conversation_id}/ask_agent'
)
assert call_args[0][0] == expected_url
assert call_args[1]['headers'] == {'X-Session-API-Key': session_api_key}
assert call_args[1]['json'] == {'question': message_content}
assert call_args[1]['timeout'] == 30.0
@pytest.mark.asyncio
async def test_handles_http_error(self, callback_processor):
"""Test that HTTP errors are handled and wrapped."""
mock_httpx_client = AsyncMock()
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.text = 'Internal Server Error'
mock_response.headers = {}
mock_error = httpx.HTTPStatusError(
'Server Error',
request=MagicMock(),
response=mock_response,
)
mock_httpx_client.post = AsyncMock(side_effect=mock_error)
with pytest.raises(Exception, match='Failed to send message to agent server'):
await callback_processor._ask_question(
httpx_client=mock_httpx_client,
agent_server_url='http://localhost:8000',
conversation_id=uuid4(),
session_api_key='test_key',
message_content='test message',
)
@pytest.mark.asyncio
async def test_handles_timeout(self, callback_processor):
"""Test that timeout errors are handled and wrapped."""
mock_httpx_client = AsyncMock()
mock_httpx_client.post = AsyncMock(
side_effect=httpx.TimeoutException('Timeout')
)
with pytest.raises(Exception, match='Failed to send message to agent server'):
await callback_processor._ask_question(
httpx_client=mock_httpx_client,
agent_server_url='http://localhost:8000',
conversation_id=uuid4(),
session_api_key='test_key',
message_content='test message',
)
@@ -3,6 +3,7 @@ Tests for Jira view classes and factory.
"""
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import UUID
import pytest
from integrations.jira.jira_payload import (
@@ -18,6 +19,9 @@ from integrations.jira.jira_view import (
JiraNewConversationView,
)
from openhands.integrations.service_types import ProviderType
from openhands.server.user_auth.user_auth import UserAuth
class TestJiraNewConversationView:
"""Tests for JiraNewConversationView"""
@@ -85,6 +89,51 @@ class TestJiraNewConversationView:
assert 'TEST-123' in user_msg
assert 'Test Issue' in user_msg
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.resolve_org_for_repo', new_callable=AsyncMock)
@patch('integrations.jira.jira_view.ProviderHandler')
@patch(
'integrations.jira.jira_view.SaasConversationStore.get_resolver_instance',
new_callable=AsyncMock,
)
@patch('integrations.jira.jira_view.start_conversation', new_callable=AsyncMock)
@patch('integrations.jira.jira_view.integration_store')
async def test_create_or_update_conversation_success(
self,
mock_integration_store,
mock_start_convo,
mock_get_resolver_instance,
mock_provider_handler_cls,
mock_resolve_org,
new_conversation_view,
mock_jinja_env,
):
"""Test successful conversation creation"""
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'Test description'
mock_repo = MagicMock()
mock_repo.git_provider = ProviderType.GITHUB
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(return_value=mock_repo)
mock_provider_handler_cls.return_value = mock_handler
mock_resolve_org.return_value = None
mock_store = MagicMock()
mock_store.save_metadata = AsyncMock()
mock_get_resolver_instance.return_value = mock_store
mock_integration_store.create_conversation = AsyncMock()
result = await new_conversation_view.create_or_update_conversation(
mock_jinja_env
)
assert result is not None
assert isinstance(result, str)
assert len(result) == 32 # uuid4().hex format
mock_start_convo.assert_called_once()
mock_integration_store.create_conversation.assert_called_once()
@pytest.mark.asyncio
async def test_create_or_update_conversation_no_repo(
self, new_conversation_view, mock_jinja_env
@@ -323,6 +372,125 @@ class TestJiraFactory:
)
CLAIMING_ORG_ID = UUID('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')
class TestJiraV0ConversationRouting:
"""Test V0 conversation routing logic based on claimed git organizations."""
@pytest.fixture
def routing_view(
self,
sample_webhook_payload,
sample_jira_user,
sample_jira_workspace,
):
"""View with non-empty provider tokens for routing tests."""
user_auth = MagicMock(spec=UserAuth)
user_auth.get_provider_tokens = AsyncMock(
return_value={ProviderType.GITHUB: MagicMock()}
)
user_auth.get_secrets = AsyncMock(return_value=None)
return JiraNewConversationView(
payload=sample_webhook_payload,
saas_user_auth=user_auth,
jira_user=sample_jira_user,
jira_workspace=sample_jira_workspace,
selected_repo='test/repo1',
_issue_title='Test Issue',
_issue_description='Test description',
_decrypted_api_key='decrypted_key',
)
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.resolve_org_for_repo', new_callable=AsyncMock)
@patch('integrations.jira.jira_view.ProviderHandler')
@patch(
'integrations.jira.jira_view.SaasConversationStore.get_resolver_instance',
new_callable=AsyncMock,
)
@patch('integrations.jira.jira_view.start_conversation', new_callable=AsyncMock)
@patch('integrations.jira.jira_view.integration_store')
async def test_routes_to_claimed_org_when_user_is_member(
self,
mock_integration_store,
mock_start_convo,
mock_get_resolver_instance,
mock_provider_handler_cls,
mock_resolve_org,
routing_view,
mock_jinja_env,
):
"""When repo belongs to a claimed org and user is a member, conversation is created in that org."""
# Arrange
mock_repo = MagicMock()
mock_repo.git_provider = ProviderType.GITHUB
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(return_value=mock_repo)
mock_provider_handler_cls.return_value = mock_handler
mock_resolve_org.return_value = CLAIMING_ORG_ID
mock_store = MagicMock()
mock_store.save_metadata = AsyncMock()
mock_get_resolver_instance.return_value = mock_store
mock_integration_store.create_conversation = AsyncMock()
# Act
await routing_view.create_or_update_conversation(mock_jinja_env)
# Assert
mock_resolve_org.assert_called_once_with(
provider='github',
full_repo_name='test/repo1',
keycloak_user_id='test_keycloak_id',
)
call_args = mock_get_resolver_instance.call_args
assert call_args[0][1] == 'test_keycloak_id' # user_id
assert call_args[0][2] == CLAIMING_ORG_ID # resolver_org_id
saved_metadata = mock_store.save_metadata.call_args[0][0]
assert saved_metadata.git_provider == ProviderType.GITHUB
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.resolve_org_for_repo', new_callable=AsyncMock)
@patch('integrations.jira.jira_view.ProviderHandler')
@patch(
'integrations.jira.jira_view.SaasConversationStore.get_resolver_instance',
new_callable=AsyncMock,
)
@patch('integrations.jira.jira_view.start_conversation', new_callable=AsyncMock)
@patch('integrations.jira.jira_view.integration_store')
async def test_falls_back_to_personal_workspace_when_no_claim(
self,
mock_integration_store,
mock_start_convo,
mock_get_resolver_instance,
mock_provider_handler_cls,
mock_resolve_org,
routing_view,
mock_jinja_env,
):
"""When no org has claimed the git org, conversation goes to personal workspace."""
# Arrange
mock_repo = MagicMock()
mock_repo.git_provider = ProviderType.GITHUB
mock_handler = MagicMock()
mock_handler.verify_repo_provider = AsyncMock(return_value=mock_repo)
mock_provider_handler_cls.return_value = mock_handler
mock_resolve_org.return_value = None
mock_store = MagicMock()
mock_store.save_metadata = AsyncMock()
mock_get_resolver_instance.return_value = mock_store
mock_integration_store.create_conversation = AsyncMock()
# Act
await routing_view.create_or_update_conversation(mock_jinja_env)
# Assert
call_args = mock_get_resolver_instance.call_args
assert call_args[0][2] is None # resolver_org_id is None
class TestJiraPayloadParser:
"""Tests for JiraPayloadParser"""
@@ -438,164 +606,3 @@ class TestJiraPayloadParserStagingLabels:
result = staging_parser.parse(payload)
assert isinstance(result, JiraPayloadSkipped)
class TestJiraV1Conversation:
"""Tests for V1 conversation creation and callback processor registration."""
@pytest.mark.asyncio
async def test_create_v1_metadata_generates_conversation_id(
self, new_conversation_view
):
"""Test that _create_v1_metadata generates a new conversation ID."""
new_conversation_view.conversation_id = ''
with patch.object(
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
) as mock_get_org:
mock_get_org.return_value = None
metadata = await new_conversation_view._create_v1_metadata()
# Conversation ID should be generated
assert new_conversation_view.conversation_id != ''
assert len(new_conversation_view.conversation_id) == 32 # UUID hex format
assert metadata.conversation_id == new_conversation_view.conversation_id
mock_get_org.assert_called_once()
@pytest.mark.asyncio
async def test_create_v1_metadata_sets_resolved_org(self, new_conversation_view):
"""Test that _create_v1_metadata sets resolved_org_id."""
from uuid import UUID
test_org_id = UUID('12345678-1234-5678-1234-567812345678')
with patch.object(
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
) as mock_get_org:
mock_get_org.return_value = test_org_id
await new_conversation_view._create_v1_metadata()
assert new_conversation_view.resolved_org_id == test_org_id
def test_create_jira_v1_callback_processor(
self, new_conversation_view, sample_jira_workspace
):
"""Test that _create_jira_v1_callback_processor creates correctly configured processor."""
from integrations.jira.jira_v1_callback_processor import JiraV1CallbackProcessor
processor = new_conversation_view._create_jira_v1_callback_processor()
assert isinstance(processor, JiraV1CallbackProcessor)
assert processor.svc_acc_email == sample_jira_workspace.svc_acc_email
assert processor.decrypted_api_key == 'decrypted_key'
assert processor.issue_key == 'TEST-123'
assert processor.jira_cloud_id == sample_jira_workspace.jira_cloud_id
@pytest.mark.asyncio
async def test_get_v1_initial_user_message(
self, new_conversation_view, mock_jinja_env
):
"""Test _get_v1_initial_user_message renders the template correctly."""
new_conversation_view._issue_title = 'Test Bug'
new_conversation_view._issue_description = 'Description of the bug'
message = await new_conversation_view._get_v1_initial_user_message(
mock_jinja_env
)
assert 'TEST-123' in message
assert 'Test Bug' in message
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.get_app_conversation_service')
@patch('integrations.jira.jira_view.integration_store')
async def test_create_or_update_conversation_v1_flow(
self,
mock_store,
mock_get_service,
new_conversation_view,
mock_jinja_env,
):
"""Test create_or_update_conversation creates V1 conversation correctly."""
from unittest.mock import AsyncMock, MagicMock
from uuid import UUID
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTaskStatus,
)
# Setup mocks
mock_store.create_conversation = AsyncMock()
# Mock the app conversation service
mock_service = AsyncMock()
async def mock_start_generator(*args, **kwargs):
yield MagicMock(status=AppConversationStartTaskStatus.WORKING)
yield MagicMock(status=AppConversationStartTaskStatus.READY)
mock_service.start_app_conversation = mock_start_generator
mock_get_service.return_value.__aenter__.return_value = mock_service
# Set issue details to avoid fetch
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'Test description'
with patch.object(
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
) as mock_get_org:
mock_get_org.return_value = UUID('12345678-1234-5678-1234-567812345678')
result = await new_conversation_view.create_or_update_conversation(
mock_jinja_env
)
# Verify conversation was created
assert result is not None
mock_store.create_conversation.assert_called_once()
@pytest.mark.asyncio
@patch('integrations.jira.jira_view.get_app_conversation_service')
@patch('integrations.jira.jira_view.integration_store')
async def test_create_or_update_conversation_handles_error(
self,
mock_store,
mock_get_service,
new_conversation_view,
mock_jinja_env,
):
"""Test create_or_update_conversation handles V1 errors correctly."""
from unittest.mock import AsyncMock, MagicMock
from openhands.app_server.app_conversation.app_conversation_models import (
AppConversationStartTaskStatus,
)
mock_store.create_conversation = AsyncMock()
# Mock the app conversation service to return error
mock_service = AsyncMock()
async def mock_error_generator(*args, **kwargs):
yield MagicMock(
status=AppConversationStartTaskStatus.ERROR,
detail='Sandbox allocation failed',
)
mock_service.start_app_conversation = mock_error_generator
mock_get_service.return_value.__aenter__.return_value = mock_service
new_conversation_view._issue_title = 'Test Issue'
new_conversation_view._issue_description = 'Test description'
with patch.object(
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
) as mock_get_org:
mock_get_org.return_value = None
with pytest.raises(RuntimeError, match='Failed to start V1 conversation'):
await new_conversation_view.create_or_update_conversation(
mock_jinja_env
)
@@ -109,43 +109,3 @@ async def test_extracts_git_org_lowercase_from_repo_name(mock_stores):
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
'github', 'myorg'
)
@pytest.mark.asyncio
async def test_returns_org_id_without_membership_check_when_no_user_id(mock_stores):
"""When user_id is None, skip membership check and return org_id if claim exists."""
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
mock_claim_store, mock_member_store = mock_stores
# Arrange
claim = MagicMock()
claim.org_id = CLAIMING_ORG_ID
mock_claim_store.get_claim_by_provider_and_git_org.return_value = claim
# Act - no user_id provided
result = await resolve_org_for_repo('github', 'OpenHands/foo')
# Assert
assert result == CLAIMING_ORG_ID
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
'github', 'openhands'
)
# Membership check should NOT be called
mock_member_store.get_org_member.assert_not_called()
@pytest.mark.asyncio
async def test_returns_none_when_no_claim_and_no_user_id(mock_stores):
"""When no claim exists and no user_id, return None."""
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
mock_claim_store, mock_member_store = mock_stores
mock_claim_store.get_claim_by_provider_and_git_org.return_value = None
# Act - no user_id provided
result = await resolve_org_for_repo('github', 'UnclaimedOrg/repo')
# Assert
assert result is None
mock_member_store.get_org_member.assert_not_called()
@@ -1,330 +0,0 @@
"""Tests for onboarding-related auth routes and functions.
Tests for:
- _should_redirect_to_onboarding() function
- _get_post_auth_redirect() function
- /complete_onboarding endpoint
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import Request, status
from fastapi.responses import JSONResponse
from server.auth.saas_user_auth import SaasUserAuth
from server.routes.auth import (
_get_post_auth_redirect,
_should_redirect_to_onboarding,
complete_onboarding,
)
from storage.user import User
# --- Fixtures ---
@pytest.fixture
def mock_request():
"""Create a mock FastAPI Request."""
request = MagicMock(spec=Request)
request.url = MagicMock()
request.url.hostname = 'localhost'
request.url.netloc = 'localhost:8000'
request.base_url = 'http://localhost:8000/'
request.headers = {}
request.cookies = {}
return request
@pytest.fixture
def mock_user():
"""Create a mock User object."""
user = MagicMock(spec=User)
user.id = uuid.uuid4()
user.current_org_id = uuid.uuid4()
user.onboarding_completed = False
return user
# --- Tests for _should_redirect_to_onboarding ---
class TestShouldRedirectToOnboarding:
"""Tests for the _should_redirect_to_onboarding function."""
@pytest.mark.asyncio
async def test_returns_false_when_onboarding_completed(self, mock_user):
"""Test that completed onboarding users are not redirected."""
mock_user.onboarding_completed = True
result = await _should_redirect_to_onboarding('user-123', mock_user)
assert result is False
@pytest.mark.asyncio
async def test_returns_true_for_cloud_mode_new_user(self, mock_user):
"""Test that cloud mode users with incomplete onboarding are redirected."""
mock_user.onboarding_completed = False
with patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'):
result = await _should_redirect_to_onboarding('user-123', mock_user)
assert result is True
@pytest.mark.asyncio
async def test_returns_true_for_self_hosted_super_admin(self, mock_user):
"""Test that the super admin (first owner to accept TOS) is redirected."""
mock_user.onboarding_completed = False
user_id = str(mock_user.id)
# Mock this user as the first owner in the org (super admin)
first_owner = MagicMock(spec=User)
first_owner.id = mock_user.id
with (
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
patch(
'server.routes.auth.UserStore.get_first_owner_in_org',
new_callable=AsyncMock,
return_value=first_owner,
),
):
result = await _should_redirect_to_onboarding(user_id, mock_user)
assert result is True
@pytest.mark.asyncio
async def test_returns_false_for_self_hosted_non_super_admin_owner(self, mock_user):
"""Test that owners who aren't the super admin are NOT redirected."""
mock_user.onboarding_completed = False
user_id = str(mock_user.id)
# Mock a different user as the first owner (super admin)
first_owner = MagicMock(spec=User)
first_owner.id = uuid.uuid4() # Different user
with (
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
patch(
'server.routes.auth.UserStore.get_first_owner_in_org',
new_callable=AsyncMock,
return_value=first_owner,
),
):
result = await _should_redirect_to_onboarding(user_id, mock_user)
assert result is False
@pytest.mark.asyncio
async def test_returns_false_for_self_hosted_when_no_owner_found(self, mock_user):
"""Test that users are not redirected when no owner is found."""
mock_user.onboarding_completed = False
user_id = str(mock_user.id)
with (
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
patch(
'server.routes.auth.UserStore.get_first_owner_in_org',
new_callable=AsyncMock,
return_value=None,
),
):
result = await _should_redirect_to_onboarding(user_id, mock_user)
assert result is False
@pytest.mark.asyncio
async def test_passes_current_org_id_to_get_first_owner(self, mock_user):
"""Test that get_first_owner_in_org is called with user's current_org_id."""
mock_user.onboarding_completed = False
user_id = str(mock_user.id)
mock_get_first_owner = AsyncMock(return_value=None)
with (
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
patch(
'server.routes.auth.UserStore.get_first_owner_in_org',
mock_get_first_owner,
),
):
await _should_redirect_to_onboarding(user_id, mock_user)
mock_get_first_owner.assert_called_once_with(mock_user.current_org_id)
# --- Tests for _get_post_auth_redirect ---
class TestGetPostAuthRedirect:
"""Tests for the _get_post_auth_redirect function."""
@pytest.mark.asyncio
async def test_returns_onboarding_url_when_onboarding_needed(self, mock_user):
"""Test that onboarding URL is returned when user needs onboarding."""
mock_user.onboarding_completed = False
user_id = str(mock_user.id)
with (
patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'),
patch(
'server.routes.auth.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=mock_user,
),
):
result = await _get_post_auth_redirect(
user_id, 'https://example.com/', 'https://example.com'
)
assert result == 'https://example.com/onboarding'
@pytest.mark.asyncio
async def test_returns_default_url_when_onboarding_completed(self, mock_user):
"""Test that default URL is returned when user has completed onboarding."""
mock_user.onboarding_completed = True
user_id = str(mock_user.id)
with patch(
'server.routes.auth.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=mock_user,
):
result = await _get_post_auth_redirect(
user_id, 'https://example.com/dashboard', 'https://example.com'
)
assert result == 'https://example.com/dashboard'
@pytest.mark.asyncio
async def test_returns_default_url_when_user_not_found(self):
"""Test that default URL is returned when user is not found."""
with patch(
'server.routes.auth.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=None,
):
result = await _get_post_auth_redirect(
'nonexistent-user', 'https://example.com/', 'https://example.com'
)
assert result == 'https://example.com/'
@pytest.mark.asyncio
async def test_logs_when_redirecting_to_onboarding(self, mock_user):
"""Test that a log message is emitted when redirecting to onboarding."""
mock_user.onboarding_completed = False
user_id = str(mock_user.id)
with (
patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'),
patch(
'server.routes.auth.UserStore.get_user_by_id',
new_callable=AsyncMock,
return_value=mock_user,
),
patch('server.routes.auth.logger') as mock_logger,
):
await _get_post_auth_redirect(
user_id, 'https://example.com/', 'https://example.com'
)
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args
assert call_args[0][0] == 'Redirecting user to onboarding'
assert call_args[1]['extra']['user_id'] == user_id
# --- Tests for /complete_onboarding endpoint ---
class TestCompleteOnboardingEndpoint:
"""Tests for the complete_onboarding API endpoint."""
@pytest.mark.asyncio
async def test_returns_401_when_not_authenticated(self, mock_request):
"""Test that unauthenticated requests return 401."""
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.get_user_id = AsyncMock(return_value=None)
with patch(
'server.routes.auth.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
):
result = await complete_onboarding(mock_request)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio
async def test_returns_404_when_user_not_found(self, mock_request):
"""Test that request for non-existent user returns 404."""
user_id = str(uuid.uuid4())
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
with (
patch(
'server.routes.auth.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.auth.UserStore.mark_onboarding_completed',
new_callable=AsyncMock,
return_value=None,
),
):
result = await complete_onboarding(mock_request)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_404_NOT_FOUND
@pytest.mark.asyncio
async def test_returns_200_on_success(self, mock_request, mock_user):
"""Test successful onboarding completion returns 200."""
user_id = str(uuid.uuid4())
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
with (
patch(
'server.routes.auth.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.auth.UserStore.mark_onboarding_completed',
new_callable=AsyncMock,
return_value=mock_user,
),
):
result = await complete_onboarding(mock_request)
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
@pytest.mark.asyncio
async def test_calls_mark_onboarding_completed_with_user_id(
self, mock_request, mock_user
):
"""Test that mark_onboarding_completed is called with the correct user_id."""
user_id = str(uuid.uuid4())
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
mock_mark_completed = AsyncMock(return_value=mock_user)
with (
patch(
'server.routes.auth.get_user_auth',
new_callable=AsyncMock,
return_value=mock_user_auth,
),
patch(
'server.routes.auth.UserStore.mark_onboarding_completed',
mock_mark_completed,
),
):
await complete_onboarding(mock_request)
mock_mark_completed.assert_called_once_with(user_id)
@@ -1,670 +0,0 @@
"""
Unit tests for AutomationEventService.
Tests the service that forwards GitHub webhook events to the automation service.
The service is optimized for high-traffic with:
- Redis caching for org claim lookups (1 hour TTL)
- Redis caching for GitHub→Keycloak user ID mappings (24 hour TTL)
- Lazy access control (membership checks deferred to execution time)
- Separate AUTOMATION_WEBHOOK_SECRET for internal service communication
"""
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# Default patches for constants
CONSTANT_PATCHES = {
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET': 'test-shared-secret',
'server.services.automation_event_service.AUTOMATION_SERVICE_TIMEOUT': 30,
}
@pytest.fixture
def mock_token_manager():
"""Create a mock TokenManager."""
return MagicMock()
@pytest.fixture
def mock_org_git_claim():
"""Create a mock OrgGitClaim."""
claim = MagicMock()
claim.org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
return claim
@pytest.fixture
def github_org_payload():
"""Create a sample GitHub webhook payload for an organization repo."""
return {
'repository': {
'id': 123456,
'full_name': 'test-org/test-repo',
'private': False,
'default_branch': 'main',
'owner': {
'login': 'test-org',
'id': 789,
'type': 'Organization',
},
},
'sender': {
'id': 12345,
'login': 'testuser',
},
'action': 'opened',
'installation': {
'id': 99999,
},
}
@pytest.fixture
def github_user_payload():
"""Create a sample GitHub webhook payload for a personal/user repo."""
return {
'repository': {
'id': 654321,
'full_name': 'testuser/personal-repo',
'private': True,
'default_branch': 'main',
'owner': {
'login': 'testuser',
'id': 12345,
'type': 'User',
},
},
'sender': {
'id': 12345,
'login': 'testuser',
},
'action': 'opened',
'installation': {
'id': 99999,
},
}
def create_service(mock_token_manager):
"""Helper to create a service with mocked sio and constants."""
with patch('server.services.automation_event_service.sio'), patch.dict(
'os.environ', {}, clear=False
):
for key, value in CONSTANT_PATCHES.items():
patch(key, value).start()
from server.services.automation_event_service import AutomationEventService
return AutomationEventService(mock_token_manager)
class TestResolveGithubOrg:
"""Tests for _resolve_github_org method with caching."""
@pytest.mark.asyncio
async def test_resolve_github_org_cache_miss_found(
self, mock_token_manager, mock_org_git_claim
):
"""
GIVEN: Cache miss and org claim exists in DB
WHEN: _resolve_github_org is called
THEN: Org ID is returned and cached
"""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None) # Cache miss
mock_redis.setex = AsyncMock()
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=mock_org_git_claim.org_id,
), patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._resolve_github_org('test-org')
assert result == mock_org_git_claim.org_id
# Verify result was cached
mock_redis.setex.assert_called_once()
@pytest.mark.asyncio
async def test_resolve_github_org_cache_hit(self, mock_token_manager):
"""
GIVEN: Org ID is cached in Redis
WHEN: _resolve_github_org is called
THEN: Cached value is returned without calling resolve_org_for_repo
"""
cached_org_id = '12345678-1234-5678-1234-567812345678'
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=cached_org_id.encode())
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
) as mock_resolver, patch(
'server.services.automation_event_service.sio'
) as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._resolve_github_org('test-org')
assert result == uuid.UUID(cached_org_id)
# resolve_org_for_repo should NOT be called
mock_resolver.assert_not_called()
@pytest.mark.asyncio
async def test_resolve_github_org_cache_miss_not_found(self, mock_token_manager):
"""
GIVEN: Cache miss and org claim does NOT exist in DB
WHEN: _resolve_github_org is called
THEN: None is returned and negative result is cached
"""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None) # Cache miss
mock_redis.setex = AsyncMock()
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=None,
), patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._resolve_github_org('unclaimed-org')
assert result is None
# Verify negative result was cached
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
# Second positional arg is the value
assert call_args[0][2] == 'none' # Negative cache value
@pytest.mark.asyncio
async def test_resolve_github_org_negative_cache_hit(self, mock_token_manager):
"""
GIVEN: Negative result is cached (org not claimed)
WHEN: _resolve_github_org is called
THEN: None is returned without calling resolve_org_for_repo
"""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=b'none') # Cached negative
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
) as mock_resolver, patch(
'server.services.automation_event_service.sio'
) as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._resolve_github_org('unclaimed-org')
assert result is None
mock_resolver.assert_not_called()
class TestResolvePersonalOrg:
"""Tests for _resolve_personal_org method with caching."""
@pytest.mark.asyncio
async def test_resolve_personal_org_cache_miss_found(self, mock_token_manager):
"""
GIVEN: Cache miss and user exists in Keycloak
WHEN: _resolve_personal_org is called
THEN: Keycloak ID is returned and cached
"""
keycloak_id = '87654321-4321-8765-4321-876543218765'
mock_token_manager.get_user_id_from_idp_user_id = AsyncMock(
return_value=keycloak_id
)
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None) # Cache miss
mock_redis.setex = AsyncMock()
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._resolve_personal_org(12345)
assert result == uuid.UUID(keycloak_id)
mock_redis.setex.assert_called_once()
@pytest.mark.asyncio
async def test_resolve_personal_org_cache_hit(self, mock_token_manager):
"""
GIVEN: Keycloak ID is cached in Redis
WHEN: _resolve_personal_org is called
THEN: Cached value is returned without Keycloak query
"""
keycloak_id = '87654321-4321-8765-4321-876543218765'
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=keycloak_id.encode())
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._resolve_personal_org(12345)
assert result == uuid.UUID(keycloak_id)
# Token manager should NOT be called
mock_token_manager.get_user_id_from_idp_user_id.assert_not_called()
@pytest.mark.asyncio
async def test_resolve_personal_org_no_github_user_id(self, mock_token_manager):
"""
GIVEN: No GitHub user ID provided
WHEN: _resolve_personal_org is called
THEN: None is returned immediately
"""
service = create_service(mock_token_manager)
result = await service._resolve_personal_org(None)
assert result is None
class TestForwardGithubEvent:
"""Tests for forward_github_event method (minimal payload, no access control)."""
@pytest.mark.asyncio
async def test_forward_org_event_success(
self, mock_token_manager, github_org_payload, mock_org_git_claim
):
"""
GIVEN: A GitHub event from a claimed organization repo
WHEN: forward_github_event is called
THEN: Minimal payload is forwarded (no access_control)
"""
from server.services.automation_event_service import AutomationEventService
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_redis.setex = AsyncMock()
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=mock_org_git_claim.org_id,
), patch(
'server.services.automation_event_service.sio'
) as mock_sio, patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
mock_sio.manager.redis = mock_redis
service = AutomationEventService(mock_token_manager)
await service.forward_github_event(
payload=github_org_payload,
installation_id=99999,
)
mock_send.assert_called_once()
call_args = mock_send.call_args
assert call_args[0][0] == mock_org_git_claim.org_id
payload = call_args[0][1]
assert payload['organization']['github_org'] == 'test-org'
assert 'payload' in payload
# access_control should NOT be in payload (lazy evaluation)
assert 'access_control' not in payload
@pytest.mark.asyncio
async def test_forward_personal_repo_event_success(
self, mock_token_manager, github_user_payload
):
"""
GIVEN: A GitHub event from a personal repo with linked OpenHands account
WHEN: forward_github_event is called
THEN: Event is forwarded using the user's personal org (keycloak ID)
"""
from server.services.automation_event_service import AutomationEventService
keycloak_id = '87654321-4321-8765-4321-876543218765'
mock_token_manager.get_user_id_from_idp_user_id = AsyncMock(
return_value=keycloak_id
)
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_redis.setex = AsyncMock()
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=None, # No org claim for personal repo
), patch(
'server.services.automation_event_service.sio'
) as mock_sio, patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
mock_sio.manager.redis = mock_redis
service = AutomationEventService(mock_token_manager)
await service.forward_github_event(
payload=github_user_payload,
installation_id=99999,
)
mock_send.assert_called_once()
call_args = mock_send.call_args
# Personal org should be keycloak ID
assert call_args[0][0] == uuid.UUID(keycloak_id)
payload = call_args[0][1]
assert payload['organization']['github_org'] == 'testuser'
assert payload['organization']['openhands_org_id'] == keycloak_id
@pytest.mark.asyncio
async def test_forward_event_no_owner_in_payload(self, mock_token_manager):
"""
GIVEN: A GitHub event with no repository owner in payload
WHEN: forward_github_event is called
THEN: Event is skipped with warning log
"""
from server.services.automation_event_service import AutomationEventService
payload = {
'repository': {},
'sender': {'id': 12345, 'login': 'testuser'},
}
with patch('server.services.automation_event_service.sio'), patch(
'server.services.automation_event_service.logger'
) as mock_logger, patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
service = AutomationEventService(mock_token_manager)
await service.forward_github_event(
payload=payload,
installation_id=99999,
)
mock_send.assert_not_called()
mock_logger.warning.assert_called()
assert 'No repository owner' in str(mock_logger.warning.call_args)
@pytest.mark.asyncio
async def test_forward_event_org_not_claimed_and_not_personal(
self, mock_token_manager, github_org_payload
):
"""
GIVEN: A GitHub event from an org that isn't claimed (and isn't personal)
WHEN: forward_github_event is called
THEN: Event is skipped with warning log
"""
from server.services.automation_event_service import AutomationEventService
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
mock_redis.setex = AsyncMock()
with patch(
'server.services.automation_event_service.resolve_org_for_repo',
new_callable=AsyncMock,
return_value=None,
), patch('server.services.automation_event_service.sio') as mock_sio, patch(
'server.services.automation_event_service.logger'
) as mock_logger, patch.object(
AutomationEventService,
'_send_to_automation_service',
new_callable=AsyncMock,
) as mock_send:
mock_sio.manager.redis = mock_redis
service = AutomationEventService(mock_token_manager)
await service.forward_github_event(
payload=github_org_payload,
installation_id=99999,
)
mock_send.assert_not_called()
mock_logger.warning.assert_called()
assert 'not claimed' in str(mock_logger.warning.call_args)
class TestBuildEventPayload:
"""Tests for _build_event_payload method."""
def test_build_minimal_payload(self, mock_token_manager):
"""
GIVEN: Org context and payload
WHEN: _build_event_payload is called
THEN: Minimal payload with only org + payload is returned
"""
from server.services.automation_event_service import OrgContext
service = create_service(mock_token_manager)
org_context = OrgContext(
org_id=uuid.UUID('12345678-1234-5678-1234-567812345678'),
github_org='test-org',
)
test_payload = {'action': 'opened', 'sender': {'login': 'user'}}
result = service._build_event_payload(org_context, test_payload)
assert result == {
'organization': {
'github_org': 'test-org',
'openhands_org_id': '12345678-1234-5678-1234-567812345678',
},
'payload': test_payload,
}
# Verify NO access_control in payload
assert 'access_control' not in result
class TestSendToAutomationService:
"""Tests for _send_to_automation_service method."""
@pytest.mark.asyncio
async def test_send_success(self, mock_token_manager):
"""
GIVEN: AUTOMATION_SERVICE_URL is configured
WHEN: _send_to_automation_service is called
THEN: Request is sent with correct signature
"""
org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
payload = {'organization': {'github_org': 'test'}, 'payload': {}}
mock_response = MagicMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={'matched': 2})
mock_post_context = MagicMock()
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
mock_post_context.__aexit__ = AsyncMock(return_value=None)
mock_session_instance = MagicMock()
mock_session_instance.post = MagicMock(return_value=mock_post_context)
mock_session_context = MagicMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session_instance)
mock_session_context.__aexit__ = AsyncMock(return_value=None)
with patch(
'server.services.automation_event_service.AUTOMATION_SERVICE_URL',
'https://automation.example.com',
), patch('server.services.automation_event_service.sio'), patch(
'server.services.automation_event_service.aiohttp.ClientSession',
return_value=mock_session_context,
):
service = create_service(mock_token_manager)
await service._send_to_automation_service(org_id, payload)
# Verify the POST was called
mock_session_instance.post.assert_called_once()
@pytest.mark.asyncio
async def test_send_no_url_configured(self, mock_token_manager):
"""
GIVEN: AUTOMATION_SERVICE_URL is not configured
WHEN: _send_to_automation_service is called
THEN: Warning is logged and nothing is sent
"""
org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
payload = {}
with patch(
'server.services.automation_event_service.AUTOMATION_SERVICE_URL', None
), patch('server.services.automation_event_service.sio'), patch(
'server.services.automation_event_service.logger'
) as mock_logger:
service = create_service(mock_token_manager)
await service._send_to_automation_service(org_id, payload)
mock_logger.warning.assert_called()
assert 'not configured' in str(mock_logger.warning.call_args)
class TestSignPayload:
"""Tests for _sign_payload method."""
def test_sign_payload(self, mock_token_manager):
"""
GIVEN: A payload bytes
WHEN: _sign_payload is called
THEN: HMAC-SHA256 signature is returned in correct format
"""
with patch(
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
'test-shared-secret',
), patch('server.services.automation_event_service.sio'):
service = create_service(mock_token_manager)
payload_bytes = b'{"test": "data"}'
signature = service._sign_payload(payload_bytes)
assert signature.startswith('sha256=')
assert len(signature) == 71 # 'sha256=' + 64 hex chars
def test_sign_payload_uses_dedicated_secret(self, mock_token_manager):
"""
GIVEN: AUTOMATION_WEBHOOK_SECRET is configured
WHEN: _sign_payload is called
THEN: The dedicated secret is used (not GitHub webhook secret)
"""
import hashlib
import hmac
# Use the default test secret from CONSTANT_PATCHES
shared_secret = 'test-shared-secret'
payload_bytes = b'{"test": "data"}'
# Calculate expected signature with the shared secret
expected_sig = hmac.new(
shared_secret.encode('utf-8'),
msg=payload_bytes,
digestmod=hashlib.sha256,
).hexdigest()
with patch(
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
shared_secret,
), patch('server.services.automation_event_service.sio'):
service = create_service(mock_token_manager)
signature = service._sign_payload(payload_bytes)
assert signature == f'sha256={expected_sig}'
class TestCacheHelpers:
"""Tests for generic cache helper methods."""
@pytest.mark.asyncio
async def test_get_cached_value_hit(self, mock_token_manager):
"""
GIVEN: Value exists in Redis cache
WHEN: _get_cached_value is called
THEN: Decoded string value is returned
"""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=b'cached-value')
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._get_cached_value('test-key')
assert result == 'cached-value'
@pytest.mark.asyncio
async def test_get_cached_value_miss(self, mock_token_manager):
"""
GIVEN: Value does not exist in Redis cache
WHEN: _get_cached_value is called
THEN: None is returned
"""
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
result = await service._get_cached_value('test-key')
assert result is None
@pytest.mark.asyncio
async def test_get_cached_value_redis_unavailable(self, mock_token_manager):
"""
GIVEN: Redis is unavailable
WHEN: _get_cached_value is called
THEN: None is returned (graceful degradation)
"""
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = None
service = create_service(mock_token_manager)
result = await service._get_cached_value('test-key')
assert result is None
@pytest.mark.asyncio
async def test_set_cached_value_success(self, mock_token_manager):
"""
GIVEN: Redis is available
WHEN: _set_cached_value is called
THEN: Value is stored with TTL
"""
mock_redis = AsyncMock()
mock_redis.setex = AsyncMock()
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = mock_redis
service = create_service(mock_token_manager)
await service._set_cached_value('test-key', 'test-value', 3600)
mock_redis.setex.assert_called_once_with('test-key', 3600, 'test-value')
@pytest.mark.asyncio
async def test_set_cached_value_redis_unavailable(self, mock_token_manager):
"""
GIVEN: Redis is unavailable
WHEN: _set_cached_value is called
THEN: No error is raised (silent failure)
"""
with patch('server.services.automation_event_service.sio') as mock_sio:
mock_sio.manager.redis = None
service = create_service(mock_token_manager)
# Should not raise
await service._set_cached_value('test-key', 'test-value', 3600)
@@ -1,105 +0,0 @@
"""Tests for enterprise server constants, specifically DEPLOYMENT_MODE detection."""
from unittest.mock import patch
import pytest
class TestDeploymentMode:
"""Tests for _get_deployment_mode() and _is_all_hands_managed_domain() functions."""
@pytest.mark.parametrize(
'web_host,expected_mode',
[
# All-Hands managed domains should return 'cloud'
('app.all-hands.dev', 'cloud'),
('staging.all-hands.dev', 'cloud'),
('feature-123.staging.all-hands.dev', 'cloud'),
('pr-456.staging.all-hands.dev', 'cloud'),
('app.openhands.ai', 'cloud'),
# Customer domains should return 'self_hosted'
('openhands.acme.com', 'self_hosted'),
('internal.company.io', 'self_hosted'),
('dev.mycompany.net', 'self_hosted'),
('openhands.example.org', 'self_hosted'),
('localhost', 'self_hosted'), # localhost is not a managed domain
# Edge cases
('all-hands.dev', 'self_hosted'), # Not a subdomain, so not managed
('fake-all-hands.dev', 'self_hosted'),
('app.all-hands.dev.evil.com', 'self_hosted'),
],
)
def test_deployment_mode_detection(self, web_host: str, expected_mode: str):
"""Test that DEPLOYMENT_MODE is correctly determined based on WEB_HOST."""
with patch.dict('os.environ', {'WEB_HOST': web_host}):
# Need to reimport to pick up the mocked environment variable
import importlib
import server.constants as constants_module
importlib.reload(constants_module)
assert constants_module.DEPLOYMENT_MODE == expected_mode
@pytest.mark.parametrize(
'host,expected',
[
('app.all-hands.dev', True),
('staging.all-hands.dev', True),
('feature.staging.all-hands.dev', True),
('app.openhands.ai', True),
('localhost', False), # localhost is not a managed domain
('customer.example.com', False),
('all-hands.dev', False),
],
)
def test_is_all_hands_managed_domain(self, host: str, expected: bool):
"""Test _is_all_hands_managed_domain() helper function."""
from server.constants import _is_all_hands_managed_domain
assert _is_all_hands_managed_domain(host) == expected
def test_deployment_mode_default_is_cloud(self):
"""Test that default WEB_HOST (app.all-hands.dev) results in 'cloud' mode."""
with patch.dict('os.environ', {}, clear=True):
# Remove WEB_HOST to test default
import importlib
import os
if 'WEB_HOST' in os.environ:
del os.environ['WEB_HOST']
import server.constants as constants_module
importlib.reload(constants_module)
# Default WEB_HOST is 'app.all-hands.dev' which should be 'cloud'
assert constants_module.DEPLOYMENT_MODE == 'cloud'
class TestDeploymentModeInConfig:
"""Tests for DEPLOYMENT_MODE being exposed in config API."""
def test_deployment_mode_included_in_feature_flags(self):
"""Test that DEPLOYMENT_MODE is included in FEATURE_FLAGS from get_config()."""
from server.config import SaaSServerConfig
with patch('server.config.DEPLOYMENT_MODE', 'cloud'):
saas_config = SaaSServerConfig()
config = saas_config.get_config()
assert 'FEATURE_FLAGS' in config
assert 'DEPLOYMENT_MODE' in config['FEATURE_FLAGS']
assert config['FEATURE_FLAGS']['DEPLOYMENT_MODE'] == 'cloud'
def test_deployment_mode_self_hosted_in_feature_flags(self):
"""Test that self_hosted DEPLOYMENT_MODE is included in FEATURE_FLAGS."""
from server.config import SaaSServerConfig
with patch('server.config.DEPLOYMENT_MODE', 'self_hosted'):
saas_config = SaaSServerConfig()
config = saas_config.get_config()
assert 'FEATURE_FLAGS' in config
assert 'DEPLOYMENT_MODE' in config['FEATURE_FLAGS']
assert config['FEATURE_FLAGS']['DEPLOYMENT_MODE'] == 'self_hosted'
+4 -217
View File
@@ -187,11 +187,6 @@ async def test_keycloak_callback_success_with_valid_offline_token(
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.posthog') as mock_posthog,
patch(
'server.routes.auth._should_redirect_to_onboarding',
new_callable=AsyncMock,
return_value=False,
),
):
# Mock user with accepted_tos
mock_user = MagicMock()
@@ -444,11 +439,6 @@ async def test_keycloak_callback_success_without_offline_token(
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.posthog') as mock_posthog,
patch(
'server.routes.auth._should_redirect_to_onboarding',
new_callable=AsyncMock,
return_value=False,
),
):
# Mock user with accepted_tos
mock_user = MagicMock()
@@ -494,101 +484,19 @@ async def test_keycloak_callback_success_without_offline_token(
mock_token_manager.store_idp_tokens.assert_called_once_with(
ProviderType.GITHUB, 'test_user_id', 'test_access_token'
)
# secure is based on web_url (http://localhost:8000/), not redirect_url
# So secure=False because web_url starts with 'http://'
# When redirecting to Keycloak for offline token, redirect_url becomes https://keycloak...
# so secure=True is expected
mock_set_cookie.assert_called_once_with(
request=mock_request,
response=result,
keycloak_access_token='test_access_token',
keycloak_refresh_token='test_refresh_token',
secure=False,
secure=True,
accepted_tos=True,
)
mock_posthog.set.assert_called_once()
@pytest.mark.asyncio
async def test_keycloak_callback_redirects_to_keycloak_when_offline_token_invalid(
mock_request, create_keycloak_user_info
):
"""Test that keycloak_callback redirects to Keycloak when offline token is invalid.
When a user doesn't have a valid offline token, they should be redirected
to Keycloak to obtain one, rather than proceeding with invitation processing.
"""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
patch(
'server.routes.auth.KEYCLOAK_SERVER_URL_EXT', 'https://keycloak.example.com'
),
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.posthog'),
patch('server.routes.auth.OrgInvitationService') as mock_invitation_service,
patch(
'server.routes.auth._should_redirect_to_onboarding',
new_callable=AsyncMock,
return_value=False,
),
):
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.id = 'test_user_id'
mock_user.current_org_id = 'test_org_id'
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_user_store.backfill_contact_name = AsyncMock()
mock_user_store.backfill_user_email = AsyncMock()
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value=create_keycloak_user_info(
sub='test_user_id',
preferred_username='test_user',
identity_provider='github',
email_verified=True,
)
)
mock_token_manager.store_idp_tokens = AsyncMock()
mock_token_manager.validate_offline_token = AsyncMock(return_value=False)
# Call with an invitation token to verify it's NOT processed
import base64
import json
state_data = {
'redirect_url': 'https://example.com/original-page',
'invitation_token': 'inv-test-token-123',
}
encoded_state = base64.urlsafe_b64encode(
json.dumps(state_data).encode()
).decode()
result = await keycloak_callback(
code='test_code',
state=encoded_state,
request=mock_request,
user_authorizer=create_mock_user_authorizer(),
)
# Should redirect to Keycloak for offline token
assert isinstance(result, RedirectResponse)
assert 'keycloak.example.com' in result.headers['location']
assert 'offline_access' in result.headers['location']
# Cookie should be set with accepted_tos=True (user has accepted TOS)
mock_set_cookie.assert_called_once()
assert mock_set_cookie.call_args[1]['accepted_tos'] is True
# Invitation service should NOT be called (early return before invitation processing)
mock_invitation_service.accept_invitation.assert_not_called()
@pytest.mark.asyncio
async def test_keycloak_callback_account_linking_error(mock_request):
"""Test keycloak_callback with account linking error."""
@@ -667,21 +575,7 @@ async def test_keycloak_offline_callback_success(
mock_request, create_keycloak_user_info
):
"""Test successful keycloak_offline_callback."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.set_response_cookie'),
patch(
'server.routes.auth._get_post_auth_redirect',
new_callable=AsyncMock,
return_value='test_state',
),
):
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
with patch('server.routes.auth.token_manager') as mock_token_manager:
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
@@ -704,43 +598,6 @@ async def test_keycloak_offline_callback_success(
)
@pytest.mark.asyncio
async def test_keycloak_offline_callback_redirects_to_onboarding(
mock_request, create_keycloak_user_info
):
"""Test keycloak_offline_callback redirects to onboarding when needed."""
with (
patch('server.routes.auth.token_manager') as mock_token_manager,
patch('server.routes.auth.UserStore') as mock_user_store,
patch('server.routes.auth.set_response_cookie'),
patch(
'server.routes.auth._get_post_auth_redirect',
new_callable=AsyncMock,
return_value='http://localhost:8000/onboarding',
),
):
# Mock user with accepted_tos
mock_user = MagicMock()
mock_user.accepted_tos = '2025-01-01'
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
mock_token_manager.get_keycloak_tokens = AsyncMock(
return_value=('test_access_token', 'test_refresh_token')
)
mock_token_manager.get_user_info = AsyncMock(
return_value=create_keycloak_user_info(sub='test_user_id')
)
mock_token_manager.store_offline_token = AsyncMock()
result = await keycloak_offline_callback(
'test_code', 'test_state', mock_request
)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
assert result.headers['location'] == 'http://localhost:8000/onboarding'
@pytest.mark.asyncio
async def test_authenticate_success():
"""Test successful authentication."""
@@ -2185,20 +2042,12 @@ async def test_accept_tos_stores_timezone_naive_datetime(mock_request):
mock_request.json = AsyncMock(return_value={'redirect_url': 'http://example.com'})
# Mock user for onboarding check (user already completed onboarding)
mock_user_for_onboarding = MagicMock()
mock_user_for_onboarding.onboarding_completed = True
with (
patch(
'server.routes.auth.get_user_auth', AsyncMock(return_value=mock_user_auth)
),
patch('server.routes.auth.a_session_maker', return_value=mock_session_context),
patch('server.routes.auth.set_response_cookie'),
patch(
'server.routes.auth._get_post_auth_redirect',
AsyncMock(return_value='http://example.com'),
),
):
# Act
result = await accept_tos(mock_request)
@@ -2209,65 +2058,3 @@ async def test_accept_tos_stores_timezone_naive_datetime(mock_request):
# The datetime assigned to user.accepted_tos must be timezone-naive
# (compatible with TIMESTAMP WITHOUT TIME ZONE database column)
assert mock_user.accepted_tos.tzinfo is None
@pytest.mark.asyncio
async def test_accept_tos_preserves_offline_flow_redirect(mock_request):
"""Test that accept_tos does not override redirect_url when it's the offline token flow."""
# Arrange
test_user_id = '12345678-1234-5678-1234-567812345678'
offline_redirect_url = 'https://auth.example.com/realms/test/protocol/openid-connect/auth?redirect_uri=https://example.com/oauth/keycloak/offline/callback'
mock_user = MagicMock()
mock_user.id = test_user_id
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_user
mock_session = AsyncMock()
mock_session.execute.return_value = mock_result
mock_session.commit = AsyncMock()
mock_session_context = AsyncMock()
mock_session_context.__aenter__.return_value = mock_session
mock_session_context.__aexit__.return_value = None
mock_user_auth = MagicMock(spec=SaasUserAuth)
mock_user_auth.get_access_token = AsyncMock(
return_value=SecretStr('test_access_token')
)
mock_user_auth.refresh_token = SecretStr('test_refresh_token')
mock_user_auth.get_user_id = AsyncMock(return_value=test_user_id)
mock_request.json = AsyncMock(return_value={'redirect_url': offline_redirect_url})
mock_get_post_auth_redirect = AsyncMock(
return_value='http://example.com/onboarding'
)
with (
patch(
'server.routes.auth.get_user_auth', AsyncMock(return_value=mock_user_auth)
),
patch('server.routes.auth.a_session_maker', return_value=mock_session_context),
patch('server.routes.auth.set_response_cookie'),
patch(
'server.routes.auth._get_post_auth_redirect',
mock_get_post_auth_redirect,
),
):
# Act
result = await accept_tos(mock_request)
# Assert
assert isinstance(result, JSONResponse)
assert result.status_code == status.HTTP_200_OK
# _get_post_auth_redirect should NOT be called for offline flow
mock_get_post_auth_redirect.assert_not_called()
# The redirect_url should be preserved (not overridden to onboarding)
import json
response_body = json.loads(result.body.decode())
assert response_body['redirect_url'] == offline_redirect_url
@@ -56,7 +56,6 @@ class TestPermission:
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
assert Permission.MANAGE_AUTOMATIONS.value == 'manage_automations'
def test_permission_from_string(self):
"""
@@ -143,7 +142,6 @@ class TestRolePermissions:
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
assert Permission.DELETE_ORGANIZATION in owner_perms
assert Permission.MANAGE_AUTOMATIONS in owner_perms
def test_admin_has_admin_permissions(self):
"""
@@ -161,7 +159,6 @@ class TestRolePermissions:
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
assert Permission.MANAGE_AUTOMATIONS in admin_perms
# Admin should NOT have owner-only permissions
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
@@ -180,7 +177,6 @@ class TestRolePermissions:
assert Permission.MANAGE_INTEGRATIONS in member_perms
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
assert Permission.MANAGE_API_KEYS in member_perms
assert Permission.MANAGE_AUTOMATIONS in member_perms
assert Permission.VIEW_LLM_SETTINGS in member_perms
assert Permission.VIEW_ORG_SETTINGS in member_perms
# Member should NOT have admin/owner permissions
@@ -1,284 +0,0 @@
"""Tests for ConversationStateUpdateEvent filtering in shared_event_router.
The shared-events endpoints are unauthenticated, so internal system state
(ConversationStateUpdateEvent) must not be returned. The frontend shared-
conversation viewer never renders these events — it only uses messages,
actions, observations, errors, and hook-execution events.
"""
from __future__ import annotations
from unittest.mock import AsyncMock
from uuid import uuid4
import pytest
from server.sharing.shared_event_router import (
_is_viewable,
batch_get_shared_events,
get_shared_event,
search_shared_events,
)
from openhands.agent_server.models import EventPage
from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent
from openhands.sdk.event.llm_convertible import MessageEvent
from openhands.sdk.llm import Message, TextContent
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
def _make_message_event() -> MessageEvent:
return MessageEvent(
source='user',
llm_message=Message(role='user', content=[TextContent(text='Hello')]),
)
def _make_state_event(
key: str = 'full_state', value: dict | str = 'idle'
) -> ConversationStateUpdateEvent:
return ConversationStateUpdateEvent(key=key, value=value)
# ---------------------------------------------------------------------------
# _is_viewable
# ---------------------------------------------------------------------------
class TestIsViewable:
def test_message_event_is_viewable(self):
assert _is_viewable(_make_message_event()) is True
def test_full_state_event_is_not_viewable(self):
assert _is_viewable(_make_state_event('full_state', {'agent': {}})) is False
def test_execution_status_event_is_not_viewable(self):
assert _is_viewable(_make_state_event('execution_status', 'running')) is False
def test_stats_event_is_not_viewable(self):
assert _is_viewable(_make_state_event('stats', {})) is False
# ---------------------------------------------------------------------------
# search_shared_events
# ---------------------------------------------------------------------------
class TestSearchSharedEvents:
@pytest.mark.asyncio
async def test_filters_out_state_events(self):
msg = _make_message_event()
state = _make_state_event()
mock_service = AsyncMock()
mock_service.search_shared_events.return_value = EventPage(
items=[msg, state, msg], next_page_id=None
)
result = await search_shared_events(
conversation_id=uuid4().hex,
shared_event_service=mock_service,
)
assert len(result.items) == 2
assert all(
not isinstance(e, ConversationStateUpdateEvent) for e in result.items
)
@pytest.mark.asyncio
async def test_empty_page_unchanged(self):
mock_service = AsyncMock()
mock_service.search_shared_events.return_value = EventPage(
items=[], next_page_id=None
)
result = await search_shared_events(
conversation_id=uuid4().hex,
shared_event_service=mock_service,
)
assert result.items == []
assert result.next_page_id is None
@pytest.mark.asyncio
async def test_fetches_additional_pages_when_filtering_reduces_count(self):
"""Fetch next page when first page has only state events."""
msg = _make_message_event()
state = _make_state_event()
mock_service = AsyncMock()
mock_service.search_shared_events.side_effect = [
# Page 1: only state events — all filtered out
EventPage(items=[state, state, state], next_page_id='page2'),
# Page 2: all viewable
EventPage(items=[msg, msg, msg], next_page_id=None),
]
result = await search_shared_events(
conversation_id=uuid4().hex,
limit=3,
shared_event_service=mock_service,
)
assert len(result.items) == 3
assert result.next_page_id is None
assert mock_service.search_shared_events.call_count == 2
@pytest.mark.asyncio
async def test_multiple_pages_until_limit_reached(self):
"""Keep fetching mixed pages until limit viewable events accumulated."""
msg = _make_message_event()
state = _make_state_event()
mock_service = AsyncMock()
mock_service.search_shared_events.side_effect = [
EventPage(items=[msg, state], next_page_id='p2'),
EventPage(items=[state, msg], next_page_id='p3'),
EventPage(items=[msg], next_page_id='p4'),
]
result = await search_shared_events(
conversation_id=uuid4().hex,
limit=3,
shared_event_service=mock_service,
)
assert len(result.items) == 3
assert result.next_page_id == 'p4'
assert mock_service.search_shared_events.call_count == 3
@pytest.mark.asyncio
async def test_stops_when_no_more_pages(self):
"""Return partial results when no more backend pages are available."""
msg = _make_message_event()
state = _make_state_event()
mock_service = AsyncMock()
mock_service.search_shared_events.side_effect = [
EventPage(items=[msg, state], next_page_id='p2'),
EventPage(items=[state], next_page_id=None),
]
result = await search_shared_events(
conversation_id=uuid4().hex,
limit=5,
shared_event_service=mock_service,
)
assert len(result.items) == 1
assert result.next_page_id is None
@pytest.mark.asyncio
async def test_passes_remaining_as_limit_to_backend(self):
"""Pass remaining needed count as limit to each backend call."""
msg = _make_message_event()
state = _make_state_event()
conv_id = uuid4().hex
mock_service = AsyncMock()
mock_service.search_shared_events.side_effect = [
# First call: limit=3, returns 1 viewable
EventPage(items=[msg, state, state], next_page_id='p2'),
# Second call: limit should be 2 (remaining)
EventPage(items=[msg, msg], next_page_id=None),
]
await search_shared_events(
conversation_id=conv_id,
limit=3,
shared_event_service=mock_service,
)
calls = mock_service.search_shared_events.call_args_list
assert calls[0].kwargs['limit'] == 3
assert calls[1].kwargs['limit'] == 2
@pytest.mark.asyncio
async def test_preserves_next_page_id_when_all_filtered(self):
"""Continue fetching when all events on a page are filtered out."""
msg = _make_message_event()
state = _make_state_event()
mock_service = AsyncMock()
mock_service.search_shared_events.side_effect = [
EventPage(items=[state], next_page_id='p2'),
EventPage(items=[msg], next_page_id='p3'),
]
result = await search_shared_events(
conversation_id=uuid4().hex,
limit=1,
shared_event_service=mock_service,
)
assert len(result.items) == 1
assert result.next_page_id == 'p3'
# ---------------------------------------------------------------------------
# batch_get_shared_events
# ---------------------------------------------------------------------------
class TestBatchGetSharedEvents:
@pytest.mark.asyncio
async def test_replaces_state_events_with_none(self):
msg = _make_message_event()
state = _make_state_event()
mock_service = AsyncMock()
mock_service.batch_get_shared_events.return_value = [msg, state, None]
result = await batch_get_shared_events(
conversation_id=uuid4().hex,
id=[uuid4().hex, uuid4().hex, uuid4().hex],
shared_event_service=mock_service,
)
assert len(result) == 3
assert result[0] is msg
assert result[1] is None # state event replaced with None
assert result[2] is None # originally None stays None
# ---------------------------------------------------------------------------
# get_shared_event
# ---------------------------------------------------------------------------
class TestGetSharedEvent:
@pytest.mark.asyncio
async def test_returns_message_event(self):
msg = _make_message_event()
mock_service = AsyncMock()
mock_service.get_shared_event.return_value = msg
result = await get_shared_event(
conversation_id=uuid4().hex,
event_id=uuid4().hex,
shared_event_service=mock_service,
)
assert result is msg
@pytest.mark.asyncio
async def test_returns_none_for_state_event(self):
state = _make_state_event()
mock_service = AsyncMock()
mock_service.get_shared_event.return_value = state
result = await get_shared_event(
conversation_id=uuid4().hex,
event_id=uuid4().hex,
shared_event_service=mock_service,
)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_not_found(self):
mock_service = AsyncMock()
mock_service.get_shared_event.return_value = None
result = await get_shared_event(
conversation_id=uuid4().hex,
event_id=uuid4().hex,
shared_event_service=mock_service,
)
assert result is None
-351
View File
@@ -1325,354 +1325,3 @@ async def test_migrate_user_sql_multiple_conversations(async_session_maker):
# statements that have SQLite/UUID compatibility issues in the test environment.
# The SQL migration tests above (test_migrate_user_sql_type_handling, etc.) verify
# the SQL operations work correctly with proper type handling.
# --- Tests for mark_onboarding_completed ---
@pytest.mark.asyncio
async def test_mark_onboarding_completed_success(async_session_maker):
"""Test successfully marking onboarding as completed."""
user_id = uuid.uuid4()
org_id = uuid.uuid4()
# Create test data
async with async_session_maker() as session:
org = Org(id=org_id, name='test-org')
session.add(org)
user = User(id=user_id, current_org_id=org_id, onboarding_completed=False)
session.add(user)
await session.commit()
# Test marking onboarding complete
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.mark_onboarding_completed(str(user_id))
assert result is not None
assert result.id == user_id
assert result.onboarding_completed is True
@pytest.mark.asyncio
async def test_mark_onboarding_completed_user_not_found(async_session_maker):
"""Test that mark_onboarding_completed returns None for non-existent user."""
non_existent_id = str(uuid.uuid4())
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.mark_onboarding_completed(non_existent_id)
assert result is None
@pytest.mark.asyncio
async def test_mark_onboarding_completed_already_completed(async_session_maker):
"""Test marking onboarding complete for user who already completed it."""
user_id = uuid.uuid4()
org_id = uuid.uuid4()
# Create user with onboarding already completed
async with async_session_maker() as session:
org = Org(id=org_id, name='test-org')
session.add(org)
user = User(id=user_id, current_org_id=org_id, onboarding_completed=True)
session.add(user)
await session.commit()
# Should still succeed and return user
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.mark_onboarding_completed(str(user_id))
assert result is not None
assert result.id == user_id
assert result.onboarding_completed is True
@pytest.mark.asyncio
async def test_mark_onboarding_completed_user_with_null_onboarding(async_session_maker):
"""Test marking onboarding complete for user with null onboarding_completed value."""
user_id = uuid.uuid4()
org_id = uuid.uuid4()
# Create user with null onboarding_completed (default)
async with async_session_maker() as session:
org = Org(id=org_id, name='test-org')
session.add(org)
user = User(
id=user_id, current_org_id=org_id
) # onboarding_completed defaults to None
session.add(user)
await session.commit()
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.mark_onboarding_completed(str(user_id))
assert result is not None
assert result.id == user_id
assert result.onboarding_completed is True
# --- Tests for get_first_owner_in_org ---
@pytest.mark.asyncio
async def test_get_first_owner_in_org_returns_first_owner(async_session_maker):
"""Test that get_first_owner_in_org returns the owner with earliest accepted_tos."""
from datetime import datetime, timedelta
from storage.org_member import OrgMember
from storage.role import Role
org_id = uuid.uuid4()
first_owner_id = uuid.uuid4()
second_owner_id = uuid.uuid4()
async with async_session_maker() as session:
# Create org
org = Org(id=org_id, name='test-org')
session.add(org)
# Create owner role
owner_role = Role(id=1, name='owner', rank=10)
session.add(owner_role)
# Create first owner (earlier TOS acceptance)
first_owner = User(
id=first_owner_id,
current_org_id=org_id,
accepted_tos=datetime.now() - timedelta(days=10),
)
session.add(first_owner)
# Create second owner (later TOS acceptance)
second_owner = User(
id=second_owner_id,
current_org_id=org_id,
accepted_tos=datetime.now() - timedelta(days=5),
)
session.add(second_owner)
await session.flush()
# Add both as org members with owner role
first_member = OrgMember(
org_id=org_id,
user_id=first_owner_id,
role_id=owner_role.id,
llm_api_key='test-key-1',
)
session.add(first_member)
second_member = OrgMember(
org_id=org_id,
user_id=second_owner_id,
role_id=owner_role.id,
llm_api_key='test-key-2',
)
session.add(second_member)
await session.commit()
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_first_owner_in_org(org_id)
assert result is not None
assert result.id == first_owner_id
@pytest.mark.asyncio
async def test_get_first_owner_in_org_ignores_non_owners(async_session_maker):
"""Test that get_first_owner_in_org ignores users with non-owner roles."""
from datetime import datetime, timedelta
from storage.org_member import OrgMember
from storage.role import Role
org_id = uuid.uuid4()
admin_id = uuid.uuid4()
owner_id = uuid.uuid4()
async with async_session_maker() as session:
# Create org
org = Org(id=org_id, name='test-org')
session.add(org)
# Create roles
owner_role = Role(id=1, name='owner', rank=10)
admin_role = Role(id=2, name='admin', rank=20)
session.add(owner_role)
session.add(admin_role)
# Create admin with earlier TOS acceptance
admin_user = User(
id=admin_id,
current_org_id=org_id,
accepted_tos=datetime.now() - timedelta(days=10),
)
session.add(admin_user)
# Create owner with later TOS acceptance
owner_user = User(
id=owner_id,
current_org_id=org_id,
accepted_tos=datetime.now() - timedelta(days=5),
)
session.add(owner_user)
await session.flush()
# Add admin member
admin_member = OrgMember(
org_id=org_id,
user_id=admin_id,
role_id=admin_role.id,
llm_api_key='test-key-admin',
)
session.add(admin_member)
# Add owner member
owner_member = OrgMember(
org_id=org_id,
user_id=owner_id,
role_id=owner_role.id,
llm_api_key='test-key-owner',
)
session.add(owner_member)
await session.commit()
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_first_owner_in_org(org_id)
# Should return the owner, not the admin (even though admin has earlier TOS)
assert result is not None
assert result.id == owner_id
@pytest.mark.asyncio
async def test_get_first_owner_in_org_returns_none_when_no_owners(async_session_maker):
"""Test that get_first_owner_in_org returns None when org has no owners."""
from datetime import datetime
from storage.org_member import OrgMember
from storage.role import Role
org_id = uuid.uuid4()
member_id = uuid.uuid4()
async with async_session_maker() as session:
# Create org
org = Org(id=org_id, name='test-org')
session.add(org)
# Create member role only
member_role = Role(id=3, name='member', rank=100)
session.add(member_role)
# Create user with member role
member_user = User(
id=member_id,
current_org_id=org_id,
accepted_tos=datetime.now(),
)
session.add(member_user)
await session.flush()
# Add as member
member = OrgMember(
org_id=org_id,
user_id=member_id,
role_id=member_role.id,
llm_api_key='test-key',
)
session.add(member)
await session.commit()
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_first_owner_in_org(org_id)
assert result is None
@pytest.mark.asyncio
async def test_get_first_owner_in_org_ignores_owners_without_tos(async_session_maker):
"""Test that get_first_owner_in_org ignores owners who haven't accepted TOS."""
from datetime import datetime
from storage.org_member import OrgMember
from storage.role import Role
org_id = uuid.uuid4()
owner_no_tos_id = uuid.uuid4()
owner_with_tos_id = uuid.uuid4()
async with async_session_maker() as session:
# Create org
org = Org(id=org_id, name='test-org')
session.add(org)
# Create owner role
owner_role = Role(id=1, name='owner', rank=10)
session.add(owner_role)
# Create owner without TOS
owner_no_tos = User(
id=owner_no_tos_id,
current_org_id=org_id,
accepted_tos=None,
)
session.add(owner_no_tos)
# Create owner with TOS
owner_with_tos = User(
id=owner_with_tos_id,
current_org_id=org_id,
accepted_tos=datetime.now(),
)
session.add(owner_with_tos)
await session.flush()
# Add both as owners
member_no_tos = OrgMember(
org_id=org_id,
user_id=owner_no_tos_id,
role_id=owner_role.id,
llm_api_key='test-key-1',
)
session.add(member_no_tos)
member_with_tos = OrgMember(
org_id=org_id,
user_id=owner_with_tos_id,
role_id=owner_role.id,
llm_api_key='test-key-2',
)
session.add(member_with_tos)
await session.commit()
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_first_owner_in_org(org_id)
# Should return the owner who has accepted TOS
assert result is not None
assert result.id == owner_with_tos_id
@pytest.mark.asyncio
async def test_get_first_owner_in_org_returns_none_for_empty_org(async_session_maker):
"""Test that get_first_owner_in_org returns None for org with no members."""
org_id = uuid.uuid4()
async with async_session_maker() as session:
# Create org only, no members
org = Org(id=org_id, name='empty-org')
session.add(org)
await session.commit()
with patch('storage.user_store.a_session_maker', async_session_maker):
result = await UserStore.get_first_owner_in_org(org_id)
assert result is None
@@ -49,34 +49,17 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
displayErrorToast: vi.fn(),
}));
const mockUseAppMode = vi.fn(() => ({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas" as string | undefined,
deploymentMode: "cloud" as string | undefined,
}));
vi.mock("#/hooks/use-app-mode", () => ({
useAppMode: () => mockUseAppMode(),
// Mock feature flags - we'll control the return value in each test
const mockEnableProjUserJourney = vi.fn(() => true);
vi.mock("#/utils/feature-flags", () => ({
ENABLE_PROJ_USER_JOURNEY: () => mockEnableProjUserJourney(),
}));
describe("LoginContent", () => {
beforeEach(() => {
vi.stubGlobal("location", { href: "" });
// Reset mock to return SaaS Cloud (CTA enabled) by default
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas",
deploymentMode: "cloud",
});
// Reset mock to return true by default
mockEnableProjUserJourney.mockReturnValue(true);
});
afterEach(() => {
@@ -299,18 +282,7 @@ describe("LoginContent", () => {
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
});
it("should display the enterprise LoginCTA component when in SaaS Cloud mode", () => {
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas",
deploymentMode: "cloud",
});
it("should display the enterprise LoginCTA component when appMode is saas and feature flag enabled", () => {
render(
<MemoryRouter>
<LoginContent
@@ -324,18 +296,7 @@ describe("LoginContent", () => {
expect(screen.getByTestId("login-cta")).toBeInTheDocument();
});
it("should not display the enterprise LoginCTA component when in OSS mode", () => {
mockUseAppMode.mockReturnValue({
isOss: true,
isSaas: false,
isCloud: false,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: false,
appMode: "oss",
deploymentMode: undefined,
});
it("should not display the enterprise LoginCTA component when appMode is oss even with feature flag enabled", () => {
render(
<MemoryRouter>
<LoginContent
@@ -349,17 +310,23 @@ describe("LoginContent", () => {
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
});
it("should not display the enterprise LoginCTA component when in SaaS Self-hosted mode", () => {
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: false,
isSelfHosted: true,
isEnterpriseSelfHosted: true,
isEnterpriseCloud: false,
appMode: "saas",
deploymentMode: "self_hosted",
});
it("should not display the enterprise LoginCTA component when appMode is null", () => {
render(
<MemoryRouter>
<LoginContent
githubAuthUrl="https://github.com/oauth/authorize"
appMode={null}
providersConfigured={["github"]}
/>
</MemoryRouter>,
);
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
});
it("should not display the enterprise LoginCTA component when feature flag is disabled", () => {
// Disable the feature flag
mockEnableProjUserJourney.mockReturnValue(false);
render(
<MemoryRouter>
@@ -0,0 +1,118 @@
import { screen } from "@testing-library/react";
import { beforeEach, describe, expect, it, vi } from "vitest";
import userEvent from "@testing-library/user-event";
import { renderWithProviders } from "test-utils";
import { EnterpriseBanner } from "#/components/features/device-verify/enterprise-banner";
const mockCapture = vi.fn();
vi.mock("posthog-js/react", () => ({
usePostHog: () => ({
capture: mockCapture,
}),
}));
const { ENABLE_PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
ENABLE_PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
}));
vi.mock("#/utils/feature-flags", () => ({
ENABLE_PROJ_USER_JOURNEY: () => ENABLE_PROJ_USER_JOURNEY_MOCK(),
}));
describe("EnterpriseBanner", () => {
beforeEach(() => {
vi.clearAllMocks();
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
});
describe("Feature Flag", () => {
it("should not render when proj_user_journey feature flag is disabled", () => {
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
const { container } = renderWithProviders(<EnterpriseBanner />);
expect(container.firstChild).toBeNull();
expect(screen.queryByText("ENTERPRISE$TITLE")).not.toBeInTheDocument();
});
it("should render when proj_user_journey feature flag is enabled", () => {
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
renderWithProviders(<EnterpriseBanner />);
expect(screen.getByText("ENTERPRISE$TITLE")).toBeInTheDocument();
});
});
describe("Rendering", () => {
it("should render the self-hosted label", () => {
renderWithProviders(<EnterpriseBanner />);
expect(screen.getByText("ENTERPRISE$SELF_HOSTED")).toBeInTheDocument();
});
it("should render the enterprise title", () => {
renderWithProviders(<EnterpriseBanner />);
expect(screen.getByText("ENTERPRISE$TITLE")).toBeInTheDocument();
});
it("should render the enterprise description", () => {
renderWithProviders(<EnterpriseBanner />);
expect(screen.getByText("ENTERPRISE$DESCRIPTION")).toBeInTheDocument();
});
it("should render all four enterprise feature items", () => {
renderWithProviders(<EnterpriseBanner />);
expect(
screen.getByText("ENTERPRISE$FEATURE_DATA_PRIVACY"),
).toBeInTheDocument();
expect(
screen.getByText("ENTERPRISE$FEATURE_DEPLOYMENT"),
).toBeInTheDocument();
expect(screen.getByText("ENTERPRISE$FEATURE_SSO")).toBeInTheDocument();
expect(
screen.getByText("ENTERPRISE$FEATURE_SUPPORT"),
).toBeInTheDocument();
});
it("should render the learn more link", () => {
renderWithProviders(<EnterpriseBanner />);
const link = screen.getByRole("link", {
name: "ENTERPRISE$LEARN_MORE_ARIA",
});
expect(link).toBeInTheDocument();
expect(link).toHaveTextContent("ENTERPRISE$LEARN_MORE");
expect(link).toHaveAttribute("href", "https://openhands.dev/enterprise");
expect(link).toHaveAttribute("target", "_blank");
expect(link).toHaveAttribute("rel", "noopener noreferrer");
});
});
describe("Learn More Link Interaction", () => {
it("should capture PostHog event when learn more link is clicked", async () => {
const user = userEvent.setup();
renderWithProviders(<EnterpriseBanner />);
const link = screen.getByRole("link", {
name: "ENTERPRISE$LEARN_MORE_ARIA",
});
await user.click(link);
expect(mockCapture).toHaveBeenCalledWith("saas_selfhosted_inquiry");
});
it("should have correct href attribute for opening in new tab", () => {
renderWithProviders(<EnterpriseBanner />);
const link = screen.getByRole("link", {
name: "ENTERPRISE$LEARN_MORE_ARIA",
});
expect(link).toHaveAttribute("href", "https://openhands.dev/enterprise");
expect(link).toHaveAttribute("target", "_blank");
});
});
});
@@ -90,15 +90,14 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
// Mock the search function that's used by the dropdown
vi.spyOn(GitService, "searchGitRepositories").mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
});
vi.spyOn(GitService, "searchGitRepositories").mockResolvedValue(
MOCK_RESPOSITORIES,
);
renderRepoConnector();
@@ -128,8 +127,8 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
renderRepoConnector();
@@ -139,11 +138,14 @@ describe("RepoConnector", () => {
// Mock the repository branches API call
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
items: [
branches: [
{ name: "main", commit_sha: "123", protected: false },
{ name: "develop", commit_sha: "456", protected: false },
],
next_page_id: null,
has_next_page: false,
current_page: 1,
per_page: 30,
total_count: 2,
});
// First select the provider
@@ -197,8 +199,8 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
renderRepoConnector();
@@ -244,8 +246,8 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
renderRepoConnector();
@@ -288,8 +290,8 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
renderRepoConnector();
@@ -345,8 +347,8 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
renderRepoConnector();
@@ -361,11 +363,14 @@ describe("RepoConnector", () => {
// Mock the repository branches API call
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
items: [
branches: [
{ name: "main", commit_sha: "123", protected: false },
{ name: "develop", commit_sha: "456", protected: false },
],
next_page_id: null,
has_next_page: false,
current_page: 1,
per_page: 30,
total_count: 2,
});
// First select the provider
@@ -422,17 +427,20 @@ describe("RepoConnector", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
// Mock the repository branches API call
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
items: [
branches: [
{ name: "main", commit_sha: "123", protected: false },
{ name: "develop", commit_sha: "456", protected: false },
],
next_page_id: null,
has_next_page: false,
current_page: 1,
per_page: 30,
total_count: 2,
});
renderRepoConnector();
@@ -187,7 +187,7 @@ describe("RepositorySelectionForm", () => {
},
];
mockUseGitRepositories.mockReturnValue({
data: { pages: [{ items: MOCK_REPOS }] },
data: { pages: [{ data: MOCK_REPOS }] },
isLoading: false,
isError: false,
hasNextPage: false,
@@ -229,7 +229,7 @@ describe("RepositorySelectionForm", () => {
// Create a spy on the API call
const searchGitReposSpy = vi.spyOn(GitService, "searchGitRepositories");
searchGitReposSpy.mockResolvedValue({ items: MOCK_SEARCH_REPOS, next_page_id: null });
searchGitReposSpy.mockResolvedValue(MOCK_SEARCH_REPOS);
mockUseGitRepositories.mockReturnValue({
data: { pages: [] },
@@ -267,7 +267,7 @@ describe("RepositorySelectionForm", () => {
];
mockUseGitRepositories.mockReturnValue({
data: { pages: [{ items: MOCK_SEARCH_REPOS }] },
data: { pages: [{ data: MOCK_SEARCH_REPOS }] },
isLoading: false,
isError: false,
hasNextPage: false,
@@ -115,8 +115,8 @@ describe("TaskCard", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
});
@@ -1,20 +1,15 @@
import { render, screen } from "@testing-library/react";
import { screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { createRoutesStub } from "react-router";
import { MemoryRouter } from "react-router";
import { beforeEach, describe, expect, it, vi } from "vitest";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { I18nextProvider } from "react-i18next";
import i18n from "i18next";
import { renderWithProviders } from "../../../../test-utils";
import OnboardingForm from "#/routes/onboarding-form";
const mockMutate = vi.fn();
const mockNavigate = vi.fn();
const mockUseMe = vi.fn();
const mockUseConfig = vi.fn();
const mockTrackOnboardingCompleted = vi.fn();
// Loader data set in beforeEach for each test suite
let loaderData: { config: { app_mode: string; feature_flags: { deployment_mode: string } } };
vi.mock("react-router", async (importOriginal) => {
const original = await importOriginal<typeof import("react-router")>();
return {
@@ -29,8 +24,8 @@ vi.mock("#/hooks/mutation/use-submit-onboarding", () => ({
}),
}));
vi.mock("#/hooks/query/use-me", () => ({
useMe: () => mockUseMe(),
vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => mockUseConfig(),
}));
vi.mock("#/hooks/use-tracking", () => ({
@@ -39,71 +34,49 @@ vi.mock("#/hooks/use-tracking", () => ({
}),
}));
const renderOnboardingForm = async () => {
const queryClient = new QueryClient({
defaultOptions: { queries: { retry: false } },
});
const RouterStub = createRoutesStub([
{
path: "/",
Component: OnboardingForm,
loader: () => loaderData,
},
]);
const result = render(
<I18nextProvider i18n={i18n}>
<QueryClientProvider client={queryClient}>
<RouterStub initialEntries={["/"]} />
</QueryClientProvider>
</I18nextProvider>,
const renderOnboardingForm = () => {
return renderWithProviders(
<MemoryRouter>
<OnboardingForm />
</MemoryRouter>,
);
// Wait for the component to render
await screen.findByTestId("onboarding-form");
return result;
};
describe("OnboardingForm - Cloud Mode", () => {
describe("OnboardingForm - SaaS Mode", () => {
beforeEach(() => {
mockMutate.mockClear();
mockNavigate.mockClear();
mockTrackOnboardingCompleted.mockClear();
loaderData = {
config: {
app_mode: "saas",
feature_flags: { deployment_mode: "cloud" },
},
};
// Cloud mode tracks all users, role doesn't matter
mockUseMe.mockReturnValue({ data: { role: "member" } });
mockUseConfig.mockReturnValue({
data: { app_mode: "saas" },
isLoading: false,
});
});
it("should render with the correct test id", async () => {
await renderOnboardingForm();
it("should render with the correct test id", () => {
renderOnboardingForm();
expect(screen.getByTestId("onboarding-form")).toBeInTheDocument();
});
it("should render the first step initially", async () => {
await renderOnboardingForm();
it("should render the first step initially", () => {
renderOnboardingForm();
expect(screen.getByTestId("step-header")).toBeInTheDocument();
expect(screen.getByTestId("step-content")).toBeInTheDocument();
expect(screen.getByTestId("step-actions")).toBeInTheDocument();
});
it("should display step progress indicator with 3 bars for cloud mode", async () => {
await renderOnboardingForm();
it("should display step progress indicator with 3 bars for saas mode", () => {
renderOnboardingForm();
const stepHeader = screen.getByTestId("step-header");
const progressBars = stepHeader.querySelectorAll(".rounded-full");
expect(progressBars).toHaveLength(3);
});
it("should have the Next button disabled when no option is selected", async () => {
await renderOnboardingForm();
it("should have the Next button disabled when no option is selected", () => {
renderOnboardingForm();
const nextButton = screen.getByRole("button", { name: /next/i });
expect(nextButton).toBeDisabled();
@@ -111,7 +84,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should enable the Next button when an option is selected", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
await user.click(screen.getByTestId("step-option-solo"));
@@ -121,7 +94,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should advance to the next step when Next is clicked", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// On step 1, first progress bar should be filled (bg-white)
const stepHeader = screen.getByTestId("step-header");
@@ -138,7 +111,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should disable Next button again on new step until option is selected", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
await user.click(screen.getByTestId("step-option-solo"));
await user.click(screen.getByRole("button", { name: /next/i }));
@@ -149,7 +122,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should call submitOnboarding with selections when finishing the last step", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Step 1 - select org size (first step in saas mode - single select)
await user.click(screen.getByTestId("step-option-org_2_10"));
@@ -173,11 +146,11 @@ describe("OnboardingForm - Cloud Mode", () => {
});
});
it("should track onboarding completion to PostHog in cloud mode", async () => {
it("should track onboarding completion to PostHog in SaaS mode", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Complete the full cloud onboarding flow
// Complete the full SaaS onboarding flow
await user.click(screen.getByTestId("step-option-org_2_10"));
await user.click(screen.getByRole("button", { name: /next/i }));
@@ -195,8 +168,8 @@ describe("OnboardingForm - Cloud Mode", () => {
});
});
it("should render 5 options on step 1 (org size question)", async () => {
await renderOnboardingForm();
it("should render 5 options on step 1 (org size question)", () => {
renderOnboardingForm();
const options = screen
.getAllByRole("button")
@@ -208,7 +181,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should preserve selections when navigating through steps", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Select org size on step 1 (single select)
await user.click(screen.getByTestId("step-option-solo"));
@@ -234,7 +207,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should allow selecting multiple options on multi-select steps", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Step 1 - select org size (single select)
await user.click(screen.getByTestId("step-option-solo"));
@@ -261,7 +234,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should allow deselecting options on multi-select steps", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Step 1 - select org size
await user.click(screen.getByTestId("step-option-solo"));
@@ -289,7 +262,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should show all progress bars filled on the last step", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Navigate to step 3
await user.click(screen.getByTestId("step-option-solo"));
@@ -304,8 +277,8 @@ describe("OnboardingForm - Cloud Mode", () => {
expect(progressBars).toHaveLength(3);
});
it("should not render the Back button on the first step", async () => {
await renderOnboardingForm();
it("should not render the Back button on the first step", () => {
renderOnboardingForm();
const backButton = screen.queryByRole("button", { name: /back/i });
expect(backButton).not.toBeInTheDocument();
@@ -313,7 +286,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should render the Back button on step 2", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
await user.click(screen.getByTestId("step-option-solo"));
await user.click(screen.getByRole("button", { name: /next/i }));
@@ -324,7 +297,7 @@ describe("OnboardingForm - Cloud Mode", () => {
it("should go back to the previous step when Back is clicked", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
renderOnboardingForm();
// Navigate to step 2
await user.click(screen.getByTestId("step-option-solo"));
@@ -343,197 +316,3 @@ describe("OnboardingForm - Cloud Mode", () => {
expect(progressBars).toHaveLength(1);
});
});
describe("OnboardingForm - Self-Hosted Mode", () => {
// Self-hosted mode has 3 steps: org_name, org_size, use_case
// The role question is cloud-only and not shown in self-hosted mode
beforeEach(() => {
mockMutate.mockClear();
mockNavigate.mockClear();
mockTrackOnboardingCompleted.mockClear();
loaderData = {
config: {
app_mode: "saas",
feature_flags: { deployment_mode: "self_hosted" },
},
};
// Self-hosted mode only tracks org owners
mockUseMe.mockReturnValue({ data: { role: "owner" } });
});
it("should render with the correct test id", async () => {
await renderOnboardingForm();
expect(screen.getByTestId("onboarding-form")).toBeInTheDocument();
});
it("should display step progress indicator with 3 bars for self-hosted mode", async () => {
await renderOnboardingForm();
// Self-hosted has 3 steps: org_name, org_size, use_case (role is cloud-only)
const stepHeader = screen.getByTestId("step-header");
const progressBars = stepHeader.querySelectorAll(".rounded-full");
expect(progressBars).toHaveLength(3);
});
it("should start with org_name question as first step with two input fields", async () => {
await renderOnboardingForm();
// The first step in self-hosted mode should be org_name with two inputs
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
expect(orgNameInput).toBeInTheDocument();
expect(orgDomainInput).toBeInTheDocument();
});
it("should call submitOnboarding with all selections including org_name when finishing", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
// Step 1 - enter org name and domain (input fields)
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
await user.type(orgNameInput, "Acme Corp");
await user.type(orgDomainInput, "acme.com");
await user.click(screen.getByRole("button", { name: /next/i }));
// Step 2 - select org size (single select)
await user.click(screen.getByTestId("step-option-org_2_10"));
await user.click(screen.getByRole("button", { name: /next/i }));
// Step 3 - select use case (multi-select) - this is the last step in self-hosted mode
await user.click(screen.getByTestId("step-option-new_features"));
await user.click(screen.getByRole("button", { name: /finish/i }));
expect(mockMutate).toHaveBeenCalledTimes(1);
expect(mockMutate).toHaveBeenCalledWith({
selections: {
org_name: "Acme Corp",
org_domain: "acme.com",
org_size: "org_2_10",
use_case: ["new_features"],
},
});
});
it("should track onboarding completion in self-hosted mode", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
// Complete the full self-hosted onboarding flow (3 steps)
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
await user.type(orgNameInput, "Test Company");
await user.type(orgDomainInput, "test.com");
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-org_2_10"));
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-new_features"));
await user.click(screen.getByRole("button", { name: /finish/i }));
expect(mockTrackOnboardingCompleted).toHaveBeenCalledTimes(1);
// Note: role is not included since role question is cloud-only
expect(mockTrackOnboardingCompleted).toHaveBeenCalledWith({
role: undefined,
orgSize: "org_2_10",
useCase: ["new_features"],
});
});
it("should show all 3 progress bars filled on the last step", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
// Navigate through all 3 steps
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
await user.type(orgNameInput, "Test Company");
await user.type(orgDomainInput, "test.com");
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-org_2_10"));
await user.click(screen.getByRole("button", { name: /next/i }));
// On step 3, all three progress bars should be filled
const stepHeader = screen.getByTestId("step-header");
const progressBars = stepHeader.querySelectorAll(".bg-white");
expect(progressBars).toHaveLength(3);
});
it("should have Next button disabled when both org_name inputs are empty", async () => {
await renderOnboardingForm();
const nextButton = screen.getByRole("button", { name: /next/i });
expect(nextButton).toBeDisabled();
});
it("should enable Next button when both org_name and org_domain are entered", async () => {
const user = userEvent.setup();
await renderOnboardingForm();
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
await user.type(orgNameInput, "My Company");
await user.type(orgDomainInput, "mycompany.com");
const nextButton = screen.getByRole("button", { name: /next/i });
expect(nextButton).not.toBeDisabled();
});
it("should NOT track onboarding completion for non-owners in self-hosted mode", async () => {
// Override the mock to return a member (non-owner) role
mockUseMe.mockReturnValue({ data: { role: "member" } });
const user = userEvent.setup();
await renderOnboardingForm();
// Complete the full self-hosted onboarding flow (3 steps)
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
await user.type(orgNameInput, "Test Company");
await user.type(orgDomainInput, "test.com");
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-org_2_10"));
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-new_features"));
await user.click(screen.getByRole("button", { name: /finish/i }));
// Tracking should NOT be called for non-owners in self-hosted mode
expect(mockTrackOnboardingCompleted).not.toHaveBeenCalled();
// But onboarding submission should still work
expect(mockMutate).toHaveBeenCalledTimes(1);
});
it("should NOT track onboarding completion for admins in self-hosted mode", async () => {
// Override the mock to return an admin role
mockUseMe.mockReturnValue({ data: { role: "admin" } });
const user = userEvent.setup();
await renderOnboardingForm();
// Complete the full self-hosted onboarding flow (3 steps)
const orgNameInput = screen.getByTestId("form-input-org_name");
const orgDomainInput = screen.getByTestId("form-input-org_domain");
await user.type(orgNameInput, "Test Company");
await user.type(orgDomainInput, "test.com");
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-org_2_10"));
await user.click(screen.getByRole("button", { name: /next/i }));
await user.click(screen.getByTestId("step-option-new_features"));
await user.click(screen.getByRole("button", { name: /finish/i }));
// Tracking should NOT be called for admins in self-hosted mode (only owners)
expect(mockTrackOnboardingCompleted).not.toHaveBeenCalled();
// But onboarding submission should still work
expect(mockMutate).toHaveBeenCalledTimes(1);
});
});
@@ -11,7 +11,6 @@ import OptionService from "#/api/option-service/option-service.api";
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
import { WebClientConfig } from "#/api/option-service/option.types";
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
import * as FeatureFlags from "#/utils/feature-flags";
// Helper to create mock config with sensible defaults
const createMockConfig = (
@@ -186,36 +185,4 @@ describe("Sidebar", () => {
});
});
});
describe("Automations button visibility", () => {
let enableAutomationsSpy: ReturnType<typeof vi.spyOn>;
beforeEach(() => {
enableAutomationsSpy = vi.spyOn(FeatureFlags, "ENABLE_AUTOMATIONS");
});
it("should show automations button when ENABLE_AUTOMATIONS flag is on", async () => {
enableAutomationsSpy.mockReturnValue(true);
renderSidebar();
await waitFor(() => {
expect(
screen.getByTestId("automations-button"),
).toBeInTheDocument();
});
});
it("should hide automations button when ENABLE_AUTOMATIONS flag is off", async () => {
enableAutomationsSpy.mockReturnValue(false);
renderSidebar();
await waitFor(() => expect(getSettingsSpy).toHaveBeenCalled());
expect(
screen.queryByTestId("automations-button"),
).not.toBeInTheDocument();
});
});
});
@@ -23,6 +23,12 @@ vi.mock("#/hooks/use-breakpoint", () => ({
useBreakpoint: vi.fn(() => false), // Default to desktop (not mobile)
}));
// Mock feature flags
const mockEnableProjUserJourney = vi.fn(() => true);
vi.mock("#/utils/feature-flags", () => ({
ENABLE_PROJ_USER_JOURNEY: () => mockEnableProjUserJourney(),
}));
// Mock useTracking hook for CTA
vi.mock("#/hooks/use-tracking", () => ({
useTracking: () => ({
@@ -138,8 +144,9 @@ describe("UserContextMenu", () => {
// Ensure clean state at the start of each test
vi.restoreAllMocks();
useSelectedOrganizationStore.setState({ organizationId: null });
// Reset breakpoint mock to desktop by default
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(false);
// Reset feature flag and breakpoint mocks to defaults
mockEnableProjUserJourney.mockReturnValue(true);
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(false); // Desktop by default
});
afterEach(() => {
@@ -743,26 +750,15 @@ describe("UserContextMenu", () => {
});
describe("Context Menu CTA", () => {
it("should render the CTA component in SaaS Cloud mode on desktop", async () => {
it("should render the CTA component in SaaS mode on desktop with feature flag enabled", async () => {
// Set SaaS mode
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
createMockWebClientConfig({
app_mode: "saas",
feature_flags: {
deployment_mode: "cloud",
enable_billing: false,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
hide_users_page: false,
hide_billing_page: false,
hide_integrations_page: false,
},
}),
createMockWebClientConfig({ app_mode: "saas" }),
);
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
// Wait for config to load
await waitFor(() => {
expect(screen.getByTestId("context-menu-cta")).toBeInTheDocument();
});
@@ -770,73 +766,59 @@ describe("UserContextMenu", () => {
expect(screen.getByText("CTA$LEARN_MORE")).toBeInTheDocument();
});
it("should not render the CTA component in OSS mode", async () => {
it("should not render the CTA component in OSS mode even with feature flag enabled", async () => {
// Set OSS mode
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
createMockWebClientConfig({ app_mode: "oss" }),
);
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
// Wait for config to load
await waitFor(() => {
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
});
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
});
it("should not render the CTA component on mobile even in SaaS Cloud mode", async () => {
it("should not render the CTA component on mobile even in SaaS mode with feature flag enabled", async () => {
// Set SaaS mode
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
createMockWebClientConfig({
app_mode: "saas",
feature_flags: {
deployment_mode: "cloud",
enable_billing: false,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
hide_users_page: false,
hide_billing_page: false,
hide_integrations_page: false,
},
}),
createMockWebClientConfig({ app_mode: "saas" }),
);
// Set mobile mode
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(true);
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
// Wait for config to load
await waitFor(() => {
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
});
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
});
it("should not render the CTA component in SaaS Self-hosted mode", async () => {
it("should not render the CTA component when feature flag is disabled in SaaS mode", async () => {
// Set SaaS mode
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
createMockWebClientConfig({
app_mode: "saas",
feature_flags: {
deployment_mode: "self_hosted",
enable_billing: false,
hide_llm_settings: false,
enable_jira: false,
enable_jira_dc: false,
enable_linear: false,
hide_users_page: false,
hide_billing_page: false,
hide_integrations_page: false,
},
}),
createMockWebClientConfig({ app_mode: "saas" }),
);
// Disable the feature flag
mockEnableProjUserJourney.mockReturnValue(false);
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
// Wait for config to load
await waitFor(() => {
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
});
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
});
});
});
@@ -1,45 +1,12 @@
import { describe, it, expect, vi } from "vitest";
import { render, screen, waitFor } from "@testing-library/react";
import { render, screen } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { ModelSelector } from "#/components/shared/modals/settings/model-selector";
import type { LLMProvider, LLMModel } from "#/api/config-service/config-service.types";
const mockProviders: LLMProvider[] = [
{ name: "openai", verified: true },
{ name: "azure", verified: false },
{ name: "vertex_ai", verified: false },
];
const mockModelsByProvider: Record<string, LLMModel[]> = {
openai: [
{ provider: "openai", name: "gpt-4o", verified: true },
{ provider: "openai", name: "gpt-4o-mini", verified: true },
],
azure: [
{ provider: "azure", name: "ada", verified: false },
{ provider: "azure", name: "gpt-35-turbo", verified: false },
],
vertex_ai: [
{ provider: "vertex_ai", name: "chat-bison", verified: false },
{ provider: "vertex_ai", name: "chat-bison-32k", verified: false },
],
};
vi.mock("#/hooks/query/use-search-providers", () => ({
useSearchProviders: () => ({ data: mockProviders }),
}));
vi.mock("#/hooks/query/use-provider-models", () => ({
useProviderModels: (provider: string | null) => ({
data: provider ? (mockModelsByProvider[provider] ?? []) : [],
}),
}));
vi.mock("react-i18next", () => ({
useTranslation: () => ({
t: (key: string) => {
const translations: Record<string, string> = {
const translations: { [key: string]: string } = {
LLM$PROVIDER: "LLM Provider",
LLM$MODEL: "LLM Model",
LLM$SELECT_PROVIDER_PLACEHOLDER: "Select a provider",
@@ -50,19 +17,34 @@ vi.mock("react-i18next", () => ({
}),
}));
function renderWithQuery(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: { queries: { retry: false } },
});
return render(
<QueryClientProvider client={queryClient}>{ui}</QueryClientProvider>,
);
}
describe("ModelSelector", () => {
const models = {
openai: {
separator: "/",
models: ["gpt-4o", "gpt-4o-mini"],
},
azure: {
separator: "/",
models: ["ada", "gpt-35-turbo"],
},
vertex_ai: {
separator: "/",
models: ["chat-bison", "chat-bison-32k"],
},
};
const verifiedModels = ["gpt-4o", "gpt-4o-mini"];
const verifiedProviders = ["openai"];
it("should display the provider selector", async () => {
const user = userEvent.setup();
renderWithQuery(<ModelSelector />);
render(
<ModelSelector
models={models}
verifiedModels={verifiedModels}
verifiedProviders={verifiedProviders}
/>,
);
const selector = screen.getByLabelText("LLM Provider");
expect(selector).toBeInTheDocument();
@@ -76,7 +58,13 @@ describe("ModelSelector", () => {
it("should disable the model selector if the provider is not selected", async () => {
const user = userEvent.setup();
renderWithQuery(<ModelSelector />);
render(
<ModelSelector
models={models}
verifiedModels={verifiedModels}
verifiedProviders={verifiedProviders}
/>,
);
const modelSelector = screen.getByLabelText("LLM Model");
expect(modelSelector).toBeDisabled();
@@ -92,7 +80,13 @@ describe("ModelSelector", () => {
it("should display the model selector", async () => {
const user = userEvent.setup();
renderWithQuery(<ModelSelector />);
render(
<ModelSelector
models={models}
verifiedModels={verifiedModels}
verifiedProviders={verifiedProviders}
/>,
);
const providerSelector = screen.getByLabelText("LLM Provider");
await user.click(providerSelector);
@@ -111,7 +105,14 @@ describe("ModelSelector", () => {
const user = userEvent.setup();
const onChange = vi.fn();
renderWithQuery(<ModelSelector onChange={onChange} />);
render(
<ModelSelector
models={models}
verifiedModels={verifiedModels}
verifiedProviders={verifiedProviders}
onChange={onChange}
/>,
);
const providerSelector = screen.getByLabelText("LLM Provider");
await user.click(providerSelector);
@@ -125,12 +126,18 @@ describe("ModelSelector", () => {
expect(onChange).toHaveBeenNthCalledWith(2, "azure", "ada");
});
it("should have a default value if passed", async () => {
renderWithQuery(<ModelSelector currentModel="azure/ada" />);
await waitFor(() => {
expect(screen.getByLabelText("LLM Provider")).toHaveValue("Azure");
expect(screen.getByLabelText("LLM Model")).toHaveValue("ada");
});
it("should have a default value if passed", async () => {
render(
<ModelSelector
models={models}
verifiedModels={verifiedModels}
verifiedProviders={verifiedProviders}
currentModel="azure/ada"
/>,
);
expect(screen.getByLabelText("LLM Provider")).toHaveValue("Azure");
expect(screen.getByLabelText("LLM Model")).toHaveValue("ada");
});
});
@@ -1,32 +0,0 @@
import { createEvent, fireEvent, render, screen } from "@testing-library/react";
import { describe, expect, it } from "vitest";
import { AutomationsButton } from "#/components/shared/buttons/automations-button";
describe("AutomationsButton", () => {
it("should render a link to /automations", () => {
render(<AutomationsButton />);
const link = screen.getByTestId("automations-button");
expect(link).toBeInTheDocument();
expect(link).toHaveAttribute("href", "/automations");
});
it("should be focusable and accessible when enabled", () => {
render(<AutomationsButton />);
const link = screen.getByTestId("automations-button");
expect(link).toHaveAttribute("tabIndex", "0");
expect(link).toHaveAttribute("aria-label", "SIDEBAR$AUTOMATIONS");
});
it("should prevent navigation and remove from tab order when disabled", () => {
render(<AutomationsButton disabled />);
const link = screen.getByTestId("automations-button");
expect(link).toHaveAttribute("tabIndex", "-1");
const clickEvent = createEvent.click(link);
fireEvent(link, clickEvent);
expect(clickEvent.defaultPrevented).toBe(true);
});
});
@@ -14,7 +14,13 @@ describe("SettingsForm", () => {
const RouteStub = createRoutesStub([
{
Component: () => (
<SettingsForm settings={DEFAULT_SETTINGS} onClose={onCloseMock} />
<SettingsForm
settings={DEFAULT_SETTINGS}
models={[DEFAULT_SETTINGS.llm_model]}
verifiedModels={[]}
verifiedProviders={["openhands"]}
onClose={onCloseMock}
/>
),
path: "/",
},
@@ -5,21 +5,12 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { createRoutesStub } from "react-router";
import DeviceVerify from "#/routes/device-verify";
const { useIsAuthedMock, mockUseAppMode } = vi.hoisted(() => ({
const { useIsAuthedMock, ENABLE_PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
useIsAuthedMock: vi.fn(() => ({
data: false as boolean | undefined,
isLoading: false,
})),
mockUseAppMode: vi.fn(() => ({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas" as string | undefined,
deploymentMode: "cloud" as string | undefined,
})),
ENABLE_PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
}));
vi.mock("#/hooks/query/use-is-authed", () => ({
@@ -32,8 +23,8 @@ vi.mock("posthog-js/react", () => ({
}),
}));
vi.mock("#/hooks/use-app-mode", () => ({
useAppMode: () => mockUseAppMode(),
vi.mock("#/utils/feature-flags", () => ({
ENABLE_PROJ_USER_JOURNEY: () => ENABLE_PROJ_USER_JOURNEY_MOCK(),
}));
const RouterStub = createRoutesStub([
@@ -75,17 +66,8 @@ describe("DeviceVerify", () => {
}),
),
);
// Reset useAppMode to SaaS Cloud (CTA enabled) by default
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas",
deploymentMode: "cloud",
});
// Enable feature flag by default
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
});
afterEach(() => {
@@ -253,17 +235,7 @@ describe("DeviceVerify", () => {
});
});
it("should include the LoginCTA component when in SaaS Cloud mode", async () => {
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas",
deploymentMode: "cloud",
});
it("should include the LoginCTA component when feature flag is enabled", async () => {
useIsAuthedMock.mockReturnValue({
data: true,
isLoading: false,
@@ -281,45 +253,8 @@ describe("DeviceVerify", () => {
});
});
it("should not include the LoginCTA component when in OSS mode", async () => {
mockUseAppMode.mockReturnValue({
isOss: true,
isSaas: false,
isCloud: false,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: false,
appMode: "oss",
deploymentMode: undefined,
});
useIsAuthedMock.mockReturnValue({
data: true,
isLoading: false,
});
render(
<RouterStub initialEntries={["/device-verify?user_code=ABC-123"]} />,
{
wrapper: createWrapper(),
},
);
await waitFor(() => {
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
});
});
it("should not include the LoginCTA and be center-aligned when in SaaS Self-hosted mode", async () => {
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: false,
isSelfHosted: true,
isEnterpriseSelfHosted: true,
isEnterpriseCloud: false,
appMode: "saas",
deploymentMode: "self_hosted",
});
it("should not include the LoginCTA and be center-aligned when feature flag is disabled", async () => {
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
useIsAuthedMock.mockReturnValue({
data: true,
isLoading: false,
+46 -77
View File
@@ -13,7 +13,7 @@ import AuthService from "#/api/auth-service/auth-service.api";
import MainApp from "#/routes/root-layout";
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
const { DEFAULT_FEATURE_FLAGS, useIsAuthedMock, useConfigMock, mockUseAppMode } = vi.hoisted(
const { DEFAULT_FEATURE_FLAGS, useIsAuthedMock, useConfigMock } = vi.hoisted(
() => {
const defaultFeatureFlags = {
enable_billing: false,
@@ -41,16 +41,6 @@ const { DEFAULT_FEATURE_FLAGS, useIsAuthedMock, useConfigMock, mockUseAppMode }
},
isLoading: false,
}),
mockUseAppMode: vi.fn().mockReturnValue({
isOss: true,
isSaas: false,
isCloud: false,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: false,
appMode: "oss",
deploymentMode: undefined,
}),
};
},
);
@@ -63,19 +53,6 @@ vi.mock("#/hooks/query/use-config", () => ({
useConfig: () => useConfigMock(),
}));
vi.mock("#/hooks/use-app-mode", () => ({
useAppMode: () => mockUseAppMode() ?? {
isOss: true,
isSaas: false,
isCloud: false,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: false,
appMode: "oss",
deploymentMode: undefined,
},
}));
const RouterStub = createRoutesStub([
{
Component: MainApp,
@@ -251,17 +228,20 @@ describe("HomeScreen", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
// Mock the repository branches API call
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
items: [
branches: [
{ name: "main", commit_sha: "123", protected: false },
{ name: "develop", commit_sha: "456", protected: false },
],
next_page_id: null,
has_next_page: false,
current_page: 1,
per_page: 30,
total_count: 2,
});
renderHomeScreen();
@@ -292,17 +272,20 @@ describe("HomeScreen", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
// Mock the repository branches API call
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
items: [
branches: [
{ name: "main", commit_sha: "123", protected: false },
{ name: "develop", commit_sha: "456", protected: false },
],
next_page_id: null,
has_next_page: false,
current_page: 1,
per_page: 30,
total_count: 2,
});
renderHomeScreen();
@@ -349,11 +332,14 @@ describe("HomeScreen", () => {
// Mock the repository branches API call
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
items: [
branches: [
{ name: "main", commit_sha: "123", protected: false },
{ name: "develop", commit_sha: "456", protected: false },
],
next_page_id: null,
has_next_page: false,
current_page: 1,
per_page: 30,
total_count: 2,
});
// Select a repository to enable the repo launch button
@@ -385,8 +371,8 @@ describe("HomeScreen", () => {
"retrieveUserGitRepositories",
);
retrieveUserGitRepositoriesSpy.mockResolvedValue({
items: MOCK_RESPOSITORIES,
next_page_id: null,
data: MOCK_RESPOSITORIES,
nextPage: null,
});
});
@@ -635,9 +621,14 @@ describe("HomepageCTA visibility", () => {
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
// Mock localStorage for CTA dismissal
// Mock localStorage to enable the PROJ_USER_JOURNEY feature flag (CTA dismissal also uses localStorage)
vi.stubGlobal("localStorage", {
getItem: vi.fn(() => null),
getItem: vi.fn((key: string) => {
if (key === "FEATURE_PROJ_USER_JOURNEY") {
return "true";
}
return null;
}),
setItem: vi.fn(),
removeItem: vi.fn(),
clear: vi.fn(),
@@ -649,7 +640,7 @@ describe("HomepageCTA visibility", () => {
vi.unstubAllGlobals();
});
it("should show HomepageCTA in SaaS Cloud mode when not dismissed", async () => {
it("should show HomepageCTA in SaaS mode when not dismissed and feature flag enabled", async () => {
useIsAuthedMock.mockReturnValue({
data: true,
isLoading: false,
@@ -657,26 +648,16 @@ describe("HomepageCTA visibility", () => {
isError: false,
});
useConfigMock.mockReturnValue({
data: { app_mode: "saas", feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "cloud" } },
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
isLoading: false,
});
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas",
deploymentMode: "cloud",
});
getConfigSpy.mockResolvedValue({
app_mode: "saas",
posthog_client_key: "test-posthog-key",
providers_configured: ["github"],
auth_url: "https://auth.example.com",
feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "cloud" },
feature_flags: DEFAULT_FEATURE_FLAGS,
maintenance_start_time: null,
recaptcha_site_key: null,
faulty_models: [],
@@ -693,7 +674,7 @@ describe("HomepageCTA visibility", () => {
expect(ctaLink).toBeInTheDocument();
});
it("should not show HomepageCTA in OSS mode", async () => {
it("should not show HomepageCTA in OSS mode even with feature flag enabled", async () => {
useIsAuthedMock.mockReturnValue({
data: true,
isLoading: false,
@@ -704,16 +685,6 @@ describe("HomepageCTA visibility", () => {
data: { app_mode: "oss", feature_flags: DEFAULT_FEATURE_FLAGS },
isLoading: false,
});
mockUseAppMode.mockReturnValue({
isOss: true,
isSaas: false,
isCloud: false,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: false,
appMode: "oss",
deploymentMode: undefined,
});
getConfigSpy.mockResolvedValue({
app_mode: "oss",
@@ -733,10 +704,18 @@ describe("HomepageCTA visibility", () => {
await screen.findByTestId("home-screen");
expect(screen.queryByTestId("homepage-cta-learn-more")).not.toBeInTheDocument();
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
});
it("should not show HomepageCTA in SaaS Self-hosted mode", async () => {
it("should not show HomepageCTA when feature flag is disabled", async () => {
// Override localStorage to disable the feature flag
vi.stubGlobal("localStorage", {
getItem: vi.fn(() => null), // No feature flags set
setItem: vi.fn(),
removeItem: vi.fn(),
clear: vi.fn(),
});
useIsAuthedMock.mockReturnValue({
data: true,
isLoading: false,
@@ -744,26 +723,16 @@ describe("HomepageCTA visibility", () => {
isError: false,
});
useConfigMock.mockReturnValue({
data: { app_mode: "saas", feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "self_hosted" } },
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
isLoading: false,
});
mockUseAppMode.mockReturnValue({
isOss: false,
isSaas: true,
isCloud: false,
isSelfHosted: true,
isEnterpriseSelfHosted: true,
isEnterpriseCloud: false,
appMode: "saas",
deploymentMode: "self_hosted",
});
getConfigSpy.mockResolvedValue({
app_mode: "saas",
posthog_client_key: "test-posthog-key",
providers_configured: ["github"],
auth_url: "https://auth.example.com",
feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "self_hosted" },
feature_flags: DEFAULT_FEATURE_FLAGS,
maintenance_start_time: null,
recaptcha_site_key: null,
faulty_models: [],
@@ -776,7 +745,7 @@ describe("HomepageCTA visibility", () => {
await screen.findByTestId("home-screen");
expect(screen.queryByTestId("homepage-cta-learn-more")).not.toBeInTheDocument();
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
});
});
+3 -12
View File
@@ -73,18 +73,9 @@ vi.mock("#/hooks/use-invitation", () => ({
useInvitation: () => useInvitationMock(),
}));
// Mock useAppMode hook - enable CTA by default (SaaS Cloud mode)
vi.mock("#/hooks/use-app-mode", () => ({
useAppMode: () => ({
isOss: false,
isSaas: true,
isCloud: true,
isSelfHosted: false,
isEnterpriseSelfHosted: false,
isEnterpriseCloud: true,
appMode: "saas",
deploymentMode: "cloud",
}),
// Mock feature flags - enable by default for tests
vi.mock("#/utils/feature-flags", () => ({
ENABLE_PROJ_USER_JOURNEY: () => true,
}));
const RouterStub = createRoutesStub([
@@ -0,0 +1,9 @@
import { test, expect } from "vitest";
import { isNumber } from "../../src/utils/is-number";
test("isNumber", () => {
expect(isNumber(1)).toBe(true);
expect(isNumber(0)).toBe(true);
expect(isNumber("3")).toBe(true);
expect(isNumber("0")).toBe(true);
});
@@ -38,7 +38,6 @@ vi.mock("#/query-client-config", () => ({
queryClient: {
getQueryData: vi.fn(() => mockConfig),
setQueryData: vi.fn(),
fetchQuery: vi.fn(() => Promise.resolve(mockConfig)),
},
}));
+24
View File
@@ -1,4 +1,9 @@
import { describe, it, expect, vi, test } from "vitest";
import {
formatTimestamp,
getExtension,
removeApiKey,
} from "../../src/utils/utils";
import { getStatusText } from "#/utils/utils";
import { AgentState } from "#/types/agent-state";
import { I18nKey } from "#/i18n/declaration";
@@ -17,6 +22,25 @@ const t = (key: string) => {
return translations[key] || key;
};
test("removeApiKey", () => {
const data = [{ args: { LLM_API_KEY: "key", LANGUAGE: "en" } }];
expect(removeApiKey(data)).toEqual([{ args: { LANGUAGE: "en" } }]);
});
test("getExtension", () => {
expect(getExtension("main.go")).toBe("go");
expect(getExtension("get-extension.test.ts")).toBe("ts");
expect(getExtension("directory")).toBe("");
});
test("formatTimestamp", () => {
const morningDate = new Date("2021-10-10T10:10:10.000").toISOString();
expect(formatTimestamp(morningDate)).toBe("10/10/2021, 10:10:10");
const eveningDate = new Date("2021-10-10T22:10:10.000").toISOString();
expect(formatTimestamp(eveningDate)).toBe("10/10/2021, 22:10:10");
});
describe("getStatusText", () => {
it("returns STOPPING when pausing", () => {
const result = getStatusText({
@@ -1,43 +0,0 @@
import { openHands } from "../open-hands-axios";
import type {
LLMModelPage,
ProviderPage,
SearchModelsParams,
SearchProvidersParams,
} from "./config-service.types";
function toSearchParams(
params: SearchModelsParams | SearchProvidersParams,
): string {
const searchParams = new URLSearchParams();
for (const [key, value] of Object.entries(params)) {
if (value !== undefined && value !== null) {
searchParams.append(key, String(value));
}
}
return searchParams.toString();
}
class ConfigService {
static async searchModels(
params: SearchModelsParams = {},
): Promise<LLMModelPage> {
const qs = toSearchParams(params);
const { data } = await openHands.get<LLMModelPage>(
`/api/v1/config/models/search?${qs}`,
);
return data;
}
static async searchProviders(
params: SearchProvidersParams = {},
): Promise<ProviderPage> {
const qs = toSearchParams(params);
const { data } = await openHands.get<ProviderPage>(
`/api/v1/config/providers/search?${qs}`,
);
return data;
}
}
export default ConfigService;
@@ -1,37 +0,0 @@
/** V1 Config API types for models and providers */
export interface LLMModel {
provider: string | null;
name: string;
verified: boolean;
}
export interface LLMModelPage {
items: LLMModel[];
next_page_id: string | null;
}
export interface SearchModelsParams {
page_id?: string;
limit?: number;
query?: string;
verified__eq?: boolean;
provider__eq?: string;
}
export interface LLMProvider {
name: string;
verified: boolean;
}
export interface ProviderPage {
items: LLMProvider[];
next_page_id: string | null;
}
export interface SearchProvidersParams {
page_id?: string;
limit?: number;
query?: string;
verified__eq?: boolean;
}
+94 -95
View File
@@ -1,5 +1,7 @@
import { openHands } from "../open-hands-axios";
import { RepositoryPage, BranchPage, InstallationPage } from "#/types/git";
import { Provider } from "#/types/settings";
import { GitRepository, PaginatedBranchesResponse, Branch } from "#/types/git";
import { extractNextPageFromLink } from "#/utils/extract-next-page-from-link";
import { GitChange, GitChangeDiff } from "../open-hands.types";
import ConversationService from "../conversation-service/conversation-service.api";
@@ -10,122 +12,130 @@ class GitService {
/**
* Search for Git repositories
* @param query Search query
* @param provider Git provider to search in (required)
* @param limit Number of results per page
* @param pageId Cursor for pagination
* @param installationId Filter by installation ID
* @param sortOrder Sort order (asc or desc)
* @returns Paginated repository response
* @param per_page Number of results per page
* @param selected_provider Git provider to search in
* @returns List of matching repositories
*/
static async searchGitRepositories(
query: string,
provider: string,
limit = 100,
pageId?: string,
installationId?: string,
): Promise<RepositoryPage> {
const { data } = await openHands.get<RepositoryPage>(
"/api/v1/git/repositories/search",
per_page = 100,
selected_provider?: Provider,
): Promise<GitRepository[]> {
const response = await openHands.get<GitRepository[]>(
"/api/user/search/repositories",
{
params: {
provider,
query,
limit,
page_id: pageId,
installation_id: installationId,
per_page,
selected_provider,
},
},
);
return data;
return response.data;
}
/**
* Retrieve user's Git repositories
* @param provider Git provider
* @param pageId Cursor for pagination
* @param limit Number of results per page
* @param installationId Filter by installation ID
* @param sortOrder Sort order (asc or desc)
* @param selected_provider Git provider
* @param page Page number
* @param per_page Number of results per page
* @returns User's repositories with pagination info
*/
static async retrieveUserGitRepositories(
provider: string,
pageId?: string,
limit = 30,
installationId?: string,
): Promise<RepositoryPage> {
const { data } = await openHands.get<RepositoryPage>(
"/api/v1/git/repositories/search",
selected_provider: Provider,
page = 1,
per_page = 30,
) {
const { data } = await openHands.get<GitRepository[]>(
"/api/user/repositories",
{
params: {
provider,
limit,
page_id: pageId,
installation_id: installationId,
selected_provider,
sort: "pushed",
page,
per_page,
},
},
);
return data;
const link =
data.length > 0 && data[0].link_header ? data[0].link_header : "";
const nextPage = extractNextPageFromLink(link);
return { data, nextPage };
}
/**
* Retrieve repositories from a specific installation
* @param provider Git provider
* @param selected_provider Git provider
* @param installationIndex Current installation index
* @param installations List of installation IDs
* @param pageId Cursor for pagination
* @param limit Number of results per page
* @param page Page number
* @param per_page Number of results per page
* @returns Installation repositories with pagination info
*/
static async retrieveInstallationRepositories(
provider: string,
selected_provider: Provider,
installationIndex: number,
installations: string[],
pageId?: string,
limit = 30,
): Promise<RepositoryPage> {
page = 1,
per_page = 30,
) {
const installationId = installations[installationIndex];
const { data } = await openHands.get<RepositoryPage>(
"/api/v1/git/repositories/search",
const response = await openHands.get<GitRepository[]>(
"/api/user/repositories",
{
params: {
provider,
limit,
page_id: pageId,
selected_provider,
sort: "pushed",
page,
per_page,
installation_id: installationId,
},
},
);
return data;
const link =
response.data.length > 0 && response.data[0].link_header
? response.data[0].link_header
: "";
const nextPage = extractNextPageFromLink(link);
let nextInstallation: number | null;
if (nextPage) {
nextInstallation = installationIndex;
} else if (installationIndex + 1 < installations.length) {
nextInstallation = installationIndex + 1;
} else {
nextInstallation = null;
}
return {
data: response.data,
nextPage,
installationIndex: nextInstallation,
};
}
/**
* Get repository branches
* @param repository Repository name
* @param provider Git provider (required)
* @param query Search query (required - can be empty string)
* @param pageId Cursor for pagination
* @param limit Number of results per page
* @param page Page number
* @param perPage Number of results per page
* @returns Paginated branches response
*/
static async getRepositoryBranches(
repository: string,
provider: string,
query: string = "",
pageId?: string,
limit = 30,
): Promise<BranchPage> {
const { data } = await openHands.get<BranchPage>(
"/api/v1/git/branches/search",
page: number = 1,
perPage: number = 30,
selectedProvider?: Provider,
): Promise<PaginatedBranchesResponse> {
const { data } = await openHands.get<PaginatedBranchesResponse>(
`/api/user/repository/branches`,
{
params: {
provider,
repository,
query,
page_id: pageId,
limit,
page,
per_page: perPage,
selected_provider: selectedProvider,
},
},
);
@@ -135,51 +145,40 @@ class GitService {
/**
* Search repository branches
* @deprecated Use getRepositoryBranches instead - this method is identical
* @param repository Repository name
* @param provider Git provider (required)
* @param query Search query
* @param pageId Cursor for pagination
* @param limit Number of results per page
* @param perPage Number of results per page
* @param selectedProvider Git provider
* @returns List of matching branches
*/
static async searchRepositoryBranches(
repository: string,
provider: string,
query: string,
pageId?: string,
limit = 30,
): Promise<BranchPage> {
return this.getRepositoryBranches(
repository,
provider,
query,
pageId,
limit,
perPage: number = 30,
selectedProvider?: Provider,
): Promise<Branch[]> {
const { data } = await openHands.get<Branch[]>(
`/api/user/search/branches`,
{
params: {
repository,
query,
per_page: perPage,
selected_provider: selectedProvider,
},
},
);
return data;
}
/**
* Get the user installation IDs
* @param provider The provider to get installation IDs for (github, bitbucket, etc.)
* @param pageId Cursor for pagination
* @param limit Max number of results
* @returns Paginated installation response
* @returns List of installation IDs
*/
static async getUserInstallations(
provider: string,
pageId?: string,
limit = 100,
): Promise<InstallationPage> {
const { data } = await openHands.get<InstallationPage>(
"/api/v1/git/installations/search",
{
params: {
provider,
page_id: pageId,
limit,
},
},
static async getUserInstallationIds(provider: Provider): Promise<string[]> {
const { data } = await openHands.get<string[]>(
`/api/user/installations?provider=${provider}`,
);
return data;
}
@@ -1,7 +1,5 @@
import { Provider } from "#/types/settings";
export type DeploymentMode = "cloud" | "self_hosted";
/**
* Structured response from ``GET /api/options/models``.
*
@@ -28,7 +26,6 @@ export interface WebClientFeatureFlags {
hide_users_page: boolean;
hide_billing_page: boolean;
hide_integrations_page: boolean;
deployment_mode?: DeploymentMode;
}
export interface WebClientConfig {
@@ -9,7 +9,7 @@ class SettingsService {
* Get the settings from the server or use the default settings if not found
*/
static async getSettings(): Promise<Settings> {
const { data } = await openHands.get<Settings>("/api/v1/settings");
const { data } = await openHands.get<Settings>("/api/settings");
return data;
}
@@ -18,7 +18,7 @@ class SettingsService {
* @param settings - the settings to save
*/
static async saveSettings(settings: Partial<Settings>): Promise<boolean> {
const data = await openHands.post("/api/v1/settings", settings);
const data = await openHands.post("/api/settings", settings);
return data.status === 200;
}
}
@@ -14,8 +14,8 @@ import { useRecaptcha } from "#/hooks/use-recaptcha";
import { useConfig } from "#/hooks/query/use-config";
import { displayErrorToast } from "#/utils/custom-toast-handlers";
import { cn } from "#/utils/utils";
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
import { LoginCTA } from "./login-cta";
import { useAppMode } from "#/hooks/use-app-mode";
export interface LoginContentProps {
githubAuthUrl: string | null;
@@ -45,7 +45,6 @@ export function LoginContent({
const { t } = useTranslation();
const { trackLoginButtonClick } = useTracking();
const { data: config } = useConfig();
const { isEnterpriseCloud } = useAppMode();
// reCAPTCHA - only need token generation, verification happens at backend callback
const { isReady: recaptchaReady, executeRecaptcha } = useRecaptcha({
@@ -307,7 +306,7 @@ export function LoginContent({
<TermsAndPrivacyNotice className="max-w-[320px] text-[#A3A3A3]" />
</div>
{isEnterpriseCloud && <LoginCTA />}
{appMode === "saas" && ENABLE_PROJ_USER_JOURNEY() && <LoginCTA />}
</div>
);
}
@@ -0,0 +1,69 @@
import { useTranslation } from "react-i18next";
import { usePostHog } from "posthog-js/react";
import { I18nKey } from "#/i18n/declaration";
import { H2, Text } from "#/ui/typography";
import CheckCircleFillIcon from "#/icons/check-circle-fill.svg?react";
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
const ENTERPRISE_FEATURE_KEYS: I18nKey[] = [
I18nKey.ENTERPRISE$FEATURE_DATA_PRIVACY,
I18nKey.ENTERPRISE$FEATURE_DEPLOYMENT,
I18nKey.ENTERPRISE$FEATURE_SSO,
I18nKey.ENTERPRISE$FEATURE_SUPPORT,
];
export function EnterpriseBanner() {
const { t } = useTranslation();
const posthog = usePostHog();
if (!ENABLE_PROJ_USER_JOURNEY()) {
return null;
}
const handleLearnMore = () => {
posthog?.capture("saas_selfhosted_inquiry");
};
return (
<div className="w-full max-w-md mx-auto lg:mx-0 lg:w-80 p-6 rounded-lg bg-gradient-to-b from-slate-800 to-slate-900 border border-slate-700 h-fit">
{/* Self-Hosted Label */}
<div className="flex justify-center mb-4">
<div className="px-8 py-0.5 rounded-full bg-gradient-to-r from-blue-900 to-blue-950 border border-blue-800">
<Text className="text-xs font-medium text-blue-400 tracking-wider uppercase">
{t(I18nKey.ENTERPRISE$SELF_HOSTED)}
</Text>
</div>
</div>
{/* Title */}
<H2 className="text-center mb-3">{t(I18nKey.ENTERPRISE$TITLE)}</H2>
{/* Description */}
<Text className="text-sm text-gray-400 text-center mb-6 block">
{t(I18nKey.ENTERPRISE$DESCRIPTION)}
</Text>
{/* Features List */}
<ul className="space-y-3 mb-6">
{ENTERPRISE_FEATURE_KEYS.map((featureKey) => (
<li key={featureKey} className="flex items-center gap-2">
<CheckCircleFillIcon className="w-4 h-4 text-blue-400 flex-shrink-0" />
<Text className="text-sm text-gray-300">{t(featureKey)}</Text>
</li>
))}
</ul>
{/* Learn More Button */}
<a
href="https://openhands.dev/enterprise"
target="_blank"
rel="noopener noreferrer"
onClick={handleLearnMore}
aria-label={t(I18nKey.ENTERPRISE$LEARN_MORE_ARIA)}
className="block w-full py-2.5 px-4 rounded-lg bg-blue-600 hover:bg-blue-700 text-white font-medium transition-colors text-center"
>
{t(I18nKey.ENTERPRISE$LEARN_MORE)}
</a>
</div>
);
}
@@ -38,7 +38,7 @@ export function useRepositoryData(
// Combine all repositories from paginated data
const allRepositories = useMemo(
() => repoData?.pages?.flatMap((page) => page.items) || [],
() => repoData?.pages?.flatMap((page) => page.data) || [],
[repoData],
);
@@ -18,11 +18,11 @@ export function useUrlSearch(inputValue: string, provider: Provider) {
try {
const repositories = await GitService.searchGitRepositories(
repoName,
provider,
3,
provider,
);
setUrlSearchResults(repositories.items);
setUrlSearchResults(repositories);
} catch {
setUrlSearchResults([]);
} finally {
@@ -6,7 +6,6 @@ import { UserActions } from "./user-actions";
import { OpenHandsLogoButton } from "#/components/shared/buttons/openhands-logo-button";
import { NewProjectButton } from "#/components/shared/buttons/new-project-button";
import { ConversationPanelButton } from "#/components/shared/buttons/conversation-panel-button";
import { AutomationsButton } from "#/components/shared/buttons/automations-button";
import { SettingsModal } from "#/components/shared/modals/settings/settings-modal";
import { useSettings } from "#/hooks/query/use-settings";
import { ConversationPanel } from "../conversation-panel/conversation-panel";
@@ -15,7 +14,6 @@ import { useConfig } from "#/hooks/query/use-config";
import { displayErrorToast } from "#/utils/custom-toast-handlers";
import { I18nKey } from "#/i18n/declaration";
import { cn } from "#/utils/utils";
import { ENABLE_AUTOMATIONS } from "#/utils/feature-flags";
export function Sidebar() {
const { t } = useTranslation();
@@ -89,11 +87,6 @@ export function Sidebar() {
}
disabled={settings?.email_verified === false}
/>
{ENABLE_AUTOMATIONS() && (
<AutomationsButton
disabled={settings?.email_verified === false}
/>
)}
</div>
<div className="flex flex-row md:flex-col md:items-center gap-[26px]">
@@ -15,9 +15,10 @@ import { ContextMenuCTA } from "../context-menu/context-menu-cta";
import { ContextMenuNavLink } from "../context-menu/context-menu-nav-link";
import { useShouldHideOrgSelector } from "#/hooks/use-should-hide-org-selector";
import { useBreakpoint } from "#/hooks/use-breakpoint";
import { useConfig } from "#/hooks/query/use-config";
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
import { SettingsNavHeader } from "../settings/settings-nav-header";
import { SettingsNavDivider } from "../settings/settings-nav-divider";
import { useAppMode } from "#/hooks/use-app-mode";
// Shared className for context menu list items in the user context menu
const contextMenuListItemClassName = cn(
@@ -41,12 +42,13 @@ export function UserContextMenu({
const settingsNavItems = useSettingsNavItems();
const shouldHideSelector = useShouldHideOrgSelector();
const isMobile = useBreakpoint(768);
const { isSaas, isEnterpriseCloud } = useAppMode();
const { data: config } = useConfig();
// Keep all nav items including headers and dividers for proper section grouping
const navItems = settingsNavItems;
const isMember = type === "member";
const isSaasMode = config?.app_mode === "saas";
// Check if the ORG SETTINGS header exists in nav items
const hasOrgHeader = navItems.some(
@@ -58,9 +60,8 @@ export function UserContextMenu({
// Show invite button for admin/owner in team orgs
const showInviteButton = !isMember && !isPersonalOrg;
// CTA only renders in SaaS Cloud desktop mode
const isCTAEnabled = isEnterpriseCloud && !isMobile;
// CTA only renders in SaaS desktop with feature flag enabled
const showCta = isSaasMode && !isMobile && ENABLE_PROJ_USER_JOURNEY();
const handleLogout = () => {
logout();
onClose();
@@ -154,7 +155,7 @@ export function UserContextMenu({
</a>
{/* Only show logout in saas mode - oss mode has no session to invalidate */}
{isSaas && (
{isSaasMode && (
<ContextMenuListItem
onClick={handleLogout}
className={contextMenuListItemClassName}
@@ -166,7 +167,7 @@ export function UserContextMenu({
</div>
</div>
{isCTAEnabled && <ContextMenuCTA />}
{showCta && <ContextMenuCTA />}
</ContextMenuContainer>
);
}
@@ -1,8 +1,6 @@
import React from "react";
import { PostHogProvider } from "posthog-js/react";
import { queryClient } from "#/query-client-config";
import OptionService from "#/api/option-service/option-service.api";
import { QUERY_KEYS, CONFIG_CACHE_OPTIONS } from "#/hooks/query/query-keys";
import { displayErrorToast } from "#/utils/custom-toast-handlers";
const POSTHOG_BOOTSTRAP_KEY = "posthog_bootstrap";
@@ -49,12 +47,7 @@ export function PostHogWrapper({ children }: { children: React.ReactNode }) {
React.useEffect(() => {
(async () => {
try {
// Use fetchQuery for automatic caching and deduplication
const config = await queryClient.fetchQuery({
queryKey: QUERY_KEYS.WEB_CLIENT_CONFIG,
queryFn: OptionService.getConfig,
...CONFIG_CACHE_OPTIONS,
});
const config = await OptionService.getConfig();
setPosthogClientKey(config.posthog_client_key);
} catch {
displayErrorToast("Error fetching PostHog client key");
@@ -1,38 +0,0 @@
import { useTranslation } from "react-i18next";
import { I18nKey } from "#/i18n/declaration";
import { StyledTooltip } from "#/components/shared/buttons/styled-tooltip";
import AutomationsIcon from "#/icons/automations.svg?react";
import { cn } from "#/utils/utils";
interface AutomationsButtonProps {
disabled?: boolean;
}
export function AutomationsButton({
disabled = false,
}: AutomationsButtonProps) {
const { t } = useTranslation();
const label = t(I18nKey.SIDEBAR$AUTOMATIONS);
return (
<StyledTooltip content={label} placement="right">
<a
href="/automations"
data-testid="automations-button"
aria-label={label}
tabIndex={disabled ? -1 : 0}
onClick={(e) => {
if (disabled) {
e.preventDefault();
}
}}
className={cn("inline-flex items-center justify-center", {
"pointer-events-none opacity-50": disabled,
})}
>
<AutomationsIcon width={24} height={24} />
</a>
</StyledTooltip>
);
}
@@ -11,11 +11,14 @@ import { extractModelAndProvider } from "#/utils/extract-model-and-provider";
import { cn } from "#/utils/utils";
import { HelpLink } from "#/ui/help-link";
import { PRODUCT_URL } from "#/utils/constants";
import { useSearchProviders } from "#/hooks/query/use-search-providers";
import { useProviderModels } from "#/hooks/query/use-provider-models";
interface ModelSelectorProps {
isDisabled?: boolean;
models: Record<string, { separator: string; models: string[] }>;
/** Model names (no provider prefix) the backend considers verified. */
verifiedModels: string[];
/** Provider names the backend considers verified. */
verifiedProviders: string[];
currentModel?: string;
onChange?: (provider: string | null, model: string | null) => void;
onDefaultValuesChanged?: (
@@ -28,6 +31,9 @@ interface ModelSelectorProps {
export function ModelSelector({
isDisabled,
models,
verifiedModels,
verifiedProviders,
currentModel,
onChange,
onDefaultValuesChanged,
@@ -40,51 +46,30 @@ export function ModelSelector({
);
const [selectedModel, setSelectedModel] = React.useState<string | null>(null);
const { data: providers = [] } = useSearchProviders();
const {
data: providerModels = [],
isLoading: isLoadingModels,
error: modelsError,
} = useProviderModels(selectedProvider);
const verifiedProviders = React.useMemo(
() => providers.filter((p) => p.verified),
[providers],
);
const unverifiedProviders = React.useMemo(
() => providers.filter((p) => !p.verified),
[providers],
);
const verifiedModels = React.useMemo(
() => providerModels.filter((m) => m.verified),
[providerModels],
);
const unverifiedModels = React.useMemo(
() => providerModels.filter((m) => !m.verified),
[providerModels],
);
React.useEffect(() => {
if (currentModel) {
// runs when resetting to defaults
const { provider, model } = extractModelAndProvider(currentModel);
setLitellmId(currentModel);
setSelectedProvider(provider || null);
setSelectedProvider(provider);
setSelectedModel(model);
onDefaultValuesChanged?.(provider || null, model);
onDefaultValuesChanged?.(provider, model);
}
}, [currentModel]);
const handleChangeProvider = (provider: string) => {
setSelectedProvider(provider);
setSelectedModel(null);
setLitellmId(`${provider}/`);
const separator = models[provider]?.separator || "";
setLitellmId(provider + separator);
onChange?.(provider, null);
};
const handleChangeModel = (model: string) => {
let fullModel = `${selectedProvider}/${model}`;
const separator = models[selectedProvider || ""]?.separator || "";
let fullModel = selectedProvider + separator + model;
if (selectedProvider === "openai") {
// LiteLLM lists OpenAI models without the openai/ prefix
fullModel = model;
@@ -138,22 +123,28 @@ export function ModelSelector({
}}
>
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$VERIFIED)}>
{verifiedProviders.map((provider) => (
<AutocompleteItem
data-testid={`provider-item-${provider.name}`}
key={provider.name}
>
{mapProvider(provider.name)}
</AutocompleteItem>
))}
</AutocompleteSection>
{unverifiedProviders.length > 0 ? (
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
{unverifiedProviders.map((provider) => (
<AutocompleteItem key={provider.name}>
{mapProvider(provider.name)}
{verifiedProviders
.filter((provider) => models[provider])
.map((provider) => (
<AutocompleteItem
data-testid={`provider-item-${provider}`}
key={provider}
>
{mapProvider(provider)}
</AutocompleteItem>
))}
</AutocompleteSection>
{Object.keys(models).some(
(provider) => !verifiedProviders.includes(provider),
) ? (
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
{Object.keys(models)
.filter((provider) => !verifiedProviders.includes(provider))
.map((provider) => (
<AutocompleteItem key={provider}>
{mapProvider(provider)}
</AutocompleteItem>
))}
</AutocompleteSection>
) : null}
</Autocomplete>
@@ -178,7 +169,6 @@ export function ModelSelector({
data-testid="llm-model-input"
isRequired
isVirtualized={false}
isLoading={isLoadingModels}
name="llm-model-input"
aria-label={t(I18nKey.LLM$MODEL)}
placeholder={t(I18nKey.LLM$SELECT_MODEL_PLACEHOLDER)}
@@ -200,28 +190,31 @@ export function ModelSelector({
}}
>
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$VERIFIED)}>
{verifiedModels.map((model) => (
<AutocompleteItem key={model.name}>{model.name}</AutocompleteItem>
))}
</AutocompleteSection>
{unverifiedModels.length > 0 ? (
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
{unverifiedModels.map((model) => (
<AutocompleteItem
data-testid={`model-item-${model.name}`}
key={model.name}
>
{model.name}
</AutocompleteItem>
{verifiedModels
.filter((model) =>
models[selectedProvider || ""]?.models?.includes(model),
)
.map((model) => (
<AutocompleteItem key={model}>{model}</AutocompleteItem>
))}
</AutocompleteSection>
{models[selectedProvider || ""]?.models?.some(
(model) => !verifiedModels.includes(model),
) ? (
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
{models[selectedProvider || ""]?.models
.filter((model) => !verifiedModels.includes(model))
.map((model) => (
<AutocompleteItem
data-testid={`model-item-${model}`}
key={model}
>
{model}
</AutocompleteItem>
))}
</AutocompleteSection>
) : null}
</Autocomplete>
{modelsError && (
<p data-testid="models-error" className="text-danger text-xs">
{t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS)}
</p>
)}
</fieldset>
</div>
);
@@ -3,6 +3,7 @@ import { useTranslation } from "react-i18next";
import React from "react";
import { usePostHog } from "posthog-js/react";
import { I18nKey } from "#/i18n/declaration";
import { organizeModelsAndProviders } from "#/utils/organize-models-and-providers";
import { DangerModal } from "../confirmation-modals/danger-modal";
import { extractSettings } from "#/utils/settings-utils";
import { ModalBackdrop } from "../modal-backdrop";
@@ -16,10 +17,19 @@ import { SETTINGS_FORM } from "#/utils/constants";
interface SettingsFormProps {
settings: Settings;
models: string[];
verifiedModels: string[];
verifiedProviders: string[];
onClose: () => void;
}
export function SettingsForm({ settings, onClose }: SettingsFormProps) {
export function SettingsForm({
settings,
models,
verifiedModels,
verifiedProviders,
onClose,
}: SettingsFormProps) {
const posthog = usePostHog();
const { mutate: saveUserSettings } = useSaveSettings();
@@ -77,6 +87,9 @@ export function SettingsForm({ settings, onClose }: SettingsFormProps) {
>
<div className="flex flex-col gap-[17px]">
<ModelSelector
models={organizeModelsAndProviders(models)}
verifiedModels={verifiedModels}
verifiedProviders={verifiedProviders}
currentModel={settings.llm_model}
wrapperClassName="!flex-col !gap-[17px]"
labelClassName={SETTINGS_FORM.LABEL_CLASSNAME}
@@ -1,5 +1,7 @@
import { useTranslation } from "react-i18next";
import { useAIConfigOptions } from "#/hooks/query/use-ai-config-options";
import { I18nKey } from "#/i18n/declaration";
import { LoadingSpinner } from "../../loading-spinner";
import { ModalBackdrop } from "../modal-backdrop";
import { SettingsForm } from "./settings-form";
import { Settings } from "#/types/settings";
@@ -12,6 +14,7 @@ interface SettingsModalProps {
}
export function SettingsModal({ onClose, settings }: SettingsModalProps) {
const aiConfigOptions = useAIConfigOptions();
const { t } = useTranslation();
return (
@@ -20,6 +23,9 @@ export function SettingsModal({ onClose, settings }: SettingsModalProps) {
data-testid="ai-config-modal"
className="bg-[#25272D] min-w-full max-w-[475px] m-4 p-6 rounded-xl flex flex-col gap-[17px] border border-tertiary api-configuration-modal"
>
{aiConfigOptions.error && (
<p className="text-danger text-xs">{aiConfigOptions.error.message}</p>
)}
<span className="text-5 leading-6 font-semibold -tracking-[0.2px]">
{t(I18nKey.AI_SETTINGS$TITLE)}
</span>
@@ -34,10 +40,20 @@ export function SettingsModal({ onClose, settings }: SettingsModalProps) {
suffixClassName="text-white"
/>
<SettingsForm
settings={settings || DEFAULT_SETTINGS}
onClose={onClose}
/>
{aiConfigOptions.isLoading && (
<div className="flex justify-center">
<LoadingSpinner size="small" />
</div>
)}
{aiConfigOptions.data && (
<SettingsForm
settings={settings || DEFAULT_SETTINGS}
models={aiConfigOptions.data?.models}
verifiedModels={aiConfigOptions.data?.verifiedModels ?? []}
verifiedProviders={aiConfigOptions.data?.verifiedProviders ?? []}
onClose={onClose}
/>
)}
</div>
</ModalBackdrop>
);
+6 -7
View File
@@ -1,6 +1,6 @@
import { I18nKey } from "#/i18n/declaration";
export type OnboardingAppMode = "cloud" | "self-hosted" | "oss";
export type OnboardingAppMode = "saas" | "self-hosted";
interface BaseOnboardingQuestion {
id: string;
@@ -43,9 +43,9 @@ export const ONBOARDING_FORM: OnboardingQuestion[] = [
{
id: "org_size",
type: "single",
app_mode: ["cloud", "self-hosted"],
app_mode: ["saas", "self-hosted"],
questionKey: I18nKey.ONBOARDING$ORG_SIZE_QUESTION,
subtitleKey: I18nKey.ONBOARDING$SELECT_ONE_SUBTITLE,
subtitleKey: I18nKey.ONBOARDING$ORG_SIZE_SUBTITLE,
answerOptions: [
{ key: I18nKey.ONBOARDING$ORG_SIZE_SOLO, id: "solo" },
{ key: I18nKey.ONBOARDING$ORG_SIZE_2_10, id: "org_2_10" },
@@ -57,9 +57,9 @@ export const ONBOARDING_FORM: OnboardingQuestion[] = [
{
id: "use_case",
type: "multi",
app_mode: ["cloud", "self-hosted"],
app_mode: ["saas", "self-hosted"],
questionKey: I18nKey.ONBOARDING$USE_CASE_QUESTION,
subtitleKey: I18nKey.ONBOARDING$SELECT_MULTIPLE_SUBTITLE,
subtitleKey: I18nKey.ONBOARDING$USE_CASE_SUBTITLE,
answerOptions: [
{ key: I18nKey.ONBOARDING$USE_CASE_NEW_FEATURES, id: "new_features" },
{
@@ -78,9 +78,8 @@ export const ONBOARDING_FORM: OnboardingQuestion[] = [
{
id: "role",
type: "single",
app_mode: ["cloud"],
app_mode: ["saas"],
questionKey: I18nKey.ONBOARDING$ROLE_QUESTION,
subtitleKey: I18nKey.ONBOARDING$SELECT_ONE_SUBTITLE,
answerOptions: [
{
key: I18nKey.ONBOARDING$ROLE_SOFTWARE_ENGINEER,
@@ -1,6 +1,5 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useMutation } from "@tanstack/react-query";
import { useNavigate } from "react-router";
import { openHands } from "#/api/open-hands-axios";
import { displayErrorToast } from "#/utils/custom-toast-handlers";
type SubmitOnboardingArgs = {
@@ -9,18 +8,14 @@ type SubmitOnboardingArgs = {
export const useSubmitOnboarding = () => {
const navigate = useNavigate();
const queryClient = useQueryClient();
return useMutation({
mutationFn: async ({ selections }: SubmitOnboardingArgs) => {
// Mark onboarding as complete
await openHands.post("/api/complete_onboarding");
return { selections };
},
mutationFn: async ({ selections }: SubmitOnboardingArgs) =>
// TODO: mark onboarding as complete
// TODO: persist user responses
({ selections }),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ["settings"] });
const finalRedirectUrl = "/";
const finalRedirectUrl = "/"; // TODO: use redirect url from api response
// Check if the redirect URL is an external URL (starts with http or https)
if (
finalRedirectUrl.startsWith("http://") ||
@@ -5,7 +5,6 @@ import { organizationService } from "#/api/organization-service/organization-ser
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
import { I18nKey } from "#/i18n/declaration";
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
import { setSelectedOrg } from "#/utils/local-storage";
export const useSwitchOrganization = () => {
const { t } = useTranslation();
@@ -34,8 +33,6 @@ export const useSwitchOrganization = () => {
// Update local state - this triggers automatic refetch for all org-scoped queries
// since their query keys include organizationId (e.g., ["settings", orgId], ["secrets", orgId])
setOrganizationId(orgId);
// Broadcast org change to other apps (e.g. Automations) via localStorage
setSelectedOrg(orgId);
// Invalidate conversations to fetch data for the new org context
queryClient.invalidateQueries({ queryKey: ["user", "conversations"] });
// Remove all individual conversation queries to clear any stale/null data
-17
View File
@@ -1,17 +0,0 @@
/**
* Centralized query keys and cache configuration for TanStack Query.
* Using constants ensures type safety and prevents typos.
*/
export const QUERY_KEYS = {
/** Web client configuration from the server */
WEB_CLIENT_CONFIG: ["web-client-config"] as const,
} as const;
/** Cache configuration shared across all config-related queries */
export const CONFIG_CACHE_OPTIONS = {
staleTime: 1000 * 60 * 5, // 5 minutes
gcTime: 1000 * 60 * 15, // 15 minutes
} as const;
export type QueryKeys = (typeof QUERY_KEYS)[keyof typeof QUERY_KEYS];
@@ -1,10 +1,21 @@
import { useQuery } from "@tanstack/react-query";
import OptionService from "#/api/option-service/option-service.api";
const fetchAiConfigOptions = async () => {
const modelsResponse = await OptionService.getModels();
return {
models: modelsResponse.models,
verifiedModels: modelsResponse.verified_models,
verifiedProviders: modelsResponse.verified_providers,
defaultModel: modelsResponse.default_model,
securityAnalyzers: await OptionService.getSecurityAnalyzers(),
};
};
export const useAIConfigOptions = () =>
useQuery({
queryKey: ["ai-config-options"],
queryFn: OptionService.getSecurityAnalyzers,
staleTime: 1000 * 60 * 5,
gcTime: 1000 * 60 * 15,
queryFn: fetchAiConfigOptions,
staleTime: 1000 * 60 * 5, // 5 minutes
gcTime: 1000 * 60 * 15, // 15 minutes
});
@@ -6,9 +6,6 @@ import { useUserProviders } from "../use-user-providers";
import { Provider } from "#/types/settings";
import { shouldUseInstallationRepos } from "#/utils/utils";
/**
* Get the first page of app installations for the provider given.
*/
export const useAppInstallations = (selectedProvider: Provider | null) => {
const { data: config } = useConfig();
const { data: userIsAuthenticated } = useIsAuthed();
@@ -16,7 +13,7 @@ export const useAppInstallations = (selectedProvider: Provider | null) => {
return useQuery({
queryKey: ["installations", providers || [], selectedProvider],
queryFn: () => GitService.getUserInstallations(selectedProvider!),
queryFn: () => GitService.getUserInstallationIds(selectedProvider!),
enabled:
userIsAuthenticated &&
!!selectedProvider &&
+4 -6
View File
@@ -30,11 +30,9 @@ export function useBranchData(
provider,
);
// Combine all branches from paginated data - use .items for V1 response
// Combine all branches from paginated data
const allBranches = useMemo(
() =>
branchData?.pages?.flatMap((page: { items: Branch[] }) => page.items) ||
[],
() => branchData?.pages?.flatMap((page) => page.branches) || [],
[branchData],
);
@@ -42,7 +40,7 @@ export function useBranchData(
const defaultBranchInLoaded = useMemo(
() =>
defaultBranch
? allBranches.find((branch: Branch) => branch.name === defaultBranch)
? allBranches.find((branch) => branch.name === defaultBranch)
: null,
[allBranches, defaultBranch],
);
@@ -77,7 +75,7 @@ export function useBranchData(
if (defaultBranch) {
// Use the already computed defaultBranchInLoaded or check in current branches
let defaultBranchObj = shouldUseSearch
? branchesToUse.find((branch: Branch) => branch.name === defaultBranch)
? branchesToUse.find((branch) => branch.name === defaultBranch)
: defaultBranchInLoaded;
// If not found in current branches, check if we have it from the default branch search
+3 -3
View File
@@ -1,7 +1,6 @@
import { useQuery } from "@tanstack/react-query";
import OptionService from "#/api/option-service/option-service.api";
import { useIsOnIntermediatePage } from "#/hooks/use-is-on-intermediate-page";
import { QUERY_KEYS, CONFIG_CACHE_OPTIONS } from "./query-keys";
interface UseConfigOptions {
enabled?: boolean;
@@ -11,9 +10,10 @@ export const useConfig = (options?: UseConfigOptions) => {
const isOnIntermediatePage = useIsOnIntermediatePage();
return useQuery({
queryKey: QUERY_KEYS.WEB_CLIENT_CONFIG,
queryKey: ["web-client-config"],
queryFn: OptionService.getConfig,
...CONFIG_CACHE_OPTIONS,
staleTime: 1000 * 60 * 5, // 5 minutes
gcTime: 1000 * 60 * 15, // 15 minutes,
enabled: options?.enabled ?? !isOnIntermediatePage,
});
};
@@ -1,8 +1,8 @@
import { useInfiniteQuery, InfiniteData } from "@tanstack/react-query";
import { useInfiniteQuery } from "@tanstack/react-query";
import { useConfig } from "./use-config";
import { useUserProviders } from "../use-user-providers";
import { useAppInstallations } from "./use-app-installations";
import { RepositoryPage } from "../../types/git";
import { GitRepository } from "../../types/git";
import { Provider } from "../../types/settings";
import GitService from "#/api/git-service/git-service.api";
import { shouldUseInstallationRepos } from "#/utils/utils";
@@ -13,27 +13,29 @@ interface UseGitRepositoriesOptions {
enabled?: boolean;
}
type InstallationCursor = { installationIndex: number; pageId: string | null };
type UserCursor = string | null;
type Cursor = InstallationCursor | UserCursor;
interface UserRepositoriesResponse {
data: GitRepository[];
nextPage: number | null;
}
interface InstallationRepositoriesResponse {
data: GitRepository[];
nextPage: number | null;
installationIndex: number | null;
}
export function useGitRepositories(options: UseGitRepositoriesOptions) {
const { provider, pageSize = 30, enabled = true } = options;
const { providers } = useUserProviders();
const { data: config } = useConfig();
const { data: page } = useAppInstallations(provider);
const installations = page?.items;
const { data: installations } = useAppInstallations(provider);
const useInstallationRepos = provider
? shouldUseInstallationRepos(provider, config?.app_mode)
: false;
const repos = useInfiniteQuery<
RepositoryPage,
Error,
InfiniteData<RepositoryPage>,
[string, string[], Provider | null, boolean, number, ...unknown[]],
Cursor
UserRepositoriesResponse | InstallationRepositoriesResponse
>({
queryKey: [
"repositories",
@@ -49,52 +51,56 @@ export function useGitRepositories(options: UseGitRepositoriesOptions) {
}
if (useInstallationRepos) {
const { repoPage, installationIndex } = pageParam as {
installationIndex: number | null;
repoPage: number | null;
};
if (!installations) {
throw new Error("Missing installation list");
}
const cursor = pageParam as InstallationCursor;
const result = await GitService.retrieveInstallationRepositories(
return GitService.retrieveInstallationRepositories(
provider,
cursor.installationIndex,
installationIndex || 0,
installations,
cursor.pageId ?? undefined,
repoPage || 1,
pageSize,
);
return result;
}
const cursor = pageParam as UserCursor;
const result = await GitService.retrieveUserGitRepositories(
return GitService.retrieveUserGitRepositories(
provider,
cursor ?? undefined,
pageParam as number,
pageSize,
);
return result;
},
getNextPageParam: (lastPage, allPages, lastPageParam) => {
if (useInstallationRepos && installations) {
// Installation-based pagination
const currentCursor = lastPageParam as InstallationCursor;
if (lastPage.next_page_id) {
getNextPageParam: (lastPage) => {
if (useInstallationRepos) {
const installationPage = lastPage as InstallationRepositoriesResponse;
if (installationPage.nextPage) {
return {
installationIndex: currentCursor.installationIndex,
pageId: lastPage.next_page_id,
installationIndex: installationPage.installationIndex,
repoPage: installationPage.nextPage,
};
}
// Advance to next installation
const nextInstallationIndex = currentCursor.installationIndex + 1;
if (nextInstallationIndex < installations.length) {
return { installationIndex: nextInstallationIndex, pageId: null };
if (installationPage.installationIndex !== null) {
return {
installationIndex: installationPage.installationIndex,
repoPage: 1,
};
}
return undefined;
return null;
}
// User repositories pagination
return lastPage.next_page_id;
const userPage = lastPage as UserRepositoriesResponse;
return userPage.nextPage;
},
initialPageParam: useInstallationRepos
? { installationIndex: 0, pageId: null }
: null,
? { installationIndex: 0, repoPage: 1 }
: 1,
enabled:
enabled &&
(providers || []).length > 0 &&
@@ -1,36 +0,0 @@
import { useQuery } from "@tanstack/react-query";
import ConfigService from "#/api/config-service/config-service.api";
import type { LLMModel } from "#/api/config-service/config-service.types";
const MAX_PAGINATION_DEPTH = 10;
async function fetchPage(
provider: string,
pageId?: string,
depth = 0,
): Promise<LLMModel[]> {
if (depth >= MAX_PAGINATION_DEPTH) {
throw new Error(`Too many pagination requests for provider ${provider}`);
}
const page = await ConfigService.searchModels({
provider__eq: provider,
limit: 100,
page_id: pageId,
});
if (page.next_page_id) {
const rest = await fetchPage(provider, page.next_page_id, depth + 1);
return [...page.items, ...rest];
}
return page.items;
}
export const useProviderModels = (provider: string | null) =>
useQuery({
queryKey: ["config", "models", provider],
queryFn: () => fetchPage(provider!),
enabled: !!provider,
staleTime: 1000 * 60 * 5,
gcTime: 1000 * 60 * 15,
});
@@ -1,20 +1,35 @@
import { useInfiniteQuery, InfiniteData } from "@tanstack/react-query";
import { useQuery, useInfiniteQuery } from "@tanstack/react-query";
import GitService from "#/api/git-service/git-service.api";
import { BranchPage } from "#/types/git";
import { Branch, PaginatedBranchesResponse } from "#/types/git";
import { Provider } from "#/types/settings";
export const useRepositoryBranches = (
repository: string | null,
selectedProvider?: Provider,
) =>
useQuery<Branch[]>({
queryKey: ["repository", repository, "branches", selectedProvider],
queryFn: async () => {
if (!repository) return [];
const response = await GitService.getRepositoryBranches(
repository,
1,
30,
selectedProvider,
);
// Ensure we return an array even if the response is malformed
return Array.isArray(response.branches) ? response.branches : [];
},
enabled: !!repository,
staleTime: 1000 * 60 * 5, // 5 minutes
});
export const useRepositoryBranchesPaginated = (
repository: string | null,
perPage: number = 30,
selectedProvider?: Provider,
) => {
const result = useInfiniteQuery<
BranchPage,
Error,
InfiniteData<BranchPage>,
[string, string | null, ...unknown[]],
string | null
>({
) =>
useInfiniteQuery<PaginatedBranchesResponse, Error>({
queryKey: [
"repository",
repository,
@@ -23,29 +38,27 @@ export const useRepositoryBranchesPaginated = (
perPage,
selectedProvider,
],
queryFn: async ({ pageParam }) => {
if (!repository || !selectedProvider) {
queryFn: async ({ pageParam = 1 }) => {
if (!repository) {
return {
items: [],
next_page_id: null,
branches: [],
has_next_page: false,
current_page: 1,
per_page: perPage,
total_count: 0,
};
}
return GitService.getRepositoryBranches(
repository,
selectedProvider,
"", // query (empty = list all)
pageParam ?? undefined,
pageParam as number,
perPage,
selectedProvider,
);
},
enabled: !!repository && !!selectedProvider,
enabled: !!repository,
staleTime: 1000 * 60 * 5, // 5 minutes
getNextPageParam: (lastPage) =>
lastPage.next_page_id ? lastPage.next_page_id : undefined,
initialPageParam: null,
// Use the has_next_page flag from the API response
lastPage.has_next_page ? lastPage.current_page + 1 : undefined,
initialPageParam: 1,
});
return {
...result,
};
};
@@ -20,17 +20,15 @@ export function useSearchBranches(
selectedProvider,
],
queryFn: async () => {
if (!repository || !query || !selectedProvider) return [];
const response = await GitService.searchRepositoryBranches(
if (!repository || !query) return [];
return GitService.searchRepositoryBranches(
repository,
selectedProvider,
query,
undefined, // pageId
perPage,
selectedProvider,
);
return response.items;
},
enabled: !!repository && !!query && !!selectedProvider,
enabled: !!repository && !!query,
staleTime: 1000 * 60 * 5,
gcTime: 1000 * 60 * 15,
});
@@ -1,17 +0,0 @@
import { useQuery } from "@tanstack/react-query";
import ConfigService from "#/api/config-service/config-service.api";
import type { LLMProvider } from "#/api/config-service/config-service.types";
async function fetchAllProviders(): Promise<LLMProvider[]> {
// Providers are a small set; fetch all in one call with a high limit.
const page = await ConfigService.searchProviders({ limit: 100 });
return page.items;
}
export const useSearchProviders = () =>
useQuery({
queryKey: ["config", "providers"],
queryFn: fetchAllProviders,
staleTime: 1000 * 60 * 5,
gcTime: 1000 * 60 * 15,
});
@@ -1,6 +1,5 @@
import { useQuery } from "@tanstack/react-query";
import GitService from "#/api/git-service/git-service.api";
import { GitRepository } from "#/types/git";
import { Provider } from "#/types/settings";
export function useSearchRepositories(
@@ -9,20 +8,14 @@ export function useSearchRepositories(
disabled?: boolean,
pageSize: number = 100,
) {
// For backward compatibility, return the items array directly
return useQuery<GitRepository[]>({
return useQuery({
queryKey: ["repositories", "search", query, selectedProvider, pageSize],
queryFn: async () => {
if (!selectedProvider) {
return [];
}
const response = await GitService.searchGitRepositories(
queryFn: () =>
GitService.searchGitRepositories(
query,
selectedProvider, // provider (required)
pageSize,
);
return response.items;
},
selectedProvider || undefined,
),
enabled: !!query && !!selectedProvider && !disabled,
staleTime: 1000 * 60 * 5, // 5 minutes
gcTime: 1000 * 60 * 15, // 15 minutes
-43
View File
@@ -1,43 +0,0 @@
import { useMemo } from "react";
import { useConfig } from "#/hooks/query/use-config";
/**
* Hook that provides boolean checks for app mode deployment mode.
*
* App Mode (app_mode):
* - "oss": Open source version running locally/self-hosted
* - "saas": All-Hands managed SaaS version
*
* Deployment Mode (deployment_mode):
* - "cloud": Enterprise customers running on All-Hands managed infrastructure (*.all-hands.dev, *.openhands.ai)
* - "self_hosted": Enterprise customers running on their own infrastructure
*
* Note: SaaS mode can have either cloud or self_hosted deployment mode.
*/
export function useAppMode() {
const { data: config } = useConfig();
return useMemo(() => {
const appMode = config?.app_mode;
const deploymentMode = config?.feature_flags?.deployment_mode;
return {
// App Mode checks
isOss: appMode === "oss",
isSaas: appMode === "saas",
// Deployment Mode checks
isCloud: deploymentMode === "cloud",
isSelfHosted: deploymentMode === "self_hosted",
/** Enterprise checks */
isEnterpriseSelfHosted:
appMode === "saas" && deploymentMode === "self_hosted",
isEnterpriseCloud: appMode === "saas" && deploymentMode === "cloud",
// Raw values (for cases where the actual value is needed)
appMode,
deploymentMode,
};
}, [config?.app_mode, config?.feature_flags?.deployment_mode]);
}
@@ -1,7 +1,6 @@
import React from "react";
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
import { useOrganizations } from "#/hooks/query/use-organizations";
import { setSelectedOrg } from "#/utils/local-storage";
/**
* Hook that automatically selects an organization when:
@@ -29,8 +28,6 @@ export function useAutoSelectOrganization() {
// Revalidation is only needed when user explicitly switches organizations
// to redirect away from admin-only pages they may no longer have access to.
setOrganizationId(initialOrgId, { skipRevalidation: true });
// Broadcast org selection to other apps (e.g. Automations) via localStorage
setSelectedOrg(initialOrgId);
}
}, [organizationId, organizations, currentOrgId, setOrganizationId]);
}
@@ -0,0 +1,13 @@
import { useMemo } from "react";
/**
* Hook to check if the current domain is an All Hands SaaS environment
* @returns True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
*/
export const useIsAllHandsSaaSEnvironment = (): boolean =>
useMemo(() => {
const { hostname } = window.location;
return (
hostname.endsWith("all-hands.dev") || hostname.endsWith("openhands.dev")
);
}, []);
+2 -3
View File
@@ -470,7 +470,6 @@ export enum I18nKey {
SIDEBAR$NAVIGATION_LABEL = "SIDEBAR$NAVIGATION_LABEL",
FEEDBACK$PUBLIC_LABEL = "FEEDBACK$PUBLIC_LABEL",
FEEDBACK$PRIVATE_LABEL = "FEEDBACK$PRIVATE_LABEL",
SIDEBAR$AUTOMATIONS = "SIDEBAR$AUTOMATIONS",
SIDEBAR$CONVERSATIONS = "SIDEBAR$CONVERSATIONS",
STATUS$CONNECTING_TO_RUNTIME = "STATUS$CONNECTING_TO_RUNTIME",
STATUS$STARTING_RUNTIME = "STATUS$STARTING_RUNTIME",
@@ -1134,14 +1133,14 @@ export enum I18nKey {
ONBOARDING$ORG_NAME_INPUT_NAME = "ONBOARDING$ORG_NAME_INPUT_NAME",
ONBOARDING$ORG_NAME_INPUT_DOMAIN = "ONBOARDING$ORG_NAME_INPUT_DOMAIN",
ONBOARDING$ORG_SIZE_QUESTION = "ONBOARDING$ORG_SIZE_QUESTION",
ONBOARDING$SELECT_ONE_SUBTITLE = "ONBOARDING$SELECT_ONE_SUBTITLE",
ONBOARDING$ORG_SIZE_SUBTITLE = "ONBOARDING$ORG_SIZE_SUBTITLE",
ONBOARDING$ORG_SIZE_SOLO = "ONBOARDING$ORG_SIZE_SOLO",
ONBOARDING$ORG_SIZE_2_10 = "ONBOARDING$ORG_SIZE_2_10",
ONBOARDING$ORG_SIZE_11_50 = "ONBOARDING$ORG_SIZE_11_50",
ONBOARDING$ORG_SIZE_51_200 = "ONBOARDING$ORG_SIZE_51_200",
ONBOARDING$ORG_SIZE_200_PLUS = "ONBOARDING$ORG_SIZE_200_PLUS",
ONBOARDING$USE_CASE_QUESTION = "ONBOARDING$USE_CASE_QUESTION",
ONBOARDING$SELECT_MULTIPLE_SUBTITLE = "ONBOARDING$SELECT_MULTIPLE_SUBTITLE",
ONBOARDING$USE_CASE_SUBTITLE = "ONBOARDING$USE_CASE_SUBTITLE",
ONBOARDING$USE_CASE_NEW_FEATURES = "ONBOARDING$USE_CASE_NEW_FEATURES",
ONBOARDING$USE_CASE_APP_FROM_SCRATCH = "ONBOARDING$USE_CASE_APP_FROM_SCRATCH",
ONBOARDING$USE_CASE_FIXING_BUGS = "ONBOARDING$USE_CASE_FIXING_BUGS",
+2 -19
View File
@@ -7989,23 +7989,6 @@
"uk": "Приватний",
"ca": "Privat"
},
"SIDEBAR$AUTOMATIONS": {
"en": "Automations",
"zh-CN": "自动化",
"zh-TW": "自動化",
"de": "Automatisierungen",
"ko-KR": "자동화",
"no": "Automatiseringer",
"it": "Automazioni",
"pt": "Automatizações",
"es": "Automatizaciones",
"ar": "الأتمتة",
"fr": "Automatisations",
"tr": "Otomasyonlar",
"ja": "自動化",
"uk": "Автоматизації",
"ca": "Automatitzacions"
},
"SIDEBAR$CONVERSATIONS": {
"en": "Conversations",
"ja": "会話",
@@ -19280,7 +19263,7 @@
"uk": "Якого розміру організація, в якій ви працюєте?",
"ca": "De quina mida és la vostra organització?"
},
"ONBOARDING$SELECT_ONE_SUBTITLE": {
"ONBOARDING$ORG_SIZE_SUBTITLE": {
"en": "Select one",
"ja": "1つ選択",
"zh-CN": "选择一个",
@@ -19399,7 +19382,7 @@
"uk": "Для яких випадків використання ви хочете використовувати OpenHands?",
"ca": "Per a quins casos d'ús voleu fer servir OpenHands?"
},
"ONBOARDING$SELECT_MULTIPLE_SUBTITLE": {
"ONBOARDING$USE_CASE_SUBTITLE": {
"en": "Check all that apply",
"ja": "該当するものをすべて選択",
"zh-CN": "选择所有适用的",
-21
View File
@@ -1,21 +0,0 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 20 20" fill="none">
<mask id="mask0_18339_1445" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="1" y="1" width="18" height="18">
<path d="M19 1H1V19H19V1Z" fill="white"/>
</mask>
<g mask="url(#mask0_18339_1445)">
<path d="M10 18.1818V16.5454" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M10 1.81836V3.45472" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M1.81824 10H3.4546" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M18.1818 10H16.5454" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5.93359 17.1019L6.74359 15.6782" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M14.0663 2.89844L13.2563 4.32208" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2.89819 5.93359L4.32183 6.74359" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M17.1019 14.0663L15.6782 13.2563" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M5.92542 2.90625L6.7436 4.3217" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M14.0909 17.0854L13.2727 15.6699" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M17.0855 5.90918L15.67 6.72736" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M2.91455 14.0911L4.33001 13.2729" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M10 16.5455C13.615 16.5455 16.5455 13.615 16.5455 10C16.5455 6.38509 13.615 3.45459 10 3.45459C6.38509 3.45459 3.45459 6.38509 3.45459 10C3.45459 13.615 6.38509 16.5455 10 16.5455Z" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M12.5854 9.77916L8.80545 7.59461C8.63363 7.49643 8.41272 7.61916 8.41272 7.81552V12.1846C8.41272 12.381 8.62545 12.5119 8.80545 12.4055L12.5854 10.221C12.7573 10.1228 12.7573 9.86916 12.5854 9.77098V9.77916Z" fill="currentColor" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
</g>
</svg>

Before

Width:  |  Height:  |  Size: 2.5 KiB

+30 -103
View File
@@ -68,120 +68,47 @@ export const resetTestHandlersMockSettings = () => {
MOCK_USER_PREFERENCES.settings = MOCK_DEFAULT_USER_SETTINGS;
};
// Mock model data used by both V0 and V1 endpoints
const MOCK_MODELS = [
"anthropic/claude-3.5",
"anthropic/claude-sonnet-4-20250514",
"anthropic/claude-sonnet-4-5-20250929",
"anthropic/claude-haiku-4-5-20251001",
"anthropic/claude-opus-4-5-20251101",
"openai/gpt-3.5-turbo",
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openhands/claude-sonnet-4-20250514",
"openhands/claude-sonnet-4-5-20250929",
"openhands/claude-haiku-4-5-20251001",
"openhands/claude-opus-4-5-20251101",
"openhands/minimax-m2.7",
"sambanova/Meta-Llama-3.1-8B-Instruct",
];
const MOCK_VERIFIED_MODELS = new Set([
"anthropic/claude-opus-4-5-20251101",
"anthropic/claude-sonnet-4-5-20250929",
"openhands/claude-opus-4-5-20251101",
"openhands/claude-sonnet-4-5-20250929",
"openhands/minimax-m2.7",
]);
const MOCK_VERIFIED_PROVIDERS = [
"openhands",
"anthropic",
"openai",
"mistral",
"gemini",
"deepseek",
"moonshot",
"minimax",
];
// --- Handlers for options/config/settings ---
export const SETTINGS_HANDLERS = [
// V0 (legacy) models endpoint still used for default_model
http.get("/api/options/models", async () =>
HttpResponse.json({
models: MOCK_MODELS,
models: [
"anthropic/claude-3.5",
"anthropic/claude-sonnet-4-20250514",
"anthropic/claude-sonnet-4-5-20250929",
"anthropic/claude-haiku-4-5-20251001",
"anthropic/claude-opus-4-5-20251101",
"openai/gpt-3.5-turbo",
"openai/gpt-4o",
"openai/gpt-4o-mini",
"openhands/claude-sonnet-4-20250514",
"openhands/claude-sonnet-4-5-20250929",
"openhands/claude-haiku-4-5-20251001",
"openhands/claude-opus-4-5-20251101",
"sambanova/Meta-Llama-3.1-8B-Instruct",
],
verified_models: [
"claude-opus-4-5-20251101",
"claude-sonnet-4-5-20250929",
],
verified_providers: MOCK_VERIFIED_PROVIDERS,
verified_providers: [
"openhands",
"anthropic",
"openai",
"mistral",
"gemini",
"deepseek",
"moonshot",
"minimax",
],
default_model: "openhands/claude-opus-4-5-20251101",
}),
),
// V1 providers search
http.get("/api/v1/config/providers/search", async ({ request }) => {
const url = new URL(request.url);
const query = url.searchParams.get("query")?.toLowerCase();
const verifiedEq = url.searchParams.get("verified__eq");
// Build unique provider list from models
const seen = new Set<string>();
let providers: { name: string; verified: boolean }[] = [];
for (const model of MOCK_MODELS) {
const [providerName] = model.split("/");
if (providerName && !seen.has(providerName)) {
seen.add(providerName);
providers.push({
name: providerName,
verified: MOCK_VERIFIED_PROVIDERS.includes(providerName),
});
}
}
if (query) {
providers = providers.filter((p) => p.name.toLowerCase().includes(query));
}
if (verifiedEq !== null && verifiedEq !== undefined) {
const wantVerified = verifiedEq === "true";
providers = providers.filter((p) => p.verified === wantVerified);
}
return HttpResponse.json({ items: providers, next_page_id: null });
}),
// V1 models search
http.get("/api/v1/config/models/search", async ({ request }) => {
const url = new URL(request.url);
const query = url.searchParams.get("query")?.toLowerCase();
const verifiedEq = url.searchParams.get("verified__eq");
const providerEq = url.searchParams.get("provider__eq");
let models = MOCK_MODELS.map((m) => {
const [provider, ...rest] = m.split("/");
const name = rest.join("/");
return {
provider: provider || null,
name,
verified: MOCK_VERIFIED_MODELS.has(m),
};
});
if (providerEq) {
models = models.filter((m) => m.provider === providerEq);
}
if (query) {
models = models.filter((m) => m.name.toLowerCase().includes(query));
}
if (verifiedEq !== null && verifiedEq !== undefined) {
const wantVerified = verifiedEq === "true";
models = models.filter((m) => m.verified === wantVerified);
}
return HttpResponse.json({ items: models, next_page_id: null });
}),
http.get("/api/options/agents", async () =>
HttpResponse.json(["CodeActAgent", "CoActAgent"]),
),
http.get("/api/options/security-analyzers", async () =>
HttpResponse.json(["llm", "none"]),
@@ -218,7 +145,7 @@ export const SETTINGS_HANDLERS = [
return HttpResponse.json(config);
}),
http.get("/api/v1/settings", async () => {
http.get("/api/settings", async () => {
await delay();
const { settings } = MOCK_USER_PREFERENCES;
@@ -227,7 +154,7 @@ export const SETTINGS_HANDLERS = [
return HttpResponse.json(settings);
}),
http.post("/api/v1/settings", async ({ request }) => {
http.post("/api/settings", async ({ request }) => {
await delay();
const body = await request.json();
+5 -6
View File
@@ -16,15 +16,14 @@ import { isBillingHidden } from "#/utils/org/billing-visibility";
import { queryClient } from "#/query-client-config";
import OptionService from "#/api/option-service/option-service.api";
import { WebClientConfig } from "#/api/option-service/option.types";
import { QUERY_KEYS, CONFIG_CACHE_OPTIONS } from "#/hooks/query/query-keys";
import { getFirstAvailablePath } from "#/utils/settings-utils";
export const clientLoader = async () => {
const config = await queryClient.fetchQuery<WebClientConfig>({
queryKey: QUERY_KEYS.WEB_CLIENT_CONFIG,
queryFn: OptionService.getConfig,
...CONFIG_CACHE_OPTIONS,
});
let config = queryClient.getQueryData<WebClientConfig>(["web-client-config"]);
if (!config) {
config = await OptionService.getConfig();
queryClient.setQueryData<WebClientConfig>(["web-client-config"], config);
}
const isSaas = config?.app_mode === "saas";
const featureFlags = config?.feature_flags;

Some files were not shown because too many files have changed in this diff Show More