mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f7335fc15 |
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -192,7 +192,7 @@ jobs:
|
||||
|
||||
- name: Upload test results
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: playwright-report
|
||||
path: tests/e2e/test-results/
|
||||
@@ -200,7 +200,7 @@ jobs:
|
||||
|
||||
- name: Upload OpenHands logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: openhands-logs
|
||||
path: |
|
||||
|
||||
@@ -41,7 +41,7 @@ jobs:
|
||||
working-directory: ./frontend
|
||||
run: npx playwright test --project=chromium
|
||||
- name: Upload Playwright report
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-report
|
||||
|
||||
@@ -33,48 +33,47 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
base_image: ${{ steps.define-base-images.outputs.base_image }}
|
||||
architectures: ${{ steps.define-base-images.outputs.architectures }}
|
||||
platforms: ${{ steps.define-base-images.outputs.platforms }}
|
||||
steps:
|
||||
- name: Define base images
|
||||
shell: bash
|
||||
id: define-base-images
|
||||
run: |
|
||||
architectures='["amd64", "arm64"]'
|
||||
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik" }
|
||||
platforms="linux/amd64"
|
||||
json=$(jq -n -c --arg platforms "$platforms" '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik", platforms: $platforms }
|
||||
]')
|
||||
else
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik" },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu" }
|
||||
platforms="linux/amd64,linux/arm64"
|
||||
json=$(jq -n -c --arg platforms "$platforms" '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik", platforms: $platforms },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu", platforms: $platforms }
|
||||
]')
|
||||
fi
|
||||
echo "base_image=$json" >> "$GITHUB_OUTPUT"
|
||||
echo "architectures=$architectures" >> "$GITHUB_OUTPUT"
|
||||
echo "platforms=$platforms" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# Builds the OpenHands Docker images (one per architecture, natively)
|
||||
# Builds the OpenHands Docker images
|
||||
ghcr_build_app:
|
||||
name: Build App Image (${{ matrix.arch }})
|
||||
runs-on: ${{ matrix.arch == 'arm64' && 'ubuntu-24.04-arm' || 'ubuntu-22.04' }}
|
||||
name: Build App Image
|
||||
runs-on: ubuntu-22.04
|
||||
if: "!(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/ext-v'))"
|
||||
needs: define-matrix
|
||||
outputs:
|
||||
# All arch variants produce the same base tags, so any entry works
|
||||
base_tags: ${{ steps.build.outputs.base_tags }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ${{ fromJson(needs.define-matrix.outputs.architectures) }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3.7.0
|
||||
with:
|
||||
image: tonistiigi/binfmt:latest
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -86,79 +85,33 @@ jobs:
|
||||
run: |
|
||||
echo REPO_OWNER=$(echo ${{ github.repository_owner }} | tr '[:upper:]' '[:lower:]') >> $GITHUB_ENV
|
||||
- name: Build and push app image
|
||||
id: build
|
||||
if: "!github.event.pull_request.head.repo.fork"
|
||||
run: |
|
||||
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --push --arch ${{ matrix.arch }}
|
||||
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --push -p ${{ needs.define-matrix.outputs.platforms }}
|
||||
|
||||
# Output base tags (without arch suffix) for the merge job
|
||||
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --arch ${{ matrix.arch }} --dry
|
||||
BASE_TAGS=$(jq -r '.base_tags | join("\n")' docker-build-dry.json)
|
||||
echo "base_tags<<EOF" >> "$GITHUB_OUTPUT"
|
||||
echo "$BASE_TAGS" >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# Merges per-architecture app images into a multi-arch manifest
|
||||
ghcr_build_app_merge:
|
||||
name: Merge App Multi-Arch Manifest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [define-matrix, ghcr_build_app]
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
permissions:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Merge multi-arch manifest
|
||||
run: |
|
||||
ARCHS='${{ join(fromJson(needs.define-matrix.outputs.architectures), ' ') }}'
|
||||
TAGS="${{ needs.ghcr_build_app.outputs.base_tags }}"
|
||||
while IFS= read -r tag; do
|
||||
[[ -z "$tag" ]] && continue
|
||||
sources=""
|
||||
for arch in $ARCHS; do
|
||||
if ! docker buildx imagetools inspect "${tag}-${arch}" > /dev/null 2>&1; then
|
||||
echo "::error::Missing image ${tag}-${arch}"
|
||||
exit 1
|
||||
fi
|
||||
sources+=" ${tag}-${arch}"
|
||||
done
|
||||
echo "Creating manifest for $tag from:$sources"
|
||||
docker buildx imagetools create -t "$tag" $sources
|
||||
done <<< "$TAGS"
|
||||
|
||||
# Builds the runtime Docker images (one per architecture x base_image, natively)
|
||||
# Builds the runtime Docker images
|
||||
ghcr_build_runtime:
|
||||
name: Build Runtime Image (${{ matrix.base_image.tag }}, ${{ matrix.arch }})
|
||||
runs-on: ${{ matrix.arch == 'arm64' && 'ubuntu-24.04-arm' || 'ubuntu-22.04' }}
|
||||
name: Build Runtime Image
|
||||
runs-on: ubuntu-22.04
|
||||
if: "!(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/ext-v'))"
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
needs: define-matrix
|
||||
outputs:
|
||||
# Keyed by base_image tag so the merge job can access per-variant tags.
|
||||
# Matrix outputs from different entries with the same key overwrite each other,
|
||||
# but all arch variants of the same base_image produce identical base tags.
|
||||
base_tags_nikolaik: ${{ steps.params.outputs.base_tags_nikolaik }}
|
||||
base_tags_ubuntu: ${{ steps.params.outputs.base_tags_ubuntu }}
|
||||
strategy:
|
||||
matrix:
|
||||
base_image: ${{ fromJson(needs.define-matrix.outputs.base_image) }}
|
||||
arch: ${{ fromJson(needs.define-matrix.outputs.architectures) }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3.7.0
|
||||
with:
|
||||
image: tonistiigi/binfmt:latest
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -184,22 +137,16 @@ jobs:
|
||||
run: |
|
||||
echo SHORT_SHA=$(git rev-parse --short "$RELEVANT_SHA") >> $GITHUB_ENV
|
||||
- name: Determine docker build params
|
||||
id: params
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
shell: bash
|
||||
run: |
|
||||
./containers/build.sh -i runtime -o ${{ env.REPO_OWNER }} -t ${{ matrix.base_image.tag }} --arch ${{ matrix.arch }} --dry
|
||||
|
||||
./containers/build.sh -i runtime -o ${{ env.REPO_OWNER }} -t ${{ matrix.base_image.tag }} --dry -p ${{ matrix.base_image.platforms }}
|
||||
|
||||
DOCKER_BUILD_JSON=$(jq -c . < docker-build-dry.json)
|
||||
echo "DOCKER_TAGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.tags | join(",")')" >> $GITHUB_ENV
|
||||
echo "DOCKER_PLATFORM=$(echo "$DOCKER_BUILD_JSON" | jq -r '.platform')" >> $GITHUB_ENV
|
||||
echo "DOCKER_BUILD_ARGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.build_args | join(",")')" >> $GITHUB_ENV
|
||||
|
||||
# Output base tags (without arch suffix) keyed by base_image tag for the merge job
|
||||
BASE_TAGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.base_tags | join("\n")')
|
||||
echo "base_tags_${{ matrix.base_image.tag }}<<EOF" >> "$GITHUB_OUTPUT"
|
||||
echo "$BASE_TAGS" >> "$GITHUB_OUTPUT"
|
||||
echo "EOF" >> "$GITHUB_OUTPUT"
|
||||
- name: Build and push runtime image ${{ matrix.base_image.image }}
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
uses: docker/build-push-action@v6
|
||||
@@ -207,8 +154,9 @@ jobs:
|
||||
push: true
|
||||
tags: ${{ env.DOCKER_TAGS }}
|
||||
platforms: ${{ env.DOCKER_PLATFORM }}
|
||||
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}-${{ matrix.arch }}
|
||||
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}-${{ matrix.arch }},mode=max
|
||||
# Caching directives to boost performance
|
||||
cache-from: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }}
|
||||
cache-to: type=registry,ref=ghcr.io/${{ env.REPO_OWNER }}/runtime:buildcache-${{ matrix.base_image.tag }},mode=max
|
||||
build-args: ${{ env.DOCKER_BUILD_ARGS }}
|
||||
context: containers/runtime
|
||||
provenance: false
|
||||
@@ -221,66 +169,20 @@ jobs:
|
||||
context: containers/runtime
|
||||
- name: Upload runtime source for fork
|
||||
if: github.event.pull_request.head.repo.fork
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}-${{ matrix.arch }}
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
|
||||
# Merges per-architecture runtime images into multi-arch manifests
|
||||
ghcr_build_runtime_merge:
|
||||
name: Merge Runtime Multi-Arch Manifest
|
||||
runs-on: ubuntu-22.04
|
||||
needs: [define-matrix, ghcr_build_runtime]
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
permissions:
|
||||
packages: write
|
||||
steps:
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Merge multi-arch manifests
|
||||
run: |
|
||||
ARCHS='${{ join(fromJson(needs.define-matrix.outputs.architectures), ' ') }}'
|
||||
|
||||
# Merge all runtime base_image variants
|
||||
for variant_tags in \
|
||||
"${{ needs.ghcr_build_runtime.outputs.base_tags_nikolaik }}" \
|
||||
"${{ needs.ghcr_build_runtime.outputs.base_tags_ubuntu }}"; do
|
||||
while IFS= read -r tag; do
|
||||
[[ -z "$tag" ]] && continue
|
||||
sources=""
|
||||
for arch in $ARCHS; do
|
||||
if ! docker buildx imagetools inspect "${tag}-${arch}" > /dev/null 2>&1; then
|
||||
echo "::error::Missing image ${tag}-${arch}"
|
||||
exit 1
|
||||
fi
|
||||
sources+=" ${tag}-${arch}"
|
||||
done
|
||||
echo "Creating manifest for $tag from:$sources"
|
||||
docker buildx imagetools create -t "$tag" $sources
|
||||
done <<< "$variant_tags"
|
||||
done
|
||||
|
||||
ghcr_build_enterprise:
|
||||
name: Push Enterprise Image (${{ matrix.arch }})
|
||||
runs-on: ${{ matrix.arch == 'arm64' && 'ubuntu-24.04-arm' || 'ubuntu-22.04' }}
|
||||
name: Push Enterprise Image
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
needs: [define-matrix, ghcr_build_app_merge]
|
||||
needs: [define-matrix, ghcr_build_app]
|
||||
# Do not build enterprise in forks
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
outputs:
|
||||
# Tags without arch suffix, for the merge job
|
||||
base_tags: ${{ steps.meta_base.outputs.tags }}
|
||||
strategy:
|
||||
matrix:
|
||||
arch: ${{ fromJson(needs.define-matrix.outputs.architectures) }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
@@ -294,7 +196,7 @@ jobs:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v4
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
@@ -302,28 +204,6 @@ jobs:
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v6
|
||||
with:
|
||||
images: ghcr.io/openhands/enterprise-server
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=sha
|
||||
type=sha,format=long
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=match,pattern=cloud-\d+\.\d+\.\d+
|
||||
flavor: |
|
||||
latest=auto
|
||||
prefix=
|
||||
suffix=-${{ matrix.arch }}
|
||||
env:
|
||||
DOCKER_METADATA_PR_HEAD_SHA: true
|
||||
|
||||
# Also compute base tags (no arch suffix) for the merge job output
|
||||
- name: Extract base metadata for merge
|
||||
id: meta_base
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/openhands/enterprise-server
|
||||
@@ -342,7 +222,6 @@ jobs:
|
||||
suffix=
|
||||
env:
|
||||
DOCKER_METADATA_PR_HEAD_SHA: true
|
||||
|
||||
- name: Determine app image tag
|
||||
shell: bash
|
||||
run: |
|
||||
@@ -359,49 +238,12 @@ jobs:
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
OPENHANDS_VERSION=${{ env.OPENHANDS_DOCKER_TAG }}
|
||||
platforms: linux/${{ matrix.arch }}
|
||||
platforms: linux/amd64
|
||||
# Add build provenance
|
||||
provenance: true
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
# Merges per-architecture enterprise images into a multi-arch manifest
|
||||
ghcr_build_enterprise_merge:
|
||||
name: Merge Enterprise Multi-Arch Manifest
|
||||
runs-on: ubuntu-22.04
|
||||
permissions:
|
||||
packages: write
|
||||
needs: [define-matrix, ghcr_build_enterprise]
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
steps:
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Merge multi-arch manifest
|
||||
run: |
|
||||
ARCHS='${{ join(fromJson(needs.define-matrix.outputs.architectures), ' ') }}'
|
||||
TAGS="${{ needs.ghcr_build_enterprise.outputs.base_tags }}"
|
||||
while IFS= read -r tag; do
|
||||
[[ -z "$tag" ]] && continue
|
||||
sources=""
|
||||
for arch in $ARCHS; do
|
||||
if ! docker buildx imagetools inspect "${tag}-${arch}" > /dev/null 2>&1; then
|
||||
echo "::error::Missing image ${tag}-${arch}"
|
||||
exit 1
|
||||
fi
|
||||
sources+=" ${tag}-${arch}"
|
||||
done
|
||||
echo "Creating manifest for $tag from:$sources"
|
||||
docker buildx imagetools create -t "$tag" $sources
|
||||
done <<< "$TAGS"
|
||||
|
||||
# "All Runtime Tests Passed" is a required job for PRs to merge
|
||||
# We can remove this once the config changes
|
||||
runtime_tests_check_success:
|
||||
@@ -414,7 +256,7 @@ jobs:
|
||||
update_pr_description:
|
||||
name: Update PR Description
|
||||
if: github.event_name == 'pull_request' && !github.event.pull_request.head.repo.fork && github.actor != 'dependabot[bot]'
|
||||
needs: [ghcr_build_runtime_merge]
|
||||
needs: [ghcr_build_runtime]
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
@@ -269,7 +269,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Upload output.jsonl as artifact
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
if: always() # Upload even if the previous steps fail
|
||||
with:
|
||||
name: resolver-output
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
echo "is_fork=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v5
|
||||
if: steps.check-fork.outputs.is_fork == 'false'
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
@@ -93,7 +93,7 @@ jobs:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Check for .pr/ directory
|
||||
id: check
|
||||
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
# Always checkout main branch for security - cannot test script changes in PRs
|
||||
- name: Checkout extensions repository
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
repository: OpenHands/extensions
|
||||
path: extensions
|
||||
@@ -77,7 +77,7 @@ jobs:
|
||||
--trace-file trace-info/laminar_trace_info.json
|
||||
|
||||
- name: Upload evaluation logs
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v5
|
||||
if: always() && steps.check-trace.outputs.trace_exists == 'true'
|
||||
with:
|
||||
name: pr-review-evaluation-${{ github.event.pull_request.number }}
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.runtime.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: coverage-openhands
|
||||
path: |
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
env:
|
||||
COVERAGE_FILE: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
- name: Store coverage file
|
||||
uses: actions/upload-artifact@v7
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: coverage-enterprise
|
||||
path: ".coverage.enterprise.${{ matrix.python_version }}"
|
||||
|
||||
+9
-30
@@ -8,18 +8,18 @@ push=0
|
||||
load=0
|
||||
tag_suffix=""
|
||||
dry_run=0
|
||||
arch_suffix=""
|
||||
platform_override=""
|
||||
|
||||
# Function to display usage information
|
||||
usage() {
|
||||
echo "Usage: $0 -i <image_name> [-o <org_name>] [--push] [--load] [-t <tag_suffix>] [--dry] [--arch <arch>]"
|
||||
echo "Usage: $0 -i <image_name> [-o <org_name>] [--push] [--load] [-t <tag_suffix>] [-p <platform>] [--dry]"
|
||||
echo " -i: Image name (required)"
|
||||
echo " -o: Organization name"
|
||||
echo " --push: Push the image"
|
||||
echo " --load: Load the image"
|
||||
echo " -t: Tag suffix"
|
||||
echo " -p: Platform(s) to build for (e.g. linux/amd64 or linux/amd64,linux/arm64)"
|
||||
echo " --dry: Don't build, only create build-args.json"
|
||||
echo " --arch: Architecture suffix (e.g. amd64 or arm64). Appends -<arch> to tags and forces single-platform build"
|
||||
exit 1
|
||||
}
|
||||
|
||||
@@ -31,8 +31,8 @@ while [[ $# -gt 0 ]]; do
|
||||
--push) push=1; shift ;;
|
||||
--load) load=1; shift ;;
|
||||
-t) tag_suffix="$2"; shift 2 ;;
|
||||
-p) platform_override="$2"; shift 2 ;;
|
||||
--dry) dry_run=1; shift ;;
|
||||
--arch) arch_suffix="$2"; shift 2 ;;
|
||||
*) usage ;;
|
||||
esac
|
||||
done
|
||||
@@ -78,7 +78,7 @@ if [[ -n $tag_suffix ]]; then
|
||||
done
|
||||
fi
|
||||
|
||||
echo "Tags (before arch suffix): ${tags[@]}"
|
||||
echo "Tags: ${tags[@]}"
|
||||
|
||||
if [[ "$image_name" == "openhands" ]]; then
|
||||
dir="./containers/app"
|
||||
@@ -113,21 +113,10 @@ if [[ -n "$DOCKER_IMAGE_TAG" ]]; then
|
||||
tags+=("$DOCKER_IMAGE_TAG")
|
||||
fi
|
||||
|
||||
# Apply architecture suffix for split-arch builds (after all tags are collected)
|
||||
if [[ -n "$arch_suffix" ]]; then
|
||||
cache_tag+="-${arch_suffix}"
|
||||
for i in "${!tags[@]}"; do
|
||||
tags[$i]="${tags[$i]}-${arch_suffix}"
|
||||
done
|
||||
# Force single-platform build for this architecture
|
||||
arch_platform="linux/${arch_suffix}"
|
||||
fi
|
||||
|
||||
DOCKER_REPOSITORY="$DOCKER_REGISTRY/$DOCKER_ORG/$DOCKER_IMAGE"
|
||||
DOCKER_REPOSITORY=${DOCKER_REPOSITORY,,} # lowercase
|
||||
echo "Repo: $DOCKER_REPOSITORY"
|
||||
echo "Base dir: $DOCKER_BASE_DIR"
|
||||
echo "Tags: ${tags[@]}"
|
||||
|
||||
args=""
|
||||
full_tags=()
|
||||
@@ -136,6 +125,7 @@ for tag in "${tags[@]}"; do
|
||||
full_tags+=("$DOCKER_REPOSITORY:$tag")
|
||||
done
|
||||
|
||||
|
||||
if [[ $push -eq 1 ]]; then
|
||||
args+=" --push"
|
||||
args+=" --cache-to=type=registry,ref=$DOCKER_REPOSITORY:$cache_tag,mode=max"
|
||||
@@ -148,8 +138,8 @@ fi
|
||||
echo "Args: $args"
|
||||
|
||||
# Determine the platform(s) to build for
|
||||
if [[ -n "$arch_platform" ]]; then
|
||||
platform="$arch_platform"
|
||||
if [[ -n "$platform_override" ]]; then
|
||||
platform="$platform_override"
|
||||
elif [[ $load -eq 1 ]]; then
|
||||
# When loading, build only for the current platform
|
||||
platform=$(docker version -f '{{.Server.Os}}/{{.Server.Arch}}')
|
||||
@@ -159,24 +149,13 @@ else
|
||||
fi
|
||||
if [[ $dry_run -eq 1 ]]; then
|
||||
echo "Dry Run is enabled. Writing build config to docker-build-dry.json"
|
||||
# Compute base tags (arch suffix stripped) for use by merge jobs
|
||||
base_tags=()
|
||||
for ftag in "${full_tags[@]}"; do
|
||||
if [[ -n "$arch_suffix" ]]; then
|
||||
base_tags+=("${ftag%-${arch_suffix}}")
|
||||
else
|
||||
base_tags+=("$ftag")
|
||||
fi
|
||||
done
|
||||
jq -n \
|
||||
--argjson tags "$(printf '%s\n' "${full_tags[@]}" | jq -R . | jq -s .)" \
|
||||
--argjson base_tags "$(printf '%s\n' "${base_tags[@]}" | jq -R . | jq -s .)" \
|
||||
--arg platform "$platform" \
|
||||
--arg openhands_build_version "$OPENHANDS_BUILD_VERSION" \
|
||||
--arg dockerfile "$dir/Dockerfile" \
|
||||
'{
|
||||
tags: $tags,
|
||||
base_tags: $base_tags,
|
||||
platform: $platform,
|
||||
build_args: [
|
||||
"OPENHANDS_BUILD_VERSION=" + $openhands_build_version
|
||||
@@ -195,7 +174,7 @@ docker buildx build \
|
||||
$args \
|
||||
--build-arg OPENHANDS_BUILD_VERSION="$OPENHANDS_BUILD_VERSION" \
|
||||
--cache-from=type=registry,ref=$DOCKER_REPOSITORY:$cache_tag \
|
||||
--cache-from=type=registry,ref=$DOCKER_REPOSITORY:${cache_tag_base}-main${arch_suffix:+-${arch_suffix}} \
|
||||
--cache-from=type=registry,ref=$DOCKER_REPOSITORY:$cache_tag_base-main \
|
||||
--platform $platform \
|
||||
--provenance=false \
|
||||
-f "$dir/Dockerfile" \
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# PolyForm Free Trial License 1.0.0
|
||||
|
||||
Copyright (c) 2026 All Hands AI
|
||||
|
||||
## Acceptance
|
||||
|
||||
In order to get any license under these terms, you must agree
|
||||
|
||||
@@ -59,7 +59,7 @@ handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
level = DEBUG
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
|
||||
@@ -1729,7 +1729,7 @@
|
||||
"syncMode": "IMPORT",
|
||||
"clientSecret": "$GITHUB_APP_CLIENT_SECRET",
|
||||
"caseSensitiveOriginalUsername": "false",
|
||||
"defaultScope": "openid email profile",
|
||||
"defaultScope": "openid email profile notifications",
|
||||
"baseUrl": "$GITHUB_BASE_URL"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -24,20 +24,20 @@ from integrations.jira.jira_types import (
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.jira.jira_view import JiraFactory
|
||||
from integrations.jira.jira_view import JiraFactory, JiraNewConversationView
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message
|
||||
from integrations.utils import (
|
||||
HOST,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
format_jira_comment_body,
|
||||
get_oh_labels,
|
||||
get_session_expired_message,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
@@ -259,6 +259,11 @@ class JiraManager(Manager[JiraViewInterface]):
|
||||
|
||||
async def start_job(self, view: JiraViewInterface) -> None:
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
JiraCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
'[Jira] Starting job',
|
||||
@@ -280,7 +285,19 @@ class JiraManager(Manager[JiraViewInterface]):
|
||||
},
|
||||
)
|
||||
|
||||
# Create success message
|
||||
# Register callback processor for updates
|
||||
if isinstance(view, JiraNewConversationView):
|
||||
processor = JiraCallbackProcessor(
|
||||
issue_key=view.payload.issue_key,
|
||||
workspace_name=view.jira_workspace.name,
|
||||
)
|
||||
register_callback_processor(conversation_id, processor)
|
||||
logger.info(
|
||||
'[Jira] Callback processor registered',
|
||||
extra={'conversation_id': conversation_id},
|
||||
)
|
||||
|
||||
# Send success response
|
||||
msg_info = view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
@@ -342,7 +359,7 @@ class JiraManager(Manager[JiraViewInterface]):
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = format_jira_comment_body(message)
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
|
||||
@@ -136,10 +136,11 @@ class JiraPayloadParser:
|
||||
items = changelog.get('items', [])
|
||||
|
||||
# Extract labels that were added
|
||||
labels = set()
|
||||
for item in items:
|
||||
if item.get('field') == 'labels' and item.get('toString'):
|
||||
labels.update(item['toString'].split())
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if self.oh_label not in labels:
|
||||
return JiraPayloadSkipped(
|
||||
|
||||
@@ -1,238 +0,0 @@
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from integrations.utils import format_jira_comment_body, get_summary_instruction
|
||||
from pydantic import Field
|
||||
|
||||
from openhands.agent_server.models import AskAgentRequest, AskAgentResponse
|
||||
from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallback,
|
||||
EventCallbackProcessor,
|
||||
)
|
||||
from openhands.app_server.event_callback.event_callback_result_models import (
|
||||
EventCallbackResult,
|
||||
EventCallbackResultStatus,
|
||||
)
|
||||
from openhands.app_server.event_callback.util import (
|
||||
ensure_conversation_found,
|
||||
ensure_running_sandbox,
|
||||
get_agent_server_url_from_sandbox,
|
||||
)
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
|
||||
class JiraV1CallbackProcessor(EventCallbackProcessor):
|
||||
"""Callback processor for Jira V1 integrations."""
|
||||
|
||||
should_request_summary: bool = Field(default=True)
|
||||
svc_acc_email: str
|
||||
decrypted_api_key: str
|
||||
issue_key: str
|
||||
jira_cloud_id: str
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
callback: EventCallback,
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for Jira V1 integration."""
|
||||
# Only handle ConversationStateUpdateEvent for execution_status
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
|
||||
if event.key != 'execution_status':
|
||||
return None
|
||||
|
||||
_logger.info('[Jira] Callback agent state was %s', event)
|
||||
|
||||
# Only request summary when execution has finished successfully
|
||||
if event.value != 'finished':
|
||||
return None
|
||||
|
||||
_logger.info('[Jira] Should request summary: %s', self.should_request_summary)
|
||||
|
||||
if not self.should_request_summary:
|
||||
return None
|
||||
|
||||
self.should_request_summary = False
|
||||
|
||||
try:
|
||||
_logger.info(f'[Jira] Requesting summary {conversation_id}')
|
||||
summary = await self._request_summary(conversation_id)
|
||||
_logger.info(
|
||||
f'[Jira] Posting summary {conversation_id}',
|
||||
extra={'summary': summary},
|
||||
)
|
||||
await self._post_summary_to_jira(summary)
|
||||
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.SUCCESS,
|
||||
event_callback_id=callback.id,
|
||||
event_id=event.id,
|
||||
conversation_id=conversation_id,
|
||||
detail=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.exception(f'[Jira] Failed to post summary: {e}', stack_info=True)
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.ERROR,
|
||||
event_callback_id=callback.id,
|
||||
event_id=event.id,
|
||||
conversation_id=conversation_id,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
async def _request_summary(self, conversation_id: UUID) -> str:
|
||||
"""Ask the agent to produce a summary of its work and return the agent response."""
|
||||
# Import services within the method to avoid circular imports
|
||||
from openhands.app_server.config import (
|
||||
get_app_conversation_info_service,
|
||||
get_httpx_client,
|
||||
get_sandbox_service,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import (
|
||||
ADMIN,
|
||||
USER_CONTEXT_ATTR,
|
||||
)
|
||||
|
||||
# Create injector state for dependency injection
|
||||
state = InjectorState()
|
||||
setattr(state, USER_CONTEXT_ATTR, ADMIN)
|
||||
|
||||
async with (
|
||||
get_app_conversation_info_service(state) as app_conversation_info_service,
|
||||
get_sandbox_service(state) as sandbox_service,
|
||||
get_httpx_client(state) as httpx_client,
|
||||
):
|
||||
# 1. Conversation lookup
|
||||
app_conversation_info = ensure_conversation_found(
|
||||
await app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
),
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
# 2. Sandbox lookup + validation
|
||||
sandbox = ensure_running_sandbox(
|
||||
await sandbox_service.get_sandbox(app_conversation_info.sandbox_id),
|
||||
app_conversation_info.sandbox_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
sandbox.session_api_key is not None
|
||||
), f'No session API key for sandbox: {sandbox.id}'
|
||||
|
||||
# 3. URL + instruction
|
||||
agent_server_url = get_agent_server_url_from_sandbox(sandbox)
|
||||
|
||||
# Prepare message based on agent state
|
||||
message_content = get_summary_instruction()
|
||||
|
||||
# Ask the agent and return the response text
|
||||
return await self._ask_question(
|
||||
httpx_client=httpx_client,
|
||||
agent_server_url=agent_server_url,
|
||||
conversation_id=conversation_id,
|
||||
session_api_key=sandbox.session_api_key,
|
||||
message_content=message_content,
|
||||
)
|
||||
|
||||
async def _ask_question(
|
||||
self,
|
||||
httpx_client: httpx.AsyncClient,
|
||||
agent_server_url: str,
|
||||
conversation_id: UUID,
|
||||
session_api_key: str,
|
||||
message_content: str,
|
||||
) -> str:
|
||||
"""Send a message to the agent server via the V1 API and return response text."""
|
||||
send_message_request = AskAgentRequest(question=message_content)
|
||||
|
||||
url = (
|
||||
f"{agent_server_url.rstrip('/')}"
|
||||
f"/api/conversations/{conversation_id}/ask_agent"
|
||||
)
|
||||
headers = {'X-Session-API-Key': session_api_key}
|
||||
payload = send_message_request.model_dump()
|
||||
|
||||
try:
|
||||
response = await httpx_client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
agent_response = AskAgentResponse.model_validate(response.json())
|
||||
return agent_response.response
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = f'HTTP {e.response.status_code} error'
|
||||
try:
|
||||
error_body = e.response.text
|
||||
if error_body:
|
||||
error_detail += f': {error_body}'
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_logger.exception(
|
||||
'[Jira] HTTP error sending message to %s: %s. '
|
||||
'Request payload: %s. Response headers: %s',
|
||||
url,
|
||||
error_detail,
|
||||
payload,
|
||||
dict(e.response.headers),
|
||||
stack_info=True,
|
||||
)
|
||||
raise Exception(f'Failed to send message to agent server: {error_detail}')
|
||||
|
||||
except httpx.TimeoutException:
|
||||
error_detail = f'Request timeout after 30 seconds to {url}'
|
||||
_logger.exception(
|
||||
'[Jira] Timeout error: %s. Request payload: %s',
|
||||
error_detail,
|
||||
payload,
|
||||
stack_info=True,
|
||||
)
|
||||
raise Exception(f'Failed to send message to agent server: {error_detail}')
|
||||
|
||||
async def _post_summary_to_jira(self, summary: str):
|
||||
"""Post the summary back to the Jira issue."""
|
||||
if not all(
|
||||
[
|
||||
self.svc_acc_email,
|
||||
self.decrypted_api_key,
|
||||
self.issue_key,
|
||||
self.jira_cloud_id,
|
||||
]
|
||||
):
|
||||
_logger.warning('[Jira] Missing required data for posting summary')
|
||||
return
|
||||
|
||||
# Add a comment to the Jira issue with the summary
|
||||
comment_url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{self.jira_cloud_id}'
|
||||
f'/rest/api/2/issue/{self.issue_key}/comment'
|
||||
)
|
||||
|
||||
message = f'OpenHands resolved this issue:\n\n{summary}'
|
||||
comment_body = format_jira_comment_body(message)
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
comment_url,
|
||||
auth=(self.svc_acc_email, self.decrypted_api_key),
|
||||
json=comment_body,
|
||||
)
|
||||
response.raise_for_status()
|
||||
_logger.info(f'[Jira] Posted summary to {self.issue_key}')
|
||||
@@ -7,7 +7,7 @@ Views are responsible for:
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from integrations.jira.jira_payload import JiraWebhookPayload
|
||||
@@ -16,37 +16,25 @@ from integrations.jira.jira_types import (
|
||||
RepositoryNotFoundError,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.jira.jira_v1_callback_processor import (
|
||||
JiraV1CallbackProcessor,
|
||||
)
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
infer_repo_from_message,
|
||||
)
|
||||
from integrations.utils import CONVERSATION_URL, infer_repo_from_message
|
||||
from jinja2 import Environment
|
||||
from server.config import get_config
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderType
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.server.services.conversation_service import start_conversation
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
@@ -66,7 +54,7 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str = ''
|
||||
selected_repo: str | None = None
|
||||
conversation_id: str = ''
|
||||
|
||||
# Lazy-loaded issue details (cached after first fetch)
|
||||
@@ -76,9 +64,6 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
# Decrypted API key (set by factory)
|
||||
_decrypted_api_key: str = field(default='', repr=False)
|
||||
|
||||
# Resolved org ID for V1 conversations
|
||||
resolved_org_id: UUID | None = None
|
||||
|
||||
async def get_issue_details(self) -> tuple[str, str]:
|
||||
"""Fetch issue details from Jira API (cached after first call).
|
||||
|
||||
@@ -184,131 +169,107 @@ class JiraNewConversationView(JiraViewInterface):
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.payload.issue_id,
|
||||
issue_key=self.payload.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
conversation_metadata = await self._create_v1_metadata()
|
||||
await self._create_v1_conversation(jinja_env, conversation_metadata)
|
||||
return self.conversation_id
|
||||
|
||||
async def _create_v1_metadata(self) -> ConversationMetadata:
|
||||
"""Create conversation metadata for V1 conversations.
|
||||
|
||||
The JiraConversation mapping is saved to the integration store (above), but
|
||||
V1 conversation metadata is managed by the app conversation system, not
|
||||
the legacy conversation store.
|
||||
"""
|
||||
logger.info('[Jira]: Creating V1 metadata')
|
||||
|
||||
# Generate a dummy conversation for V1 (not saved to store)
|
||||
self.conversation_id = uuid4().hex
|
||||
self.resolved_org_id = await self._get_resolved_org_id()
|
||||
|
||||
return ConversationMetadata(
|
||||
conversation_id=self.conversation_id,
|
||||
selected_repository=self.selected_repo,
|
||||
)
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[Jira]: Creating V1 conversation')
|
||||
|
||||
initial_user_text = await self._get_v1_initial_user_message(jinja_env)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
role='user', content=[TextContent(text=initial_user_text)]
|
||||
)
|
||||
|
||||
# Create the Jira V1 callback processor
|
||||
jira_callback_processor = self._create_jira_v1_callback_processor()
|
||||
|
||||
injector_state = InjectorState()
|
||||
|
||||
# Create the V1 conversation start request
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
system_message_suffix=None,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
title=f'Jira Issue {self.payload.issue_key}: {self._issue_title or "Unknown"}',
|
||||
trigger=ConversationTrigger.JIRA,
|
||||
processors=[jira_callback_processor],
|
||||
)
|
||||
|
||||
# Set up the Jira user context for the V1 system
|
||||
jira_user_context = ResolverUserContext(
|
||||
saas_user_auth=self.saas_user_auth,
|
||||
resolver_org_id=self.resolved_org_id,
|
||||
)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, jira_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
injector_state
|
||||
) as app_conversation_service:
|
||||
async for task in app_conversation_service.start_app_conversation(
|
||||
start_request
|
||||
):
|
||||
if task.status == AppConversationStartTaskStatus.ERROR:
|
||||
logger.error(f'Failed to start V1 conversation: {task.detail}')
|
||||
raise RuntimeError(
|
||||
f'Failed to start V1 conversation: {task.detail}'
|
||||
)
|
||||
|
||||
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
|
||||
"""Build the initial user message for V1 resolver conversations."""
|
||||
issue_title, issue_description = await self.get_issue_details()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.payload.issue_key,
|
||||
issue_title=issue_title,
|
||||
issue_description=issue_description,
|
||||
user_message=self.payload.user_msg,
|
||||
)
|
||||
|
||||
return user_msg
|
||||
|
||||
def _create_jira_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for Jira integration."""
|
||||
return JiraV1CallbackProcessor(
|
||||
svc_acc_email=self.jira_workspace.svc_acc_email,
|
||||
decrypted_api_key=self._decrypted_api_key,
|
||||
issue_key=self.payload.issue_key,
|
||||
jira_cloud_id=self.jira_workspace.jira_cloud_id,
|
||||
)
|
||||
|
||||
async def _get_resolved_org_id(self) -> UUID | None:
|
||||
"""Resolve the org ID for V1 conversations."""
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if not provider_tokens:
|
||||
return None
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
repository = await provider_handler.verify_repo_provider(self.selected_repo)
|
||||
resolved_org_id = await resolve_org_for_repo(
|
||||
provider=repository.git_provider.value,
|
||||
full_repo_name=self.selected_repo,
|
||||
keycloak_user_id=self.jira_user.keycloak_user_id,
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
# Resolve git provider from repository
|
||||
resolved_git_provider = None
|
||||
if provider_tokens:
|
||||
try:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
repository = await provider_handler.verify_repo_provider(
|
||||
self.selected_repo
|
||||
)
|
||||
resolved_git_provider = repository.git_provider
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Jira] Failed to resolve git provider for {self.selected_repo}: {e}'
|
||||
)
|
||||
|
||||
# Resolve target org based on claimed git organizations
|
||||
resolved_org_id = None
|
||||
if resolved_git_provider and self.selected_repo:
|
||||
try:
|
||||
resolved_org_id = await resolve_org_for_repo(
|
||||
provider=resolved_git_provider.value,
|
||||
full_repo_name=self.selected_repo,
|
||||
keycloak_user_id=user_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Jira] Failed to resolve org for {self.selected_repo}: {e}'
|
||||
)
|
||||
|
||||
# Create the conversation store with resolver org routing
|
||||
store = await SaasConversationStore.get_resolver_instance(
|
||||
get_config(),
|
||||
user_id,
|
||||
resolved_org_id,
|
||||
)
|
||||
return resolved_org_id
|
||||
|
||||
conversation_id = uuid4().hex
|
||||
conversation_metadata = ConversationMetadata(
|
||||
trigger=ConversationTrigger.JIRA,
|
||||
conversation_id=conversation_id,
|
||||
title=get_default_conversation_title(conversation_id),
|
||||
user_id=user_id,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
git_provider=resolved_git_provider,
|
||||
)
|
||||
await store.save_metadata(conversation_metadata)
|
||||
|
||||
await start_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
initial_user_msg=user_msg,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_id=conversation_id,
|
||||
conversation_metadata=conversation_metadata,
|
||||
conversation_instructions=instructions,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_id
|
||||
|
||||
logger.info(
|
||||
'[Jira] Created conversation',
|
||||
extra={
|
||||
'conversation_id': self.conversation_id,
|
||||
'issue_key': self.payload.issue_key,
|
||||
'selected_repo': self.selected_repo,
|
||||
'resolved_org_id': str(resolved_org_id)
|
||||
if resolved_org_id
|
||||
else None,
|
||||
},
|
||||
)
|
||||
|
||||
# Store Jira conversation mapping
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.payload.issue_id,
|
||||
issue_key=self.payload.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Jira] Failed to resolve org for {self.selected_repo}: {e}'
|
||||
if isinstance(e, StartingConvoException):
|
||||
raise
|
||||
logger.error(
|
||||
'[Jira] Failed to create conversation',
|
||||
extra={'issue_key': self.payload.issue_key, 'error': str(e)},
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira."""
|
||||
|
||||
@@ -20,7 +20,6 @@ from integrations.utils import (
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
get_session_expired_message,
|
||||
markdown_to_jira_markup,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
@@ -469,8 +468,7 @@ class JiraDcManager(Manager[JiraDcViewInterface]):
|
||||
"""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
# Convert standard Markdown to Jira Wiki Markup for proper rendering
|
||||
data = {'body': markdown_to_jira_markup(message)}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -16,23 +16,21 @@ from openhands.core.logger import openhands_logger as logger
|
||||
async def resolve_org_for_repo(
|
||||
provider: str,
|
||||
full_repo_name: str,
|
||||
keycloak_user_id: str | None = None,
|
||||
keycloak_user_id: str,
|
||||
) -> UUID | None:
|
||||
"""Determine the OpenHands org_id for a resolver conversation.
|
||||
|
||||
If the repo's git organization is claimed by an OpenHands org, returns the
|
||||
claiming org's ID. When keycloak_user_id is provided, also verifies the user
|
||||
is a member of that org.
|
||||
If the repo's git organization is claimed by an OpenHands org AND the user
|
||||
is a member of that org, returns the claiming org's ID. Otherwise returns
|
||||
None (caller should fall back to user.current_org_id / personal workspace).
|
||||
|
||||
Args:
|
||||
provider: Git provider name ("github", "gitlab", "bitbucket")
|
||||
full_repo_name: Full repository name (e.g., "OpenHands/foo")
|
||||
keycloak_user_id: The user's Keycloak UUID string (optional). If provided,
|
||||
membership is verified before returning the org_id.
|
||||
keycloak_user_id: The user's Keycloak UUID string
|
||||
|
||||
Returns:
|
||||
The org_id if the repo's org is claimed (and user is a member when
|
||||
keycloak_user_id is provided), else None
|
||||
The org_id if the repo's org is claimed and user is a member, else None
|
||||
"""
|
||||
git_org = full_repo_name.split('/')[0].lower()
|
||||
|
||||
@@ -46,14 +44,6 @@ async def resolve_org_for_repo(
|
||||
)
|
||||
return None
|
||||
|
||||
# Skip membership check if no user_id provided
|
||||
if keycloak_user_id is None:
|
||||
logger.info(
|
||||
f'[OrgResolver] Resolved org {claim.org_id} '
|
||||
f'for {provider}/{git_org} (no user membership check)',
|
||||
)
|
||||
return claim.org_id
|
||||
|
||||
member = await OrgMemberStore.get_org_member(
|
||||
claim.org_id, UUID(keycloak_user_id)
|
||||
)
|
||||
|
||||
@@ -436,13 +436,12 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
r'(?=\s|$|}}|[\]\)\'",.:`])' # right boundary
|
||||
)
|
||||
|
||||
# Use dict to preserve ordering
|
||||
matches: dict[str, bool] = {}
|
||||
matches: list[str] = []
|
||||
|
||||
# Git URLs first (highest priority)
|
||||
for owner, repo in re.findall(git_url_pattern, normalized_msg):
|
||||
repo = re.sub(r'\.git$', '', repo)
|
||||
matches[f'{owner}/{repo}'] = True
|
||||
matches.append(f'{owner}/{repo}')
|
||||
|
||||
# Direct mentions
|
||||
for owner, repo in re.findall(direct_pattern, normalized_msg):
|
||||
@@ -458,10 +457,9 @@ def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
continue
|
||||
|
||||
if full_match not in matches:
|
||||
matches[full_match] = True
|
||||
matches.append(full_match)
|
||||
|
||||
result = list(matches)
|
||||
return result
|
||||
return matches
|
||||
|
||||
|
||||
def filter_potential_repos_by_user_msg(
|
||||
@@ -597,18 +595,3 @@ def markdown_to_jira_markup(markdown_text: str) -> str:
|
||||
# Log the error but don't raise it - return original text as fallback
|
||||
print(f'Error converting markdown to Jira markup: {str(e)}')
|
||||
return markdown_text or ''
|
||||
|
||||
|
||||
def format_jira_comment_body(message: str) -> dict:
|
||||
"""Format a message as a Jira API v2 comment body.
|
||||
|
||||
This helper ensures consistent comment formatting across all Jira integrations.
|
||||
Converts markdown to Jira Wiki Markup and wraps in the expected API structure.
|
||||
|
||||
Args:
|
||||
message: The message content to send (may contain markdown)
|
||||
|
||||
Returns:
|
||||
dict: The comment body in Jira API v2 format {'body': ...}
|
||||
"""
|
||||
return {'body': markdown_to_jira_markup(message)}
|
||||
|
||||
@@ -6,12 +6,6 @@ from logging.config import fileConfig
|
||||
# These plugin setup messages would otherwise appear before logging is configured
|
||||
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
# Prevent SQLAlchemy engine from logging SQL results at DEBUG level, which can
|
||||
# leak sensitive column data (e.g. API keys, tokens) into log aggregators.
|
||||
# This is set before any engine is created so it takes effect immediately.
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.engine.Engine').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine, text # noqa: E402
|
||||
@@ -76,12 +70,6 @@ config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Re-apply SQLAlchemy engine log suppression after fileConfig, which may override
|
||||
# our earlier settings from alembic.ini. This ensures DEBUG-level SQL result logging
|
||||
# is always suppressed, preventing sensitive data from leaking into log aggregators.
|
||||
logging.getLogger('sqlalchemy.engine').setLevel(logging.WARNING)
|
||||
logging.getLogger('sqlalchemy.engine.Engine').setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
@@ -6,6 +6,7 @@ Create Date: 2026-03-26
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -23,18 +24,18 @@ def upgrade() -> None:
|
||||
|
||||
# Migrate existing org-level MCP configs to all members in each org.
|
||||
# This preserves existing configurations while transitioning to user-specific settings.
|
||||
# Uses server-side SQL to avoid pulling sensitive config data into the Python process.
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE org_member
|
||||
SET mcp_config = org.mcp_config
|
||||
FROM org
|
||||
WHERE org_member.org_id = org.id
|
||||
AND org.mcp_config IS NOT NULL
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
orgs_with_config = conn.execute(
|
||||
sa.text('SELECT id, mcp_config FROM org WHERE mcp_config IS NOT NULL')
|
||||
).fetchall()
|
||||
|
||||
for org_id, mcp_config in orgs_with_config:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
'UPDATE org_member SET mcp_config = :config WHERE org_id = :org_id'
|
||||
),
|
||||
{'config': json.dumps(mcp_config), 'org_id': str(org_id)},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Add onboarding_completed column to user table.
|
||||
|
||||
Tracks whether a user has completed the onboarding flow.
|
||||
Used to redirect new SaaS users to /onboarding after accepting TOS.
|
||||
|
||||
Revision ID: 107
|
||||
Revises: 106
|
||||
Create Date: 2026-03-31
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '107'
|
||||
down_revision: Union[str, None] = '106'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user',
|
||||
sa.Column('onboarding_completed', sa.Boolean(), nullable=True, default=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'onboarding_completed')
|
||||
@@ -87,9 +87,6 @@ class Permission(str, Enum):
|
||||
# Git organization claims
|
||||
MANAGE_ORG_CLAIMS = 'manage_org_claims'
|
||||
|
||||
# Manage Automations
|
||||
MANAGE_AUTOMATIONS = 'manage_automations'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
@@ -126,8 +123,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
# Git organization claims
|
||||
Permission.MANAGE_ORG_CLAIMS,
|
||||
# Manage Automations
|
||||
Permission.MANAGE_AUTOMATIONS,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
@@ -151,8 +146,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Git organization claims
|
||||
Permission.MANAGE_ORG_CLAIMS,
|
||||
# Manage Automations
|
||||
Permission.MANAGE_AUTOMATIONS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
@@ -166,8 +159,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
# Manage Automations
|
||||
Permission.MANAGE_AUTOMATIONS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
@@ -56,23 +56,6 @@ RECAPTCHA_SITE_KEY = os.getenv('RECAPTCHA_SITE_KEY', '').strip()
|
||||
RECAPTCHA_HMAC_SECRET = os.getenv('RECAPTCHA_HMAC_SECRET', '').strip()
|
||||
RECAPTCHA_BLOCK_THRESHOLD = float(os.getenv('RECAPTCHA_BLOCK_THRESHOLD', '0.3'))
|
||||
|
||||
# Automation Service
|
||||
AUTOMATION_SERVICE_URL = os.getenv('AUTOMATION_SERVICE_URL', '').strip()
|
||||
if AUTOMATION_SERVICE_URL and not AUTOMATION_SERVICE_URL.startswith(
|
||||
('http://', 'https://')
|
||||
):
|
||||
raise ValueError(
|
||||
f'AUTOMATION_SERVICE_URL must start with http:// or https://, '
|
||||
f'got: {AUTOMATION_SERVICE_URL}'
|
||||
)
|
||||
AUTOMATION_EVENT_FORWARDING_ENABLED = os.getenv(
|
||||
'AUTOMATION_EVENT_FORWARDING_ENABLED', 'false'
|
||||
) in ('1', 'true')
|
||||
# Shared secret for signing payloads sent to automation service (separate from GitHub webhook secret)
|
||||
AUTOMATION_WEBHOOK_SECRET = os.getenv('AUTOMATION_WEBHOOK_SECRET', '').strip()
|
||||
# Default HTTP timeout for automation service requests (seconds)
|
||||
AUTOMATION_SERVICE_TIMEOUT = int(os.getenv('AUTOMATION_SERVICE_TIMEOUT', '30'))
|
||||
|
||||
# Account Defender labels that indicate suspicious activity
|
||||
SUSPICIOUS_LABELS = {
|
||||
'SUSPICIOUS_LOGIN_ACTIVITY',
|
||||
|
||||
@@ -20,7 +20,6 @@ from server.auth.constants import (
|
||||
GITLAB_APP_CLIENT_ID,
|
||||
RECAPTCHA_SITE_KEY,
|
||||
)
|
||||
from server.constants import DEPLOYMENT_MODE
|
||||
|
||||
from openhands.core.config.utils import load_openhands_config
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
@@ -180,7 +179,6 @@ class SaaSServerConfig(ServerConfig):
|
||||
'ENABLE_JIRA': self.enable_jira,
|
||||
'ENABLE_JIRA_DC': self.enable_jira_dc,
|
||||
'ENABLE_LINEAR': self.enable_linear,
|
||||
'DEPLOYMENT_MODE': DEPLOYMENT_MODE,
|
||||
},
|
||||
'PROVIDERS_CONFIGURED': providers_configured,
|
||||
}
|
||||
|
||||
@@ -15,33 +15,6 @@ IS_FEATURE_ENV = (
|
||||
) # Does not include the staging deployment
|
||||
IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
|
||||
|
||||
# _is_all_hands_managed_domain() can be removed/replaced when a self-hosted specific
|
||||
# env var is created (e.g is_self_hosted` or `deployment_mode`)
|
||||
def _is_all_hands_managed_domain(host: str) -> bool:
|
||||
"""Check if the host is an All-Hands managed domain."""
|
||||
return (
|
||||
host == 'app.all-hands.dev'
|
||||
or host == 'app.openhands.ai'
|
||||
or host.endswith('.all-hands.dev')
|
||||
or host.endswith('.openhands.ai')
|
||||
)
|
||||
|
||||
|
||||
def _get_deployment_mode() -> str:
|
||||
"""Determine deployment mode based on WEB_HOST.
|
||||
|
||||
Returns:
|
||||
'cloud' for All-Hands managed infrastructure (app.all-hands.dev, etc.)
|
||||
'self_hosted' for enterprise self-hosted deployments (customer domains)
|
||||
"""
|
||||
if _is_all_hands_managed_domain(HOST):
|
||||
return 'cloud'
|
||||
return 'self_hosted'
|
||||
|
||||
|
||||
DEPLOYMENT_MODE = _get_deployment_mode()
|
||||
|
||||
# Role name constants
|
||||
ROLE_OWNER = 'owner'
|
||||
ROLE_ADMIN = 'admin'
|
||||
|
||||
@@ -27,10 +27,7 @@ from server.auth.user.user_authorizer import (
|
||||
depends_user_authorizer,
|
||||
)
|
||||
from server.config import sign_token
|
||||
from server.constants import (
|
||||
DEPLOYMENT_MODE,
|
||||
IS_FEATURE_ENV,
|
||||
)
|
||||
from server.constants import IS_FEATURE_ENV, IS_LOCAL_ENV
|
||||
from server.routes.event_webhook import _get_session_api_key, _get_user_id
|
||||
from server.services.org_invitation_service import (
|
||||
EmailMismatchError,
|
||||
@@ -465,20 +462,8 @@ async def keycloak_callback(
|
||||
tos_redirect_url = f'{tos_redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(tos_redirect_url, status_code=302)
|
||||
else:
|
||||
# User has accepted TOS - check if they need onboarding
|
||||
# Only redirect to onboarding if user has a valid offline token,
|
||||
# otherwise they need to complete the Keycloak offline token flow first
|
||||
if valid_offline_token and await _should_redirect_to_onboarding(user_id, user):
|
||||
redirect_url = f'{web_url}/onboarding'
|
||||
logger.info(
|
||||
'Redirecting returning user to onboarding',
|
||||
extra={'user_id': user_id, 'deployment_mode': DEPLOYMENT_MODE},
|
||||
)
|
||||
if invitation_token:
|
||||
if '?' in redirect_url:
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
else:
|
||||
redirect_url = f'{redirect_url}?invitation_success=true'
|
||||
redirect_url = f'{redirect_url}&invitation_success=true'
|
||||
response = RedirectResponse(redirect_url, status_code=302)
|
||||
|
||||
set_response_cookie(
|
||||
@@ -486,7 +471,7 @@ async def keycloak_callback(
|
||||
response=response,
|
||||
keycloak_access_token=keycloak_access_token,
|
||||
keycloak_refresh_token=keycloak_refresh_token,
|
||||
secure=True if web_url.startswith('https') else False,
|
||||
secure=True if redirect_url.startswith('https') else False,
|
||||
accepted_tos=has_accepted_tos,
|
||||
)
|
||||
|
||||
@@ -527,23 +512,8 @@ async def keycloak_offline_callback(code: str, state: str, request: Request):
|
||||
user_id=user_info.sub, offline_token=keycloak_refresh_token
|
||||
)
|
||||
|
||||
user = await UserStore.get_user_by_id(user_info.sub)
|
||||
has_accepted_tos = user is not None and user.accepted_tos is not None
|
||||
|
||||
redirect_url, _, _ = _extract_oauth_state(state)
|
||||
default_url = redirect_url if redirect_url else web_url
|
||||
final_url = await _get_post_auth_redirect(user_info.sub, default_url, web_url, user)
|
||||
|
||||
response = RedirectResponse(final_url, status_code=302)
|
||||
set_response_cookie(
|
||||
request=request,
|
||||
response=response,
|
||||
keycloak_access_token=keycloak_access_token,
|
||||
keycloak_refresh_token=keycloak_refresh_token,
|
||||
secure=True if web_url.startswith('https') else False,
|
||||
accepted_tos=has_accepted_tos,
|
||||
)
|
||||
return response
|
||||
return RedirectResponse(redirect_url if redirect_url else web_url, status_code=302)
|
||||
|
||||
|
||||
@oauth_router.get('/github/callback')
|
||||
@@ -579,74 +549,6 @@ async def authenticate(request: Request):
|
||||
return response
|
||||
|
||||
|
||||
async def _should_redirect_to_onboarding(user_id: str, user: User) -> bool:
|
||||
"""Check if user should be redirected to onboarding after TOS acceptance.
|
||||
|
||||
Backend always redirects applicable users to /onboarding. The frontend
|
||||
checks the ENABLE_ONBOARDING feature flag (localStorage) and redirects
|
||||
to / if the flag is disabled. This avoids needing helm chart changes.
|
||||
|
||||
|
||||
Returns True if:
|
||||
- User has onboarding_completed explicitly set to False (new users)
|
||||
- Either:
|
||||
- Deployment mode is 'cloud' (all users)
|
||||
- Deployment mode is 'self_hosted' AND user is the super admin
|
||||
(first owner in their current org to accept TOS)
|
||||
|
||||
Returns False if:
|
||||
- User has onboarding_completed=True (already completed)
|
||||
- User has onboarding_completed=None (existing users before this feature)
|
||||
"""
|
||||
# Already completed onboarding
|
||||
if user.onboarding_completed is True:
|
||||
return False
|
||||
|
||||
# Existing user before this feature (NULL in database)
|
||||
if user.onboarding_completed is None:
|
||||
return False
|
||||
|
||||
# Cloud SaaS: all users go to onboarding
|
||||
if DEPLOYMENT_MODE == 'cloud':
|
||||
return True
|
||||
|
||||
# Self-hosted SaaS: only the super admin (first owner to accept TOS in the org)
|
||||
if DEPLOYMENT_MODE == 'self_hosted':
|
||||
first_owner = await UserStore.get_first_owner_in_org(user.current_org_id)
|
||||
if first_owner and str(first_owner.id) == user_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def _get_post_auth_redirect(
|
||||
user_id: str, default_url: str, web_url: str, user: User | None = None
|
||||
) -> str:
|
||||
"""Determine where to redirect user after authentication completes.
|
||||
|
||||
Called after offline token is stored to determine final redirect destination.
|
||||
Checks for pending user flows (e.g., onboarding) before falling back to default.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID.
|
||||
default_url: The default URL to redirect to if no special flow is needed.
|
||||
web_url: The base web URL for constructing absolute paths.
|
||||
user: Optional user object to avoid refetching.
|
||||
|
||||
Returns:
|
||||
The URL to redirect the user to.
|
||||
"""
|
||||
if not user:
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user and await _should_redirect_to_onboarding(user_id, user):
|
||||
logger.info(
|
||||
'Redirecting user to onboarding',
|
||||
extra={'user_id': user_id, 'deployment_mode': DEPLOYMENT_MODE},
|
||||
)
|
||||
return f'{web_url}/onboarding'
|
||||
return default_url
|
||||
|
||||
|
||||
@api_router.post('/accept_tos')
|
||||
async def accept_tos(request: Request):
|
||||
user_auth = cast(SaasUserAuth, await get_user_auth(request))
|
||||
@@ -687,12 +589,6 @@ async def accept_tos(request: Request):
|
||||
|
||||
logger.info(f'User {user_id} accepted TOS')
|
||||
|
||||
# Determine final redirect - but don't override if it's the offline token flow
|
||||
# (the offline callback will handle post-auth redirect after storing the token)
|
||||
is_offline_flow = 'offline' in redirect_url
|
||||
if not is_offline_flow:
|
||||
redirect_url = await _get_post_auth_redirect(user_id, redirect_url, web_url)
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={'redirect_url': redirect_url}
|
||||
)
|
||||
@@ -702,42 +598,12 @@ async def accept_tos(request: Request):
|
||||
response=response,
|
||||
keycloak_access_token=access_token.get_secret_value(),
|
||||
keycloak_refresh_token=refresh_token.get_secret_value(),
|
||||
secure=True if web_url.startswith('https') else False,
|
||||
secure=not IS_LOCAL_ENV,
|
||||
accepted_tos=True,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@api_router.post('/complete_onboarding')
|
||||
async def complete_onboarding(request: Request):
|
||||
"""Mark onboarding as completed for the current user."""
|
||||
user_auth = cast(SaasUserAuth, await get_user_auth(request))
|
||||
user_id = await user_auth.get_user_id()
|
||||
|
||||
if not user_id:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={'error': 'User is not authenticated'},
|
||||
)
|
||||
|
||||
user = await UserStore.mark_onboarding_completed(user_id)
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
content={'error': 'User not found'},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'User completed onboarding',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK,
|
||||
content={'message': 'Onboarding completed'},
|
||||
)
|
||||
|
||||
|
||||
@api_router.post('/logout')
|
||||
async def logout(request: Request):
|
||||
# Always create the response object first to ensure we can return it even if errors occur
|
||||
|
||||
@@ -3,17 +3,13 @@ import hashlib
|
||||
import hmac
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request
|
||||
from fastapi import APIRouter, Header, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from integrations.github.data_collector import GitHubDataCollector
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.models import Message, SourceType
|
||||
from server.auth.constants import (
|
||||
AUTOMATION_EVENT_FORWARDING_ENABLED,
|
||||
GITHUB_APP_WEBHOOK_SECRET,
|
||||
)
|
||||
from server.auth.constants import GITHUB_APP_WEBHOOK_SECRET
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -26,7 +22,6 @@ github_integration_router = APIRouter(prefix='/integration')
|
||||
token_manager = TokenManager()
|
||||
data_collector = GitHubDataCollector()
|
||||
github_manager = GithubManager(token_manager, data_collector)
|
||||
automation_event_service = AutomationEventService(token_manager)
|
||||
|
||||
|
||||
def verify_github_signature(payload: bytes, signature: str):
|
||||
@@ -51,9 +46,7 @@ def verify_github_signature(payload: bytes, signature: str):
|
||||
@github_integration_router.post('/github/events')
|
||||
async def github_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_hub_signature_256: str = Header(None),
|
||||
x_github_event: str = Header(None),
|
||||
):
|
||||
# Check if GitHub webhooks are enabled
|
||||
if not GITHUB_WEBHOOKS_ENABLED:
|
||||
@@ -79,18 +72,6 @@ async def github_events(
|
||||
content={'error': 'Installation ID is missing in the payload.'},
|
||||
)
|
||||
|
||||
# Forward to automation service (fire-and-forget background task)
|
||||
if AUTOMATION_EVENT_FORWARDING_ENABLED:
|
||||
logger.info(
|
||||
f'triggering forward_github_event with payload: {payload_data}, installation_id: {installation_id}'
|
||||
)
|
||||
background_tasks.add_task(
|
||||
automation_event_service.forward_github_event,
|
||||
payload=payload_data,
|
||||
installation_id=installation_id,
|
||||
)
|
||||
|
||||
# Existing resolver bot processing
|
||||
message_payload = {'payload': payload_data, 'installation': installation_id}
|
||||
message = Message(source=SourceType.GITHUB, message=message_payload)
|
||||
await github_manager.receive_message(message)
|
||||
|
||||
@@ -149,12 +149,7 @@ async def verify_jira_signature(body: bytes, signature: str, payload: dict):
|
||||
|
||||
workspace_name = jira_manager.get_workspace_name_from_payload(payload)
|
||||
if workspace_name is None:
|
||||
logger.warning(
|
||||
'[Jira] No workspace name found in webhook payload',
|
||||
extra={
|
||||
'payload': payload,
|
||||
},
|
||||
)
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
raise HTTPException(
|
||||
status_code=403, detail='Workspace name not found in payload'
|
||||
)
|
||||
|
||||
@@ -1,448 +0,0 @@
|
||||
"""
|
||||
Service for forwarding GitHub webhook events to the automation service.
|
||||
|
||||
This service is optimized for high-traffic scenarios:
|
||||
1. Resolves GitHub org → OpenHands org_id (via cached OrgGitClaim lookup)
|
||||
2. For personal repos, resolves to personal org (via cached GitHub→Keycloak mapping)
|
||||
3. Forwards minimal payload to automation service (just org_id + payload)
|
||||
4. Access control checks are deferred to automation execution time
|
||||
|
||||
The lazy access control approach means:
|
||||
- Most webhooks only do cached lookups + HTTP forward
|
||||
- Membership checks only happen when an automation actually matches
|
||||
|
||||
Security notes:
|
||||
- Uses AUTOMATION_WEBHOOK_SECRET (not GitHub webhook secret) for internal service signing
|
||||
- Negative results are cached to prevent DoS via repeated lookups for unclaimed orgs
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from server.auth.constants import (
|
||||
AUTOMATION_SERVICE_TIMEOUT,
|
||||
AUTOMATION_SERVICE_URL,
|
||||
AUTOMATION_WEBHOOK_SECRET,
|
||||
)
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.server.shared import sio
|
||||
|
||||
# Cache TTL constants
|
||||
ORG_CLAIM_CACHE_TTL_SECONDS = 3600 # 1 hour for org claims (rarely change)
|
||||
USER_ID_CACHE_TTL_SECONDS = 86400 # 24 hours for user ID mappings (never change)
|
||||
|
||||
# Cache key prefixes
|
||||
ORG_CLAIM_CACHE_PREFIX = 'automation:org_claim'
|
||||
USER_ID_CACHE_PREFIX = 'automation:gh_to_kc_user'
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgContext:
|
||||
"""Context for the resolved organization."""
|
||||
|
||||
org_id: UUID
|
||||
github_org: str
|
||||
|
||||
|
||||
class AutomationEventService:
|
||||
"""
|
||||
Service for forwarding webhook events to the automation service.
|
||||
|
||||
Optimized for high traffic with:
|
||||
- Redis caching for org claim lookups (1 hour TTL)
|
||||
- Redis caching for GitHub→Keycloak user ID mappings (24 hour TTL)
|
||||
- Lazy access control (membership checks deferred to execution time)
|
||||
"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
from server.auth.constants import AUTOMATION_EVENT_FORWARDING_ENABLED
|
||||
|
||||
self.token_manager = token_manager
|
||||
|
||||
# Fail fast if forwarding is enabled but misconfigured
|
||||
if AUTOMATION_EVENT_FORWARDING_ENABLED:
|
||||
if not AUTOMATION_SERVICE_URL:
|
||||
raise ValueError(
|
||||
'AUTOMATION_EVENT_FORWARDING_ENABLED=true but '
|
||||
'AUTOMATION_SERVICE_URL is not configured'
|
||||
)
|
||||
if not AUTOMATION_WEBHOOK_SECRET:
|
||||
raise ValueError(
|
||||
'AUTOMATION_EVENT_FORWARDING_ENABLED=true but '
|
||||
'AUTOMATION_WEBHOOK_SECRET is not configured'
|
||||
)
|
||||
|
||||
async def forward_github_event(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
installation_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Forward a GitHub webhook event to the automation service.
|
||||
|
||||
This is designed to be called as a fire-and-forget background task.
|
||||
The forward path is optimized for speed - only org resolution is done here.
|
||||
Access control checks are deferred to automation execution time.
|
||||
|
||||
Args:
|
||||
payload: The raw GitHub webhook payload
|
||||
installation_id: The GitHub App installation ID
|
||||
"""
|
||||
org_id: UUID | None = None
|
||||
try:
|
||||
logger.info(f'Retrieving org context for payload {payload}')
|
||||
# Resolve org context (org_id and github_org name) - uses Redis cache
|
||||
org_context = await self._resolve_org_context(payload)
|
||||
logger.info(f'org context for payload {payload} is {org_context}')
|
||||
if not org_context:
|
||||
return
|
||||
|
||||
org_id = org_context.org_id
|
||||
|
||||
# Build minimal payload and forward immediately
|
||||
# Access control is NOT computed here - it's deferred to execution time
|
||||
event_payload = self._build_event_payload(org_context, payload)
|
||||
logger.info(f'event_payload is {event_payload}')
|
||||
await self._send_to_automation_service(org_id, event_payload)
|
||||
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
# Network errors are expected and recoverable
|
||||
logger.error(
|
||||
f'[AutomationEventService] Network error forwarding event '
|
||||
f'(org_id={org_id}): {e}',
|
||||
exc_info=True,
|
||||
extra={'installation_id': installation_id},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log unexpected errors. Note: This is a background task, so exceptions
|
||||
# won't surface to the HTTP caller - they're logged for debugging only.
|
||||
logger.error(
|
||||
f'[AutomationEventService] Unexpected error forwarding event '
|
||||
f'(org_id={org_id}): {e}',
|
||||
exc_info=True,
|
||||
extra={'installation_id': installation_id},
|
||||
)
|
||||
# Don't re-raise in background task - just log for debugging
|
||||
|
||||
async def _resolve_org_context(self, payload: dict[str, Any]) -> OrgContext | None:
|
||||
"""
|
||||
Resolve the organization context from the webhook payload.
|
||||
|
||||
Uses Redis caching for both org claims and user ID mappings.
|
||||
Returns None if the org cannot be resolved (not claimed, no personal org).
|
||||
"""
|
||||
repo = payload.get('repository', {})
|
||||
owner = repo.get('owner', {})
|
||||
git_org_name = owner.get('login')
|
||||
owner_type = owner.get('type') # 'User' or 'Organization'
|
||||
|
||||
if not git_org_name:
|
||||
logger.warning(
|
||||
'[AutomationEventService] No repository owner in payload, skipping'
|
||||
)
|
||||
return None
|
||||
|
||||
# Try to resolve via OrgGitClaim
|
||||
org_id = await self._resolve_github_org(git_org_name)
|
||||
|
||||
# Fallback for personal repos
|
||||
if not org_id and owner_type == 'User':
|
||||
org_id = await self._resolve_personal_org(owner.get('id'))
|
||||
if org_id:
|
||||
logger.info(
|
||||
f'[AutomationEventService] Resolved personal repo owner '
|
||||
f'{git_org_name} to personal org {org_id}'
|
||||
)
|
||||
|
||||
if not org_id:
|
||||
logger.warning(
|
||||
f'[AutomationEventService] GitHub org {git_org_name} not claimed '
|
||||
f'and no personal org found, skipping'
|
||||
)
|
||||
return None
|
||||
|
||||
return OrgContext(org_id=org_id, github_org=git_org_name)
|
||||
|
||||
def _build_event_payload(
|
||||
self,
|
||||
org_context: OrgContext,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the minimal event payload to forward to the automation service.
|
||||
|
||||
Access control is NOT included here - it's deferred to execution time.
|
||||
This keeps the forward path fast for high-traffic scenarios.
|
||||
"""
|
||||
return {
|
||||
'organization': {
|
||||
'github_org': org_context.github_org,
|
||||
'openhands_org_id': str(org_context.org_id),
|
||||
},
|
||||
'payload': payload,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Cached Org Resolution Methods
|
||||
# =========================================================================
|
||||
|
||||
async def _resolve_github_org(self, git_org_name: str) -> UUID | None:
|
||||
"""
|
||||
Resolve a GitHub organization name to an OpenHands org_id.
|
||||
|
||||
Uses Redis caching with 1-hour TTL. Caches both positive and negative
|
||||
results to avoid repeated DB queries for unclaimed orgs.
|
||||
|
||||
Note: GitHub org names are case-insensitive. We normalize to lowercase
|
||||
for both cache keys and DB queries. This matches the OrgGitClaim schema
|
||||
which stores git_organization as lowercase (enforced by GitOrgClaimRequest
|
||||
validator in org_models.py).
|
||||
"""
|
||||
normalized_org = git_org_name.lower()
|
||||
cache_key = f'{ORG_CLAIM_CACHE_PREFIX}:{normalized_org}'
|
||||
|
||||
# Check cache first
|
||||
cached = await self._get_cached_value(cache_key)
|
||||
if cached is not None:
|
||||
if cached == 'none':
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit (negative): org {git_org_name} not claimed'
|
||||
)
|
||||
return None
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit: org {git_org_name} -> {cached}'
|
||||
)
|
||||
return UUID(cached)
|
||||
|
||||
# Cache miss - use resolve_org_for_repo without user_id (no membership check)
|
||||
# Construct a minimal repo name since resolve_org_for_repo extracts the org
|
||||
org_id = await resolve_org_for_repo(
|
||||
provider='github',
|
||||
full_repo_name=f'{normalized_org}/',
|
||||
)
|
||||
|
||||
# Cache the result (including negative results)
|
||||
if org_id:
|
||||
await self._set_cached_value(
|
||||
cache_key, str(org_id), ORG_CLAIM_CACHE_TTL_SECONDS
|
||||
)
|
||||
return org_id
|
||||
else:
|
||||
# Cache negative result to avoid repeated DB queries
|
||||
await self._set_cached_value(cache_key, 'none', ORG_CLAIM_CACHE_TTL_SECONDS)
|
||||
return None
|
||||
|
||||
async def _resolve_personal_org(self, github_user_id: int | None) -> UUID | None:
|
||||
"""
|
||||
Resolve a GitHub user to their personal OpenHands org.
|
||||
|
||||
For personal repos (owner type is 'User'), the OpenHands org_id
|
||||
is the user's keycloak user ID. This allows users to set up
|
||||
automations on their personal repos without needing an OrgGitClaim.
|
||||
|
||||
Uses Redis caching for the GitHub→Keycloak user ID mapping (24h TTL).
|
||||
"""
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
keycloak_id = await self._get_keycloak_user_id_cached(github_user_id)
|
||||
if keycloak_id:
|
||||
return UUID(keycloak_id)
|
||||
return None
|
||||
|
||||
async def _get_keycloak_user_id_cached(self, github_user_id: int) -> str | None:
|
||||
"""
|
||||
Convert a GitHub user ID to a Keycloak user ID.
|
||||
|
||||
Uses Redis caching with 24-hour TTL since this mapping never changes.
|
||||
Caches negative results to avoid repeated Keycloak queries.
|
||||
"""
|
||||
cache_key = f'{USER_ID_CACHE_PREFIX}:{github_user_id}'
|
||||
|
||||
# Check cache first
|
||||
cached = await self._get_cached_value(cache_key)
|
||||
if cached is not None:
|
||||
if cached == 'none':
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit (negative): GitHub user {github_user_id} not in Keycloak'
|
||||
)
|
||||
return None
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit: GitHub user {github_user_id} -> Keycloak {cached}'
|
||||
)
|
||||
return cached
|
||||
|
||||
# Cache miss - query Keycloak
|
||||
try:
|
||||
keycloak_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
str(github_user_id), ProviderType.GITHUB
|
||||
)
|
||||
|
||||
# Cache the result (including negative results)
|
||||
if keycloak_id:
|
||||
await self._set_cached_value(
|
||||
cache_key, keycloak_id, USER_ID_CACHE_TTL_SECONDS
|
||||
)
|
||||
else:
|
||||
# Cache negative result to prevent repeated Keycloak queries (DoS mitigation)
|
||||
await self._set_cached_value(
|
||||
cache_key, 'none', USER_ID_CACHE_TTL_SECONDS
|
||||
)
|
||||
|
||||
return keycloak_id
|
||||
except Exception as e:
|
||||
# Log at warning level to surface programmer errors and API issues
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Failed to get keycloak ID for GitHub user {github_user_id}: {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Generic Redis Cache Helpers
|
||||
# =========================================================================
|
||||
|
||||
async def _get_cached_value(self, cache_key: str) -> str | None:
|
||||
"""
|
||||
Get a cached value from Redis.
|
||||
|
||||
Returns the cached string value, or None if not cached or Redis unavailable.
|
||||
Falls back to DB/API queries if Redis is unavailable (graceful degradation).
|
||||
|
||||
Warning: When Redis is unavailable, every webhook will hit the DB directly.
|
||||
Monitor logs for 'Redis unavailable' warnings to detect degradation.
|
||||
"""
|
||||
try:
|
||||
redis = getattr(sio.manager, 'redis', None)
|
||||
if not redis:
|
||||
# Log at warning level - this is a significant degradation that
|
||||
# will cause DB load. Monitor these logs for alerting.
|
||||
logger.warning(
|
||||
'[AutomationEventService] Redis unavailable for cache read, '
|
||||
'falling back to direct DB queries (this will increase DB load)'
|
||||
)
|
||||
return None
|
||||
|
||||
cached = await redis.get(cache_key)
|
||||
if cached is None:
|
||||
return None
|
||||
|
||||
# Redis returns bytes, decode to string
|
||||
return cached.decode('utf-8') if isinstance(cached, bytes) else cached
|
||||
except Exception as e:
|
||||
# Log at warning level - cache errors cause DB fallback
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Redis cache read error (falling back to DB): {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
async def _set_cached_value(
|
||||
self, cache_key: str, value: str, ttl_seconds: int
|
||||
) -> None:
|
||||
"""
|
||||
Set a cached value in Redis with TTL.
|
||||
|
||||
Fails silently if Redis is unavailable (graceful degradation).
|
||||
"""
|
||||
try:
|
||||
redis = getattr(sio.manager, 'redis', None)
|
||||
if not redis:
|
||||
# Silent failure - read path already logs the warning
|
||||
return
|
||||
|
||||
await redis.setex(cache_key, ttl_seconds, value)
|
||||
except Exception as e:
|
||||
# Log at warning level for visibility
|
||||
logger.warning(f'[AutomationEventService] Redis cache write error: {e}')
|
||||
|
||||
def _sign_payload(self, payload_bytes: bytes) -> str:
|
||||
"""
|
||||
Sign a payload using the dedicated automation shared secret.
|
||||
|
||||
Uses AUTOMATION_WEBHOOK_SECRET (not GitHub webhook secret) to maintain
|
||||
separate trust boundaries between GitHub webhooks and internal services.
|
||||
|
||||
Returns the signature in the format 'sha256=<hex_digest>'.
|
||||
"""
|
||||
signature = hmac.new(
|
||||
AUTOMATION_WEBHOOK_SECRET.encode('utf-8'),
|
||||
msg=payload_bytes,
|
||||
digestmod=hashlib.sha256,
|
||||
).hexdigest()
|
||||
return f'sha256={signature}'
|
||||
|
||||
async def _send_to_automation_service(
|
||||
self,
|
||||
org_id: UUID,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Send the normalized payload to the automation service.
|
||||
|
||||
The payload is signed using AUTOMATION_WEBHOOK_SECRET so the
|
||||
automation service can verify it came from the OpenHands server.
|
||||
"""
|
||||
if not AUTOMATION_SERVICE_URL:
|
||||
logger.warning(
|
||||
'[AutomationEventService] AUTOMATION_SERVICE_URL not configured'
|
||||
)
|
||||
return
|
||||
|
||||
# Build endpoint URL. AUTOMATION_SERVICE_URL may include path segments
|
||||
# (e.g., https://example.com/api/automation), so we strip trailing slash
|
||||
# and append our path.
|
||||
url = f'{AUTOMATION_SERVICE_URL.rstrip("/")}/v1/events/{org_id}/github'
|
||||
|
||||
# Serialize payload to JSON bytes for signing
|
||||
payload_bytes = json.dumps(payload, separators=(',', ':')).encode('utf-8')
|
||||
signature = self._sign_payload(payload_bytes)
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Hub-Signature-256': signature,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
data=payload_bytes,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=AUTOMATION_SERVICE_TIMEOUT),
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
# Try JSON first (expected interface), fall back to text
|
||||
# for infrastructure errors (502/503 from load balancer)
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
body = await resp.text()
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Automation service returned '
|
||||
f'{resp.status} for org {org_id}: {body}'
|
||||
)
|
||||
else:
|
||||
data = await resp.json()
|
||||
matched = data.get('matched', 0)
|
||||
logger.info(
|
||||
f'[AutomationEventService] Forwarded event to org {org_id}: '
|
||||
f'{matched} automations matched'
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Timeout ({AUTOMATION_SERVICE_TIMEOUT}s) '
|
||||
'forwarding to automation service'
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.warning(
|
||||
f'[AutomationEventService] HTTP error forwarding to automation service: {e}'
|
||||
)
|
||||
@@ -1,12 +1,4 @@
|
||||
"""Shared Event router for OpenHands Server.
|
||||
|
||||
All endpoints in this router are unauthenticated — shared conversations are
|
||||
public. To avoid returning internal system state that the viewer does not
|
||||
need, ``ConversationStateUpdateEvent`` instances are filtered out before the
|
||||
response is sent. The shared-conversation frontend only renders messages,
|
||||
actions, observations, errors, and hook-execution events; state snapshots
|
||||
are consumed exclusively by the authenticated WebSocket path.
|
||||
"""
|
||||
"""Shared Event router for OpenHands Server."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
@@ -21,15 +13,9 @@ from server.sharing.shared_event_service import (
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent
|
||||
from openhands.utils.environment import StorageProvider, get_storage_provider
|
||||
|
||||
|
||||
def _is_viewable(event: Event) -> bool:
|
||||
"""Return True if *event* should be included in public shared responses."""
|
||||
return not isinstance(event, ConversationStateUpdateEvent)
|
||||
|
||||
|
||||
def get_shared_event_service_injector() -> SharedEventServiceInjector:
|
||||
"""Get the appropriate SharedEventServiceInjector based on configuration.
|
||||
|
||||
@@ -101,36 +87,15 @@ async def search_shared_events(
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation.
|
||||
|
||||
Because non-viewable events (e.g. ``ConversationStateUpdateEvent``) are
|
||||
filtered out after fetching, a single backend page may yield fewer items
|
||||
than *limit*. This method transparently fetches additional backend pages
|
||||
until the requested *limit* is reached or there are no more results.
|
||||
"""
|
||||
conv_id = UUID(conversation_id)
|
||||
viewable: list[Event] = []
|
||||
cursor = page_id
|
||||
|
||||
while len(viewable) < limit:
|
||||
remaining = limit - len(viewable)
|
||||
page = await shared_event_service.search_shared_events(
|
||||
conversation_id=conv_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=cursor,
|
||||
limit=remaining,
|
||||
)
|
||||
viewable.extend(e for e in page.items if _is_viewable(e))
|
||||
cursor = page.next_page_id
|
||||
if cursor is None:
|
||||
break
|
||||
|
||||
return EventPage(
|
||||
items=viewable[:limit],
|
||||
next_page_id=cursor,
|
||||
"""Search / List events for a shared conversation."""
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -182,7 +147,7 @@ async def batch_get_shared_events(
|
||||
events = await shared_event_service.batch_get_shared_events(
|
||||
UUID(conversation_id), event_ids
|
||||
)
|
||||
return [e if e is not None and _is_viewable(e) else None for e in events]
|
||||
return events
|
||||
|
||||
|
||||
@router.get('/{conversation_id}/{event_id}')
|
||||
@@ -192,9 +157,6 @@ async def get_shared_event(
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> Event | None:
|
||||
"""Get a single event from a shared conversation by conversation_id and event_id."""
|
||||
event = await shared_event_service.get_shared_event(
|
||||
return await shared_event_service.get_shared_event(
|
||||
UUID(conversation_id), UUID(event_id)
|
||||
)
|
||||
if event is not None and not _is_viewable(event):
|
||||
return None
|
||||
return event
|
||||
|
||||
@@ -36,7 +36,6 @@ class User(Base): # type: ignore
|
||||
git_user_email = Column(String, nullable=True)
|
||||
sandbox_grouping_strategy = Column(String, nullable=True)
|
||||
disabled_skills = Column(JSON, nullable=True)
|
||||
onboarding_completed = Column(Boolean, nullable=True, default=False)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
|
||||
@@ -24,7 +24,6 @@ from storage.encrypt_utils import (
|
||||
)
|
||||
from storage.org import Org
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
from storage.user import User
|
||||
from storage.user_settings import UserSettings
|
||||
@@ -750,65 +749,6 @@ class UserStore:
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def mark_onboarding_completed(user_id: str) -> Optional[User]:
|
||||
"""Mark the user's onboarding as completed.
|
||||
|
||||
Args:
|
||||
user_id: The user's ID (Keycloak user ID)
|
||||
|
||||
Returns:
|
||||
User: The updated user object, or None if user not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User).filter(User.id == uuid.UUID(user_id)).with_for_update()
|
||||
)
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
logger.warning(
|
||||
'mark_onboarding_completed:user_not_found',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return None
|
||||
|
||||
user.onboarding_completed = True
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
logger.info(
|
||||
'mark_onboarding_completed:success',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
async def get_first_owner_in_org(org_id: UUID) -> Optional[User]:
|
||||
"""Get the first owner in an organization who accepted the Terms of Service.
|
||||
|
||||
This user is considered the super admin for that org in self-hosted deployments.
|
||||
The super admin is identified as the owner with the earliest accepted_tos timestamp.
|
||||
|
||||
Args:
|
||||
org_id: The organization UUID
|
||||
|
||||
Returns:
|
||||
User: The first owner to accept TOS in this org, or None if not found.
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.join(OrgMember, OrgMember.user_id == User.id)
|
||||
.join(Role, Role.id == OrgMember.role_id)
|
||||
.filter(
|
||||
OrgMember.org_id == org_id,
|
||||
Role.name == 'owner',
|
||||
User.accepted_tos.isnot(None),
|
||||
)
|
||||
.order_by(User.accepted_tos.asc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def backfill_contact_name(user_id: str, user_info: dict) -> None:
|
||||
"""Update contact_name on the personal org if it still has a username-style value.
|
||||
|
||||
@@ -206,7 +206,7 @@ def new_conversation_view(
|
||||
sample_webhook_payload, sample_user_auth, sample_jira_user, sample_jira_workspace
|
||||
):
|
||||
"""JiraNewConversationView instance for testing"""
|
||||
view = JiraNewConversationView(
|
||||
return JiraNewConversationView(
|
||||
payload=sample_webhook_payload,
|
||||
saas_user_auth=sample_user_auth,
|
||||
jira_user=sample_jira_user,
|
||||
@@ -215,8 +215,6 @@ def new_conversation_view(
|
||||
conversation_id='conv-123',
|
||||
_decrypted_api_key='decrypted_key',
|
||||
)
|
||||
view.v1_enabled = False
|
||||
return view
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -202,10 +202,14 @@ class TestStartJob:
|
||||
)
|
||||
jira_manager._send_comment = AsyncMock()
|
||||
|
||||
await jira_manager.start_job(new_conversation_view)
|
||||
with patch(
|
||||
'integrations.jira.jira_manager.register_callback_processor'
|
||||
) as mock_register:
|
||||
await jira_manager.start_job(new_conversation_view)
|
||||
|
||||
new_conversation_view.create_or_update_conversation.assert_called_once()
|
||||
jira_manager._send_comment.assert_called_once()
|
||||
new_conversation_view.create_or_update_conversation.assert_called_once()
|
||||
mock_register.assert_called_once()
|
||||
jira_manager._send_comment.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_job_missing_settings_error(
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
"""Tests for JiraV1CallbackProcessor.
|
||||
|
||||
This module tests the V1 callback processor that handles Jira integration
|
||||
callbacks when conversations complete.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from integrations.jira.jira_v1_callback_processor import (
|
||||
JIRA_CLOUD_API_URL,
|
||||
JiraV1CallbackProcessor,
|
||||
)
|
||||
|
||||
from openhands.app_server.event_callback.event_callback_models import EventCallback
|
||||
from openhands.app_server.event_callback.event_callback_result_models import (
|
||||
EventCallbackResultStatus,
|
||||
)
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def callback_processor():
|
||||
"""Create a JiraV1CallbackProcessor for testing."""
|
||||
return JiraV1CallbackProcessor(
|
||||
svc_acc_email='service@example.com',
|
||||
decrypted_api_key='test_api_key',
|
||||
issue_key='TEST-123',
|
||||
jira_cloud_id='cloud-123',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_event_callback():
|
||||
"""Create a mock EventCallback."""
|
||||
callback = MagicMock(spec=EventCallback)
|
||||
callback.id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
callback.conversation_id = UUID('87654321-4321-8765-4321-876543218765')
|
||||
return callback
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def finished_event():
|
||||
"""Create a ConversationStateUpdateEvent for finished state."""
|
||||
return ConversationStateUpdateEvent(
|
||||
id='event-123',
|
||||
key='execution_status',
|
||||
value='finished',
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def running_event():
|
||||
"""Create a ConversationStateUpdateEvent for running state."""
|
||||
return ConversationStateUpdateEvent(
|
||||
id='event-456',
|
||||
key='execution_status',
|
||||
value='running',
|
||||
)
|
||||
|
||||
|
||||
class TestJiraV1CallbackProcessor:
|
||||
"""Tests for JiraV1CallbackProcessor."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_non_conversation_state_events(
|
||||
self, callback_processor, mock_event_callback
|
||||
):
|
||||
"""Test that non-ConversationStateUpdateEvent events are ignored."""
|
||||
# Use a different event type (mock)
|
||||
other_event = MagicMock()
|
||||
other_event.__class__ = object # Not ConversationStateUpdateEvent
|
||||
|
||||
result = await callback_processor(
|
||||
conversation_id=uuid4(),
|
||||
callback=mock_event_callback,
|
||||
event=other_event,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_non_execution_status_keys(
|
||||
self, callback_processor, mock_event_callback
|
||||
):
|
||||
"""Test that events with keys other than 'execution_status' are ignored."""
|
||||
event = ConversationStateUpdateEvent(
|
||||
id='event-123',
|
||||
key='agent_status', # Different key
|
||||
value='finished',
|
||||
)
|
||||
|
||||
result = await callback_processor(
|
||||
conversation_id=uuid4(),
|
||||
callback=mock_event_callback,
|
||||
event=event,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignores_non_finished_status(
|
||||
self, callback_processor, mock_event_callback, running_event
|
||||
):
|
||||
"""Test that non-finished execution statuses are ignored."""
|
||||
result = await callback_processor(
|
||||
conversation_id=uuid4(),
|
||||
callback=mock_event_callback,
|
||||
event=running_event,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_requests_summary_once(
|
||||
self, callback_processor, mock_event_callback, finished_event
|
||||
):
|
||||
"""Test that summary is only requested once (should_request_summary flag)."""
|
||||
callback_processor.should_request_summary = False
|
||||
|
||||
result = await callback_processor(
|
||||
conversation_id=uuid4(),
|
||||
callback=mock_event_callback,
|
||||
event=finished_event,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(JiraV1CallbackProcessor, '_request_summary')
|
||||
@patch.object(JiraV1CallbackProcessor, '_post_summary_to_jira')
|
||||
async def test_successful_summary_flow(
|
||||
self,
|
||||
mock_post_summary,
|
||||
mock_request_summary,
|
||||
callback_processor,
|
||||
mock_event_callback,
|
||||
finished_event,
|
||||
):
|
||||
"""Test successful summary request and posting flow."""
|
||||
conversation_id = uuid4()
|
||||
mock_request_summary.return_value = 'Test summary content'
|
||||
mock_post_summary.return_value = None
|
||||
|
||||
result = await callback_processor(
|
||||
conversation_id=conversation_id,
|
||||
callback=mock_event_callback,
|
||||
event=finished_event,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == EventCallbackResultStatus.SUCCESS
|
||||
assert result.detail == 'Test summary content'
|
||||
assert callback_processor.should_request_summary is False
|
||||
mock_request_summary.assert_called_once_with(conversation_id)
|
||||
mock_post_summary.assert_called_once_with('Test summary content')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(JiraV1CallbackProcessor, '_request_summary')
|
||||
async def test_error_handling_on_summary_request_failure(
|
||||
self,
|
||||
mock_request_summary,
|
||||
callback_processor,
|
||||
mock_event_callback,
|
||||
finished_event,
|
||||
):
|
||||
"""Test error handling when summary request fails."""
|
||||
conversation_id = uuid4()
|
||||
mock_request_summary.side_effect = Exception('Agent server unavailable')
|
||||
|
||||
result = await callback_processor(
|
||||
conversation_id=conversation_id,
|
||||
callback=mock_event_callback,
|
||||
event=finished_event,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == EventCallbackResultStatus.ERROR
|
||||
assert 'Agent server unavailable' in result.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch.object(JiraV1CallbackProcessor, '_request_summary')
|
||||
@patch.object(JiraV1CallbackProcessor, '_post_summary_to_jira')
|
||||
async def test_error_handling_on_post_failure(
|
||||
self,
|
||||
mock_post_summary,
|
||||
mock_request_summary,
|
||||
callback_processor,
|
||||
mock_event_callback,
|
||||
finished_event,
|
||||
):
|
||||
"""Test error handling when posting to Jira fails."""
|
||||
conversation_id = uuid4()
|
||||
mock_request_summary.return_value = 'Test summary'
|
||||
mock_post_summary.side_effect = Exception('Jira API error')
|
||||
|
||||
result = await callback_processor(
|
||||
conversation_id=conversation_id,
|
||||
callback=mock_event_callback,
|
||||
event=finished_event,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.status == EventCallbackResultStatus.ERROR
|
||||
assert 'Jira API error' in result.detail
|
||||
|
||||
|
||||
class TestPostSummaryToJira:
|
||||
"""Tests for _post_summary_to_jira method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_missing_credentials(self, callback_processor):
|
||||
"""Test that posting is skipped when credentials are missing."""
|
||||
callback_processor.svc_acc_email = ''
|
||||
|
||||
# Should not raise, just log and return
|
||||
await callback_processor._post_summary_to_jira('Test summary')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_missing_issue_key(self, callback_processor):
|
||||
"""Test that posting is skipped when issue key is missing."""
|
||||
callback_processor.issue_key = ''
|
||||
|
||||
# Should not raise, just log and return
|
||||
await callback_processor._post_summary_to_jira('Test summary')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_when_missing_cloud_id(self, callback_processor):
|
||||
"""Test that posting is skipped when cloud ID is missing."""
|
||||
callback_processor.jira_cloud_id = ''
|
||||
|
||||
# Should not raise, just log and return
|
||||
await callback_processor._post_summary_to_jira('Test summary')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('httpx.AsyncClient')
|
||||
async def test_posts_comment_with_correct_format(
|
||||
self, mock_async_client, callback_processor
|
||||
):
|
||||
"""Test that comment is posted with correct format (plain string body)."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||
mock_async_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
await callback_processor._post_summary_to_jira('Test summary content')
|
||||
|
||||
# Verify the post was called
|
||||
mock_client_instance.post.assert_called_once()
|
||||
call_args = mock_client_instance.post.call_args
|
||||
|
||||
# Check URL
|
||||
expected_url = (
|
||||
f'{JIRA_CLOUD_API_URL}/cloud-123/rest/api/2/issue/TEST-123/comment'
|
||||
)
|
||||
assert call_args[0][0] == expected_url
|
||||
|
||||
# Check that body contains the summary message
|
||||
json_body = call_args[1]['json']
|
||||
assert 'body' in json_body
|
||||
assert 'OpenHands resolved this issue' in json_body['body']
|
||||
assert 'Test summary content' in json_body['body']
|
||||
|
||||
# Check auth
|
||||
assert call_args[1]['auth'] == ('service@example.com', 'test_api_key')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('httpx.AsyncClient')
|
||||
async def test_raises_on_http_error(self, mock_async_client, callback_processor):
|
||||
"""Test that HTTP errors are propagated."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
'Bad Request',
|
||||
request=MagicMock(),
|
||||
response=MagicMock(status_code=400),
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.post = AsyncMock(return_value=mock_response)
|
||||
mock_async_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await callback_processor._post_summary_to_jira('Test summary')
|
||||
|
||||
|
||||
class TestAskQuestion:
|
||||
"""Tests for _ask_question method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_request_with_correct_payload(self, callback_processor):
|
||||
"""Test that ask_question sends correct request to agent server."""
|
||||
mock_httpx_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {'response': 'Agent response text'}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_httpx_client.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
conversation_id = uuid4()
|
||||
agent_server_url = 'http://localhost:8000'
|
||||
session_api_key = 'test_session_key'
|
||||
message_content = 'Please summarize your work'
|
||||
|
||||
result = await callback_processor._ask_question(
|
||||
httpx_client=mock_httpx_client,
|
||||
agent_server_url=agent_server_url,
|
||||
conversation_id=conversation_id,
|
||||
session_api_key=session_api_key,
|
||||
message_content=message_content,
|
||||
)
|
||||
|
||||
assert result == 'Agent response text'
|
||||
|
||||
# Verify request
|
||||
mock_httpx_client.post.assert_called_once()
|
||||
call_args = mock_httpx_client.post.call_args
|
||||
|
||||
expected_url = (
|
||||
f'{agent_server_url}/api/conversations/{conversation_id}/ask_agent'
|
||||
)
|
||||
assert call_args[0][0] == expected_url
|
||||
assert call_args[1]['headers'] == {'X-Session-API-Key': session_api_key}
|
||||
assert call_args[1]['json'] == {'question': message_content}
|
||||
assert call_args[1]['timeout'] == 30.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_http_error(self, callback_processor):
|
||||
"""Test that HTTP errors are handled and wrapped."""
|
||||
mock_httpx_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = 'Internal Server Error'
|
||||
mock_response.headers = {}
|
||||
mock_error = httpx.HTTPStatusError(
|
||||
'Server Error',
|
||||
request=MagicMock(),
|
||||
response=mock_response,
|
||||
)
|
||||
mock_httpx_client.post = AsyncMock(side_effect=mock_error)
|
||||
|
||||
with pytest.raises(Exception, match='Failed to send message to agent server'):
|
||||
await callback_processor._ask_question(
|
||||
httpx_client=mock_httpx_client,
|
||||
agent_server_url='http://localhost:8000',
|
||||
conversation_id=uuid4(),
|
||||
session_api_key='test_key',
|
||||
message_content='test message',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_timeout(self, callback_processor):
|
||||
"""Test that timeout errors are handled and wrapped."""
|
||||
mock_httpx_client = AsyncMock()
|
||||
mock_httpx_client.post = AsyncMock(
|
||||
side_effect=httpx.TimeoutException('Timeout')
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match='Failed to send message to agent server'):
|
||||
await callback_processor._ask_question(
|
||||
httpx_client=mock_httpx_client,
|
||||
agent_server_url='http://localhost:8000',
|
||||
conversation_id=uuid4(),
|
||||
session_api_key='test_key',
|
||||
message_content='test message',
|
||||
)
|
||||
@@ -3,6 +3,7 @@ Tests for Jira view classes and factory.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from integrations.jira.jira_payload import (
|
||||
@@ -18,6 +19,9 @@ from integrations.jira.jira_view import (
|
||||
JiraNewConversationView,
|
||||
)
|
||||
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class TestJiraNewConversationView:
|
||||
"""Tests for JiraNewConversationView"""
|
||||
@@ -85,6 +89,51 @@ class TestJiraNewConversationView:
|
||||
assert 'TEST-123' in user_msg
|
||||
assert 'Test Issue' in user_msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.resolve_org_for_repo', new_callable=AsyncMock)
|
||||
@patch('integrations.jira.jira_view.ProviderHandler')
|
||||
@patch(
|
||||
'integrations.jira.jira_view.SaasConversationStore.get_resolver_instance',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch('integrations.jira.jira_view.start_conversation', new_callable=AsyncMock)
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_or_update_conversation_success(
|
||||
self,
|
||||
mock_integration_store,
|
||||
mock_start_convo,
|
||||
mock_get_resolver_instance,
|
||||
mock_provider_handler_cls,
|
||||
mock_resolve_org,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
):
|
||||
"""Test successful conversation creation"""
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'Test description'
|
||||
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.git_provider = ProviderType.GITHUB
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(return_value=mock_repo)
|
||||
mock_provider_handler_cls.return_value = mock_handler
|
||||
|
||||
mock_resolve_org.return_value = None
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_metadata = AsyncMock()
|
||||
mock_get_resolver_instance.return_value = mock_store
|
||||
mock_integration_store.create_conversation = AsyncMock()
|
||||
|
||||
result = await new_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) == 32 # uuid4().hex format
|
||||
mock_start_convo.assert_called_once()
|
||||
mock_integration_store.create_conversation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_or_update_conversation_no_repo(
|
||||
self, new_conversation_view, mock_jinja_env
|
||||
@@ -323,6 +372,125 @@ class TestJiraFactory:
|
||||
)
|
||||
|
||||
|
||||
CLAIMING_ORG_ID = UUID('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')
|
||||
|
||||
|
||||
class TestJiraV0ConversationRouting:
|
||||
"""Test V0 conversation routing logic based on claimed git organizations."""
|
||||
|
||||
@pytest.fixture
|
||||
def routing_view(
|
||||
self,
|
||||
sample_webhook_payload,
|
||||
sample_jira_user,
|
||||
sample_jira_workspace,
|
||||
):
|
||||
"""View with non-empty provider tokens for routing tests."""
|
||||
user_auth = MagicMock(spec=UserAuth)
|
||||
user_auth.get_provider_tokens = AsyncMock(
|
||||
return_value={ProviderType.GITHUB: MagicMock()}
|
||||
)
|
||||
user_auth.get_secrets = AsyncMock(return_value=None)
|
||||
return JiraNewConversationView(
|
||||
payload=sample_webhook_payload,
|
||||
saas_user_auth=user_auth,
|
||||
jira_user=sample_jira_user,
|
||||
jira_workspace=sample_jira_workspace,
|
||||
selected_repo='test/repo1',
|
||||
_issue_title='Test Issue',
|
||||
_issue_description='Test description',
|
||||
_decrypted_api_key='decrypted_key',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.resolve_org_for_repo', new_callable=AsyncMock)
|
||||
@patch('integrations.jira.jira_view.ProviderHandler')
|
||||
@patch(
|
||||
'integrations.jira.jira_view.SaasConversationStore.get_resolver_instance',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch('integrations.jira.jira_view.start_conversation', new_callable=AsyncMock)
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_routes_to_claimed_org_when_user_is_member(
|
||||
self,
|
||||
mock_integration_store,
|
||||
mock_start_convo,
|
||||
mock_get_resolver_instance,
|
||||
mock_provider_handler_cls,
|
||||
mock_resolve_org,
|
||||
routing_view,
|
||||
mock_jinja_env,
|
||||
):
|
||||
"""When repo belongs to a claimed org and user is a member, conversation is created in that org."""
|
||||
# Arrange
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.git_provider = ProviderType.GITHUB
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(return_value=mock_repo)
|
||||
mock_provider_handler_cls.return_value = mock_handler
|
||||
|
||||
mock_resolve_org.return_value = CLAIMING_ORG_ID
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_metadata = AsyncMock()
|
||||
mock_get_resolver_instance.return_value = mock_store
|
||||
mock_integration_store.create_conversation = AsyncMock()
|
||||
|
||||
# Act
|
||||
await routing_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
# Assert
|
||||
mock_resolve_org.assert_called_once_with(
|
||||
provider='github',
|
||||
full_repo_name='test/repo1',
|
||||
keycloak_user_id='test_keycloak_id',
|
||||
)
|
||||
call_args = mock_get_resolver_instance.call_args
|
||||
assert call_args[0][1] == 'test_keycloak_id' # user_id
|
||||
assert call_args[0][2] == CLAIMING_ORG_ID # resolver_org_id
|
||||
saved_metadata = mock_store.save_metadata.call_args[0][0]
|
||||
assert saved_metadata.git_provider == ProviderType.GITHUB
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.resolve_org_for_repo', new_callable=AsyncMock)
|
||||
@patch('integrations.jira.jira_view.ProviderHandler')
|
||||
@patch(
|
||||
'integrations.jira.jira_view.SaasConversationStore.get_resolver_instance',
|
||||
new_callable=AsyncMock,
|
||||
)
|
||||
@patch('integrations.jira.jira_view.start_conversation', new_callable=AsyncMock)
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_falls_back_to_personal_workspace_when_no_claim(
|
||||
self,
|
||||
mock_integration_store,
|
||||
mock_start_convo,
|
||||
mock_get_resolver_instance,
|
||||
mock_provider_handler_cls,
|
||||
mock_resolve_org,
|
||||
routing_view,
|
||||
mock_jinja_env,
|
||||
):
|
||||
"""When no org has claimed the git org, conversation goes to personal workspace."""
|
||||
# Arrange
|
||||
mock_repo = MagicMock()
|
||||
mock_repo.git_provider = ProviderType.GITHUB
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.verify_repo_provider = AsyncMock(return_value=mock_repo)
|
||||
mock_provider_handler_cls.return_value = mock_handler
|
||||
|
||||
mock_resolve_org.return_value = None
|
||||
mock_store = MagicMock()
|
||||
mock_store.save_metadata = AsyncMock()
|
||||
mock_get_resolver_instance.return_value = mock_store
|
||||
mock_integration_store.create_conversation = AsyncMock()
|
||||
|
||||
# Act
|
||||
await routing_view.create_or_update_conversation(mock_jinja_env)
|
||||
|
||||
# Assert
|
||||
call_args = mock_get_resolver_instance.call_args
|
||||
assert call_args[0][2] is None # resolver_org_id is None
|
||||
|
||||
|
||||
class TestJiraPayloadParser:
|
||||
"""Tests for JiraPayloadParser"""
|
||||
|
||||
@@ -438,164 +606,3 @@ class TestJiraPayloadParserStagingLabels:
|
||||
result = staging_parser.parse(payload)
|
||||
|
||||
assert isinstance(result, JiraPayloadSkipped)
|
||||
|
||||
|
||||
class TestJiraV1Conversation:
|
||||
"""Tests for V1 conversation creation and callback processor registration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_v1_metadata_generates_conversation_id(
|
||||
self, new_conversation_view
|
||||
):
|
||||
"""Test that _create_v1_metadata generates a new conversation ID."""
|
||||
new_conversation_view.conversation_id = ''
|
||||
|
||||
with patch.object(
|
||||
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = None
|
||||
|
||||
metadata = await new_conversation_view._create_v1_metadata()
|
||||
|
||||
# Conversation ID should be generated
|
||||
assert new_conversation_view.conversation_id != ''
|
||||
assert len(new_conversation_view.conversation_id) == 32 # UUID hex format
|
||||
assert metadata.conversation_id == new_conversation_view.conversation_id
|
||||
mock_get_org.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_v1_metadata_sets_resolved_org(self, new_conversation_view):
|
||||
"""Test that _create_v1_metadata sets resolved_org_id."""
|
||||
from uuid import UUID
|
||||
|
||||
test_org_id = UUID('12345678-1234-5678-1234-567812345678')
|
||||
|
||||
with patch.object(
|
||||
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = test_org_id
|
||||
|
||||
await new_conversation_view._create_v1_metadata()
|
||||
|
||||
assert new_conversation_view.resolved_org_id == test_org_id
|
||||
|
||||
def test_create_jira_v1_callback_processor(
|
||||
self, new_conversation_view, sample_jira_workspace
|
||||
):
|
||||
"""Test that _create_jira_v1_callback_processor creates correctly configured processor."""
|
||||
from integrations.jira.jira_v1_callback_processor import JiraV1CallbackProcessor
|
||||
|
||||
processor = new_conversation_view._create_jira_v1_callback_processor()
|
||||
|
||||
assert isinstance(processor, JiraV1CallbackProcessor)
|
||||
assert processor.svc_acc_email == sample_jira_workspace.svc_acc_email
|
||||
assert processor.decrypted_api_key == 'decrypted_key'
|
||||
assert processor.issue_key == 'TEST-123'
|
||||
assert processor.jira_cloud_id == sample_jira_workspace.jira_cloud_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_v1_initial_user_message(
|
||||
self, new_conversation_view, mock_jinja_env
|
||||
):
|
||||
"""Test _get_v1_initial_user_message renders the template correctly."""
|
||||
new_conversation_view._issue_title = 'Test Bug'
|
||||
new_conversation_view._issue_description = 'Description of the bug'
|
||||
|
||||
message = await new_conversation_view._get_v1_initial_user_message(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
assert 'TEST-123' in message
|
||||
assert 'Test Bug' in message
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.get_app_conversation_service')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_or_update_conversation_v1_flow(
|
||||
self,
|
||||
mock_store,
|
||||
mock_get_service,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
):
|
||||
"""Test create_or_update_conversation creates V1 conversation correctly."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
|
||||
# Setup mocks
|
||||
mock_store.create_conversation = AsyncMock()
|
||||
|
||||
# Mock the app conversation service
|
||||
mock_service = AsyncMock()
|
||||
|
||||
async def mock_start_generator(*args, **kwargs):
|
||||
yield MagicMock(status=AppConversationStartTaskStatus.WORKING)
|
||||
yield MagicMock(status=AppConversationStartTaskStatus.READY)
|
||||
|
||||
mock_service.start_app_conversation = mock_start_generator
|
||||
mock_get_service.return_value.__aenter__.return_value = mock_service
|
||||
|
||||
# Set issue details to avoid fetch
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'Test description'
|
||||
|
||||
with patch.object(
|
||||
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = UUID('12345678-1234-5678-1234-567812345678')
|
||||
|
||||
result = await new_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
# Verify conversation was created
|
||||
assert result is not None
|
||||
mock_store.create_conversation.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('integrations.jira.jira_view.get_app_conversation_service')
|
||||
@patch('integrations.jira.jira_view.integration_store')
|
||||
async def test_create_or_update_conversation_handles_error(
|
||||
self,
|
||||
mock_store,
|
||||
mock_get_service,
|
||||
new_conversation_view,
|
||||
mock_jinja_env,
|
||||
):
|
||||
"""Test create_or_update_conversation handles V1 errors correctly."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
|
||||
mock_store.create_conversation = AsyncMock()
|
||||
|
||||
# Mock the app conversation service to return error
|
||||
mock_service = AsyncMock()
|
||||
|
||||
async def mock_error_generator(*args, **kwargs):
|
||||
yield MagicMock(
|
||||
status=AppConversationStartTaskStatus.ERROR,
|
||||
detail='Sandbox allocation failed',
|
||||
)
|
||||
|
||||
mock_service.start_app_conversation = mock_error_generator
|
||||
mock_get_service.return_value.__aenter__.return_value = mock_service
|
||||
|
||||
new_conversation_view._issue_title = 'Test Issue'
|
||||
new_conversation_view._issue_description = 'Test description'
|
||||
|
||||
with patch.object(
|
||||
new_conversation_view, '_get_resolved_org_id', new_callable=AsyncMock
|
||||
) as mock_get_org:
|
||||
mock_get_org.return_value = None
|
||||
|
||||
with pytest.raises(RuntimeError, match='Failed to start V1 conversation'):
|
||||
await new_conversation_view.create_or_update_conversation(
|
||||
mock_jinja_env
|
||||
)
|
||||
|
||||
@@ -109,43 +109,3 @@ async def test_extracts_git_org_lowercase_from_repo_name(mock_stores):
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
|
||||
'github', 'myorg'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_org_id_without_membership_check_when_no_user_id(mock_stores):
|
||||
"""When user_id is None, skip membership check and return org_id if claim exists."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, mock_member_store = mock_stores
|
||||
|
||||
# Arrange
|
||||
claim = MagicMock()
|
||||
claim.org_id = CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = claim
|
||||
|
||||
# Act - no user_id provided
|
||||
result = await resolve_org_for_repo('github', 'OpenHands/foo')
|
||||
|
||||
# Assert
|
||||
assert result == CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
|
||||
'github', 'openhands'
|
||||
)
|
||||
# Membership check should NOT be called
|
||||
mock_member_store.get_org_member.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_claim_and_no_user_id(mock_stores):
|
||||
"""When no claim exists and no user_id, return None."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, mock_member_store = mock_stores
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = None
|
||||
|
||||
# Act - no user_id provided
|
||||
result = await resolve_org_for_repo('github', 'UnclaimedOrg/repo')
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_member_store.get_org_member.assert_not_called()
|
||||
|
||||
@@ -1,330 +0,0 @@
|
||||
"""Tests for onboarding-related auth routes and functions.
|
||||
|
||||
Tests for:
|
||||
- _should_redirect_to_onboarding() function
|
||||
- _get_post_auth_redirect() function
|
||||
- /complete_onboarding endpoint
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.auth import (
|
||||
_get_post_auth_redirect,
|
||||
_should_redirect_to_onboarding,
|
||||
complete_onboarding,
|
||||
)
|
||||
from storage.user import User
|
||||
|
||||
# --- Fixtures ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
"""Create a mock FastAPI Request."""
|
||||
request = MagicMock(spec=Request)
|
||||
request.url = MagicMock()
|
||||
request.url.hostname = 'localhost'
|
||||
request.url.netloc = 'localhost:8000'
|
||||
request.base_url = 'http://localhost:8000/'
|
||||
request.headers = {}
|
||||
request.cookies = {}
|
||||
return request
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a mock User object."""
|
||||
user = MagicMock(spec=User)
|
||||
user.id = uuid.uuid4()
|
||||
user.current_org_id = uuid.uuid4()
|
||||
user.onboarding_completed = False
|
||||
return user
|
||||
|
||||
|
||||
# --- Tests for _should_redirect_to_onboarding ---
|
||||
|
||||
|
||||
class TestShouldRedirectToOnboarding:
|
||||
"""Tests for the _should_redirect_to_onboarding function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_when_onboarding_completed(self, mock_user):
|
||||
"""Test that completed onboarding users are not redirected."""
|
||||
mock_user.onboarding_completed = True
|
||||
|
||||
result = await _should_redirect_to_onboarding('user-123', mock_user)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_true_for_cloud_mode_new_user(self, mock_user):
|
||||
"""Test that cloud mode users with incomplete onboarding are redirected."""
|
||||
mock_user.onboarding_completed = False
|
||||
|
||||
with patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'):
|
||||
result = await _should_redirect_to_onboarding('user-123', mock_user)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_true_for_self_hosted_super_admin(self, mock_user):
|
||||
"""Test that the super admin (first owner to accept TOS) is redirected."""
|
||||
mock_user.onboarding_completed = False
|
||||
user_id = str(mock_user.id)
|
||||
|
||||
# Mock this user as the first owner in the org (super admin)
|
||||
first_owner = MagicMock(spec=User)
|
||||
first_owner.id = mock_user.id
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_first_owner_in_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=first_owner,
|
||||
),
|
||||
):
|
||||
result = await _should_redirect_to_onboarding(user_id, mock_user)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_for_self_hosted_non_super_admin_owner(self, mock_user):
|
||||
"""Test that owners who aren't the super admin are NOT redirected."""
|
||||
mock_user.onboarding_completed = False
|
||||
user_id = str(mock_user.id)
|
||||
|
||||
# Mock a different user as the first owner (super admin)
|
||||
first_owner = MagicMock(spec=User)
|
||||
first_owner.id = uuid.uuid4() # Different user
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_first_owner_in_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=first_owner,
|
||||
),
|
||||
):
|
||||
result = await _should_redirect_to_onboarding(user_id, mock_user)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_false_for_self_hosted_when_no_owner_found(self, mock_user):
|
||||
"""Test that users are not redirected when no owner is found."""
|
||||
mock_user.onboarding_completed = False
|
||||
user_id = str(mock_user.id)
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_first_owner_in_org',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = await _should_redirect_to_onboarding(user_id, mock_user)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_current_org_id_to_get_first_owner(self, mock_user):
|
||||
"""Test that get_first_owner_in_org is called with user's current_org_id."""
|
||||
mock_user.onboarding_completed = False
|
||||
user_id = str(mock_user.id)
|
||||
mock_get_first_owner = AsyncMock(return_value=None)
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'self_hosted'),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_first_owner_in_org',
|
||||
mock_get_first_owner,
|
||||
),
|
||||
):
|
||||
await _should_redirect_to_onboarding(user_id, mock_user)
|
||||
|
||||
mock_get_first_owner.assert_called_once_with(mock_user.current_org_id)
|
||||
|
||||
|
||||
# --- Tests for _get_post_auth_redirect ---
|
||||
|
||||
|
||||
class TestGetPostAuthRedirect:
|
||||
"""Tests for the _get_post_auth_redirect function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_onboarding_url_when_onboarding_needed(self, mock_user):
|
||||
"""Test that onboarding URL is returned when user needs onboarding."""
|
||||
mock_user.onboarding_completed = False
|
||||
user_id = str(mock_user.id)
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
result = await _get_post_auth_redirect(
|
||||
user_id, 'https://example.com/', 'https://example.com'
|
||||
)
|
||||
|
||||
assert result == 'https://example.com/onboarding'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_url_when_onboarding_completed(self, mock_user):
|
||||
"""Test that default URL is returned when user has completed onboarding."""
|
||||
mock_user.onboarding_completed = True
|
||||
user_id = str(mock_user.id)
|
||||
|
||||
with patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
):
|
||||
result = await _get_post_auth_redirect(
|
||||
user_id, 'https://example.com/dashboard', 'https://example.com'
|
||||
)
|
||||
|
||||
assert result == 'https://example.com/dashboard'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_default_url_when_user_not_found(self):
|
||||
"""Test that default URL is returned when user is not found."""
|
||||
with patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
result = await _get_post_auth_redirect(
|
||||
'nonexistent-user', 'https://example.com/', 'https://example.com'
|
||||
)
|
||||
|
||||
assert result == 'https://example.com/'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_when_redirecting_to_onboarding(self, mock_user):
|
||||
"""Test that a log message is emitted when redirecting to onboarding."""
|
||||
mock_user.onboarding_completed = False
|
||||
user_id = str(mock_user.id)
|
||||
|
||||
with (
|
||||
patch('server.routes.auth.DEPLOYMENT_MODE', 'cloud'),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.get_user_by_id',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
patch('server.routes.auth.logger') as mock_logger,
|
||||
):
|
||||
await _get_post_auth_redirect(
|
||||
user_id, 'https://example.com/', 'https://example.com'
|
||||
)
|
||||
|
||||
mock_logger.info.assert_called_once()
|
||||
call_args = mock_logger.info.call_args
|
||||
assert call_args[0][0] == 'Redirecting user to onboarding'
|
||||
assert call_args[1]['extra']['user_id'] == user_id
|
||||
|
||||
|
||||
# --- Tests for /complete_onboarding endpoint ---
|
||||
|
||||
|
||||
class TestCompleteOnboardingEndpoint:
|
||||
"""Tests for the complete_onboarding API endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_401_when_not_authenticated(self, mock_request):
|
||||
"""Test that unauthenticated requests return 401."""
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
):
|
||||
result = await complete_onboarding(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_401_UNAUTHORIZED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_404_when_user_not_found(self, mock_request):
|
||||
"""Test that request for non-existent user returns 404."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.mark_onboarding_completed',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = await complete_onboarding(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_200_on_success(self, mock_request, mock_user):
|
||||
"""Test successful onboarding completion returns 200."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.mark_onboarding_completed',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
result = await complete_onboarding(mock_request)
|
||||
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_mark_onboarding_completed_with_user_id(
|
||||
self, mock_request, mock_user
|
||||
):
|
||||
"""Test that mark_onboarding_completed is called with the correct user_id."""
|
||||
user_id = str(uuid.uuid4())
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=user_id)
|
||||
mock_mark_completed = AsyncMock(return_value=mock_user)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_user_auth,
|
||||
),
|
||||
patch(
|
||||
'server.routes.auth.UserStore.mark_onboarding_completed',
|
||||
mock_mark_completed,
|
||||
),
|
||||
):
|
||||
await complete_onboarding(mock_request)
|
||||
|
||||
mock_mark_completed.assert_called_once_with(user_id)
|
||||
@@ -1,670 +0,0 @@
|
||||
"""
|
||||
Unit tests for AutomationEventService.
|
||||
|
||||
Tests the service that forwards GitHub webhook events to the automation service.
|
||||
|
||||
The service is optimized for high-traffic with:
|
||||
- Redis caching for org claim lookups (1 hour TTL)
|
||||
- Redis caching for GitHub→Keycloak user ID mappings (24 hour TTL)
|
||||
- Lazy access control (membership checks deferred to execution time)
|
||||
- Separate AUTOMATION_WEBHOOK_SECRET for internal service communication
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Default patches for constants
|
||||
CONSTANT_PATCHES = {
|
||||
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET': 'test-shared-secret',
|
||||
'server.services.automation_event_service.AUTOMATION_SERVICE_TIMEOUT': 30,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_manager():
|
||||
"""Create a mock TokenManager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org_git_claim():
|
||||
"""Create a mock OrgGitClaim."""
|
||||
claim = MagicMock()
|
||||
claim.org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
|
||||
return claim
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_org_payload():
|
||||
"""Create a sample GitHub webhook payload for an organization repo."""
|
||||
return {
|
||||
'repository': {
|
||||
'id': 123456,
|
||||
'full_name': 'test-org/test-repo',
|
||||
'private': False,
|
||||
'default_branch': 'main',
|
||||
'owner': {
|
||||
'login': 'test-org',
|
||||
'id': 789,
|
||||
'type': 'Organization',
|
||||
},
|
||||
},
|
||||
'sender': {
|
||||
'id': 12345,
|
||||
'login': 'testuser',
|
||||
},
|
||||
'action': 'opened',
|
||||
'installation': {
|
||||
'id': 99999,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_user_payload():
|
||||
"""Create a sample GitHub webhook payload for a personal/user repo."""
|
||||
return {
|
||||
'repository': {
|
||||
'id': 654321,
|
||||
'full_name': 'testuser/personal-repo',
|
||||
'private': True,
|
||||
'default_branch': 'main',
|
||||
'owner': {
|
||||
'login': 'testuser',
|
||||
'id': 12345,
|
||||
'type': 'User',
|
||||
},
|
||||
},
|
||||
'sender': {
|
||||
'id': 12345,
|
||||
'login': 'testuser',
|
||||
},
|
||||
'action': 'opened',
|
||||
'installation': {
|
||||
'id': 99999,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_service(mock_token_manager):
|
||||
"""Helper to create a service with mocked sio and constants."""
|
||||
with patch('server.services.automation_event_service.sio'), patch.dict(
|
||||
'os.environ', {}, clear=False
|
||||
):
|
||||
for key, value in CONSTANT_PATCHES.items():
|
||||
patch(key, value).start()
|
||||
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
return AutomationEventService(mock_token_manager)
|
||||
|
||||
|
||||
class TestResolveGithubOrg:
|
||||
"""Tests for _resolve_github_org method with caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_cache_miss_found(
|
||||
self, mock_token_manager, mock_org_git_claim
|
||||
):
|
||||
"""
|
||||
GIVEN: Cache miss and org claim exists in DB
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: Org ID is returned and cached
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('test-org')
|
||||
|
||||
assert result == mock_org_git_claim.org_id
|
||||
# Verify result was cached
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_cache_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Org ID is cached in Redis
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: Cached value is returned without calling resolve_org_for_repo
|
||||
"""
|
||||
cached_org_id = '12345678-1234-5678-1234-567812345678'
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=cached_org_id.encode())
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_resolver, patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('test-org')
|
||||
|
||||
assert result == uuid.UUID(cached_org_id)
|
||||
# resolve_org_for_repo should NOT be called
|
||||
mock_resolver.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_cache_miss_not_found(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Cache miss and org claim does NOT exist in DB
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: None is returned and negative result is cached
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('unclaimed-org')
|
||||
|
||||
assert result is None
|
||||
# Verify negative result was cached
|
||||
mock_redis.setex.assert_called_once()
|
||||
call_args = mock_redis.setex.call_args
|
||||
# Second positional arg is the value
|
||||
assert call_args[0][2] == 'none' # Negative cache value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_negative_cache_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Negative result is cached (org not claimed)
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: None is returned without calling resolve_org_for_repo
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=b'none') # Cached negative
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_resolver, patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('unclaimed-org')
|
||||
|
||||
assert result is None
|
||||
mock_resolver.assert_not_called()
|
||||
|
||||
|
||||
class TestResolvePersonalOrg:
|
||||
"""Tests for _resolve_personal_org method with caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_personal_org_cache_miss_found(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Cache miss and user exists in Keycloak
|
||||
WHEN: _resolve_personal_org is called
|
||||
THEN: Keycloak ID is returned and cached
|
||||
"""
|
||||
keycloak_id = '87654321-4321-8765-4321-876543218765'
|
||||
mock_token_manager.get_user_id_from_idp_user_id = AsyncMock(
|
||||
return_value=keycloak_id
|
||||
)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_personal_org(12345)
|
||||
|
||||
assert result == uuid.UUID(keycloak_id)
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_personal_org_cache_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Keycloak ID is cached in Redis
|
||||
WHEN: _resolve_personal_org is called
|
||||
THEN: Cached value is returned without Keycloak query
|
||||
"""
|
||||
keycloak_id = '87654321-4321-8765-4321-876543218765'
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=keycloak_id.encode())
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_personal_org(12345)
|
||||
|
||||
assert result == uuid.UUID(keycloak_id)
|
||||
# Token manager should NOT be called
|
||||
mock_token_manager.get_user_id_from_idp_user_id.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_personal_org_no_github_user_id(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: No GitHub user ID provided
|
||||
WHEN: _resolve_personal_org is called
|
||||
THEN: None is returned immediately
|
||||
"""
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_personal_org(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestForwardGithubEvent:
|
||||
"""Tests for forward_github_event method (minimal payload, no access control)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_org_event_success(
|
||||
self, mock_token_manager, github_org_payload, mock_org_git_claim
|
||||
):
|
||||
"""
|
||||
GIVEN: A GitHub event from a claimed organization repo
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Minimal payload is forwarded (no access_control)
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=github_org_payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args
|
||||
assert call_args[0][0] == mock_org_git_claim.org_id
|
||||
|
||||
payload = call_args[0][1]
|
||||
assert payload['organization']['github_org'] == 'test-org'
|
||||
assert 'payload' in payload
|
||||
# access_control should NOT be in payload (lazy evaluation)
|
||||
assert 'access_control' not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_personal_repo_event_success(
|
||||
self, mock_token_manager, github_user_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: A GitHub event from a personal repo with linked OpenHands account
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Event is forwarded using the user's personal org (keycloak ID)
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
keycloak_id = '87654321-4321-8765-4321-876543218765'
|
||||
mock_token_manager.get_user_id_from_idp_user_id = AsyncMock(
|
||||
return_value=keycloak_id
|
||||
)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # No org claim for personal repo
|
||||
), patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=github_user_payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args
|
||||
# Personal org should be keycloak ID
|
||||
assert call_args[0][0] == uuid.UUID(keycloak_id)
|
||||
payload = call_args[0][1]
|
||||
assert payload['organization']['github_org'] == 'testuser'
|
||||
assert payload['organization']['openhands_org_id'] == keycloak_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_event_no_owner_in_payload(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: A GitHub event with no repository owner in payload
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Event is skipped with warning log
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
payload = {
|
||||
'repository': {},
|
||||
'sender': {'id': 12345, 'login': 'testuser'},
|
||||
}
|
||||
|
||||
with patch('server.services.automation_event_service.sio'), patch(
|
||||
'server.services.automation_event_service.logger'
|
||||
) as mock_logger, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_not_called()
|
||||
mock_logger.warning.assert_called()
|
||||
assert 'No repository owner' in str(mock_logger.warning.call_args)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_event_org_not_claimed_and_not_personal(
|
||||
self, mock_token_manager, github_org_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: A GitHub event from an org that isn't claimed (and isn't personal)
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Event is skipped with warning log
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio, patch(
|
||||
'server.services.automation_event_service.logger'
|
||||
) as mock_logger, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=github_org_payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_not_called()
|
||||
mock_logger.warning.assert_called()
|
||||
assert 'not claimed' in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
class TestBuildEventPayload:
|
||||
"""Tests for _build_event_payload method."""
|
||||
|
||||
def test_build_minimal_payload(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Org context and payload
|
||||
WHEN: _build_event_payload is called
|
||||
THEN: Minimal payload with only org + payload is returned
|
||||
"""
|
||||
from server.services.automation_event_service import OrgContext
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
|
||||
org_context = OrgContext(
|
||||
org_id=uuid.UUID('12345678-1234-5678-1234-567812345678'),
|
||||
github_org='test-org',
|
||||
)
|
||||
test_payload = {'action': 'opened', 'sender': {'login': 'user'}}
|
||||
|
||||
result = service._build_event_payload(org_context, test_payload)
|
||||
|
||||
assert result == {
|
||||
'organization': {
|
||||
'github_org': 'test-org',
|
||||
'openhands_org_id': '12345678-1234-5678-1234-567812345678',
|
||||
},
|
||||
'payload': test_payload,
|
||||
}
|
||||
# Verify NO access_control in payload
|
||||
assert 'access_control' not in result
|
||||
|
||||
|
||||
class TestSendToAutomationService:
|
||||
"""Tests for _send_to_automation_service method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_success(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: AUTOMATION_SERVICE_URL is configured
|
||||
WHEN: _send_to_automation_service is called
|
||||
THEN: Request is sent with correct signature
|
||||
"""
|
||||
|
||||
org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
|
||||
payload = {'organization': {'github_org': 'test'}, 'payload': {}}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'matched': 2})
|
||||
|
||||
mock_post_context = MagicMock()
|
||||
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_post_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.post = MagicMock(return_value=mock_post_context)
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_SERVICE_URL',
|
||||
'https://automation.example.com',
|
||||
), patch('server.services.automation_event_service.sio'), patch(
|
||||
'server.services.automation_event_service.aiohttp.ClientSession',
|
||||
return_value=mock_session_context,
|
||||
):
|
||||
service = create_service(mock_token_manager)
|
||||
await service._send_to_automation_service(org_id, payload)
|
||||
|
||||
# Verify the POST was called
|
||||
mock_session_instance.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_no_url_configured(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: AUTOMATION_SERVICE_URL is not configured
|
||||
WHEN: _send_to_automation_service is called
|
||||
THEN: Warning is logged and nothing is sent
|
||||
"""
|
||||
org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
|
||||
payload = {}
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_SERVICE_URL', None
|
||||
), patch('server.services.automation_event_service.sio'), patch(
|
||||
'server.services.automation_event_service.logger'
|
||||
) as mock_logger:
|
||||
service = create_service(mock_token_manager)
|
||||
await service._send_to_automation_service(org_id, payload)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
assert 'not configured' in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
class TestSignPayload:
|
||||
"""Tests for _sign_payload method."""
|
||||
|
||||
def test_sign_payload(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: A payload bytes
|
||||
WHEN: _sign_payload is called
|
||||
THEN: HMAC-SHA256 signature is returned in correct format
|
||||
"""
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
|
||||
'test-shared-secret',
|
||||
), patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
payload_bytes = b'{"test": "data"}'
|
||||
|
||||
signature = service._sign_payload(payload_bytes)
|
||||
|
||||
assert signature.startswith('sha256=')
|
||||
assert len(signature) == 71 # 'sha256=' + 64 hex chars
|
||||
|
||||
def test_sign_payload_uses_dedicated_secret(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: AUTOMATION_WEBHOOK_SECRET is configured
|
||||
WHEN: _sign_payload is called
|
||||
THEN: The dedicated secret is used (not GitHub webhook secret)
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
# Use the default test secret from CONSTANT_PATCHES
|
||||
shared_secret = 'test-shared-secret'
|
||||
payload_bytes = b'{"test": "data"}'
|
||||
|
||||
# Calculate expected signature with the shared secret
|
||||
expected_sig = hmac.new(
|
||||
shared_secret.encode('utf-8'),
|
||||
msg=payload_bytes,
|
||||
digestmod=hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
|
||||
shared_secret,
|
||||
), patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
signature = service._sign_payload(payload_bytes)
|
||||
|
||||
assert signature == f'sha256={expected_sig}'
|
||||
|
||||
|
||||
class TestCacheHelpers:
|
||||
"""Tests for generic cache helper methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cached_value_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Value exists in Redis cache
|
||||
WHEN: _get_cached_value is called
|
||||
THEN: Decoded string value is returned
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=b'cached-value')
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._get_cached_value('test-key')
|
||||
|
||||
assert result == 'cached-value'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cached_value_miss(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Value does not exist in Redis cache
|
||||
WHEN: _get_cached_value is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._get_cached_value('test-key')
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cached_value_redis_unavailable(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Redis is unavailable
|
||||
WHEN: _get_cached_value is called
|
||||
THEN: None is returned (graceful degradation)
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = None
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._get_cached_value('test-key')
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_cached_value_success(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Redis is available
|
||||
WHEN: _set_cached_value is called
|
||||
THEN: Value is stored with TTL
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
await service._set_cached_value('test-key', 'test-value', 3600)
|
||||
|
||||
mock_redis.setex.assert_called_once_with('test-key', 3600, 'test-value')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_cached_value_redis_unavailable(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Redis is unavailable
|
||||
WHEN: _set_cached_value is called
|
||||
THEN: No error is raised (silent failure)
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = None
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
# Should not raise
|
||||
await service._set_cached_value('test-key', 'test-value', 3600)
|
||||
@@ -1,105 +0,0 @@
|
||||
"""Tests for enterprise server constants, specifically DEPLOYMENT_MODE detection."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestDeploymentMode:
|
||||
"""Tests for _get_deployment_mode() and _is_all_hands_managed_domain() functions."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'web_host,expected_mode',
|
||||
[
|
||||
# All-Hands managed domains should return 'cloud'
|
||||
('app.all-hands.dev', 'cloud'),
|
||||
('staging.all-hands.dev', 'cloud'),
|
||||
('feature-123.staging.all-hands.dev', 'cloud'),
|
||||
('pr-456.staging.all-hands.dev', 'cloud'),
|
||||
('app.openhands.ai', 'cloud'),
|
||||
# Customer domains should return 'self_hosted'
|
||||
('openhands.acme.com', 'self_hosted'),
|
||||
('internal.company.io', 'self_hosted'),
|
||||
('dev.mycompany.net', 'self_hosted'),
|
||||
('openhands.example.org', 'self_hosted'),
|
||||
('localhost', 'self_hosted'), # localhost is not a managed domain
|
||||
# Edge cases
|
||||
('all-hands.dev', 'self_hosted'), # Not a subdomain, so not managed
|
||||
('fake-all-hands.dev', 'self_hosted'),
|
||||
('app.all-hands.dev.evil.com', 'self_hosted'),
|
||||
],
|
||||
)
|
||||
def test_deployment_mode_detection(self, web_host: str, expected_mode: str):
|
||||
"""Test that DEPLOYMENT_MODE is correctly determined based on WEB_HOST."""
|
||||
with patch.dict('os.environ', {'WEB_HOST': web_host}):
|
||||
# Need to reimport to pick up the mocked environment variable
|
||||
import importlib
|
||||
|
||||
import server.constants as constants_module
|
||||
|
||||
importlib.reload(constants_module)
|
||||
|
||||
assert constants_module.DEPLOYMENT_MODE == expected_mode
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'host,expected',
|
||||
[
|
||||
('app.all-hands.dev', True),
|
||||
('staging.all-hands.dev', True),
|
||||
('feature.staging.all-hands.dev', True),
|
||||
('app.openhands.ai', True),
|
||||
('localhost', False), # localhost is not a managed domain
|
||||
('customer.example.com', False),
|
||||
('all-hands.dev', False),
|
||||
],
|
||||
)
|
||||
def test_is_all_hands_managed_domain(self, host: str, expected: bool):
|
||||
"""Test _is_all_hands_managed_domain() helper function."""
|
||||
from server.constants import _is_all_hands_managed_domain
|
||||
|
||||
assert _is_all_hands_managed_domain(host) == expected
|
||||
|
||||
def test_deployment_mode_default_is_cloud(self):
|
||||
"""Test that default WEB_HOST (app.all-hands.dev) results in 'cloud' mode."""
|
||||
with patch.dict('os.environ', {}, clear=True):
|
||||
# Remove WEB_HOST to test default
|
||||
import importlib
|
||||
import os
|
||||
|
||||
if 'WEB_HOST' in os.environ:
|
||||
del os.environ['WEB_HOST']
|
||||
|
||||
import server.constants as constants_module
|
||||
|
||||
importlib.reload(constants_module)
|
||||
|
||||
# Default WEB_HOST is 'app.all-hands.dev' which should be 'cloud'
|
||||
assert constants_module.DEPLOYMENT_MODE == 'cloud'
|
||||
|
||||
|
||||
class TestDeploymentModeInConfig:
|
||||
"""Tests for DEPLOYMENT_MODE being exposed in config API."""
|
||||
|
||||
def test_deployment_mode_included_in_feature_flags(self):
|
||||
"""Test that DEPLOYMENT_MODE is included in FEATURE_FLAGS from get_config()."""
|
||||
from server.config import SaaSServerConfig
|
||||
|
||||
with patch('server.config.DEPLOYMENT_MODE', 'cloud'):
|
||||
saas_config = SaaSServerConfig()
|
||||
config = saas_config.get_config()
|
||||
|
||||
assert 'FEATURE_FLAGS' in config
|
||||
assert 'DEPLOYMENT_MODE' in config['FEATURE_FLAGS']
|
||||
assert config['FEATURE_FLAGS']['DEPLOYMENT_MODE'] == 'cloud'
|
||||
|
||||
def test_deployment_mode_self_hosted_in_feature_flags(self):
|
||||
"""Test that self_hosted DEPLOYMENT_MODE is included in FEATURE_FLAGS."""
|
||||
from server.config import SaaSServerConfig
|
||||
|
||||
with patch('server.config.DEPLOYMENT_MODE', 'self_hosted'):
|
||||
saas_config = SaaSServerConfig()
|
||||
config = saas_config.get_config()
|
||||
|
||||
assert 'FEATURE_FLAGS' in config
|
||||
assert 'DEPLOYMENT_MODE' in config['FEATURE_FLAGS']
|
||||
assert config['FEATURE_FLAGS']['DEPLOYMENT_MODE'] == 'self_hosted'
|
||||
@@ -187,11 +187,6 @@ async def test_keycloak_callback_success_with_valid_offline_token(
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
patch(
|
||||
'server.routes.auth._should_redirect_to_onboarding',
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
@@ -444,11 +439,6 @@ async def test_keycloak_callback_success_without_offline_token(
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog') as mock_posthog,
|
||||
patch(
|
||||
'server.routes.auth._should_redirect_to_onboarding',
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
@@ -494,101 +484,19 @@ async def test_keycloak_callback_success_without_offline_token(
|
||||
mock_token_manager.store_idp_tokens.assert_called_once_with(
|
||||
ProviderType.GITHUB, 'test_user_id', 'test_access_token'
|
||||
)
|
||||
# secure is based on web_url (http://localhost:8000/), not redirect_url
|
||||
# So secure=False because web_url starts with 'http://'
|
||||
# When redirecting to Keycloak for offline token, redirect_url becomes https://keycloak...
|
||||
# so secure=True is expected
|
||||
mock_set_cookie.assert_called_once_with(
|
||||
request=mock_request,
|
||||
response=result,
|
||||
keycloak_access_token='test_access_token',
|
||||
keycloak_refresh_token='test_refresh_token',
|
||||
secure=False,
|
||||
secure=True,
|
||||
accepted_tos=True,
|
||||
)
|
||||
mock_posthog.set.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_redirects_to_keycloak_when_offline_token_invalid(
|
||||
mock_request, create_keycloak_user_info
|
||||
):
|
||||
"""Test that keycloak_callback redirects to Keycloak when offline token is invalid.
|
||||
|
||||
When a user doesn't have a valid offline token, they should be redirected
|
||||
to Keycloak to obtain one, rather than proceeding with invitation processing.
|
||||
"""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.set_response_cookie') as mock_set_cookie,
|
||||
patch(
|
||||
'server.routes.auth.KEYCLOAK_SERVER_URL_EXT', 'https://keycloak.example.com'
|
||||
),
|
||||
patch('server.routes.auth.KEYCLOAK_REALM_NAME', 'test-realm'),
|
||||
patch('server.routes.auth.KEYCLOAK_CLIENT_ID', 'test-client'),
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.posthog'),
|
||||
patch('server.routes.auth.OrgInvitationService') as mock_invitation_service,
|
||||
patch(
|
||||
'server.routes.auth._should_redirect_to_onboarding',
|
||||
new_callable=AsyncMock,
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 'test_user_id'
|
||||
mock_user.current_org_id = 'test_org_id'
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
mock_user_store.backfill_contact_name = AsyncMock()
|
||||
mock_user_store.backfill_user_email = AsyncMock()
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value=create_keycloak_user_info(
|
||||
sub='test_user_id',
|
||||
preferred_username='test_user',
|
||||
identity_provider='github',
|
||||
email_verified=True,
|
||||
)
|
||||
)
|
||||
mock_token_manager.store_idp_tokens = AsyncMock()
|
||||
mock_token_manager.validate_offline_token = AsyncMock(return_value=False)
|
||||
|
||||
# Call with an invitation token to verify it's NOT processed
|
||||
import base64
|
||||
import json
|
||||
|
||||
state_data = {
|
||||
'redirect_url': 'https://example.com/original-page',
|
||||
'invitation_token': 'inv-test-token-123',
|
||||
}
|
||||
encoded_state = base64.urlsafe_b64encode(
|
||||
json.dumps(state_data).encode()
|
||||
).decode()
|
||||
|
||||
result = await keycloak_callback(
|
||||
code='test_code',
|
||||
state=encoded_state,
|
||||
request=mock_request,
|
||||
user_authorizer=create_mock_user_authorizer(),
|
||||
)
|
||||
|
||||
# Should redirect to Keycloak for offline token
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert 'keycloak.example.com' in result.headers['location']
|
||||
assert 'offline_access' in result.headers['location']
|
||||
|
||||
# Cookie should be set with accepted_tos=True (user has accepted TOS)
|
||||
mock_set_cookie.assert_called_once()
|
||||
assert mock_set_cookie.call_args[1]['accepted_tos'] is True
|
||||
|
||||
# Invitation service should NOT be called (early return before invitation processing)
|
||||
mock_invitation_service.accept_invitation.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_callback_account_linking_error(mock_request):
|
||||
"""Test keycloak_callback with account linking error."""
|
||||
@@ -667,21 +575,7 @@ async def test_keycloak_offline_callback_success(
|
||||
mock_request, create_keycloak_user_info
|
||||
):
|
||||
"""Test successful keycloak_offline_callback."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch(
|
||||
'server.routes.auth._get_post_auth_redirect',
|
||||
new_callable=AsyncMock,
|
||||
return_value='test_state',
|
||||
),
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
|
||||
with patch('server.routes.auth.token_manager') as mock_token_manager:
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
@@ -704,43 +598,6 @@ async def test_keycloak_offline_callback_success(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_keycloak_offline_callback_redirects_to_onboarding(
|
||||
mock_request, create_keycloak_user_info
|
||||
):
|
||||
"""Test keycloak_offline_callback redirects to onboarding when needed."""
|
||||
with (
|
||||
patch('server.routes.auth.token_manager') as mock_token_manager,
|
||||
patch('server.routes.auth.UserStore') as mock_user_store,
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch(
|
||||
'server.routes.auth._get_post_auth_redirect',
|
||||
new_callable=AsyncMock,
|
||||
return_value='http://localhost:8000/onboarding',
|
||||
),
|
||||
):
|
||||
# Mock user with accepted_tos
|
||||
mock_user = MagicMock()
|
||||
mock_user.accepted_tos = '2025-01-01'
|
||||
mock_user_store.get_user_by_id = AsyncMock(return_value=mock_user)
|
||||
|
||||
mock_token_manager.get_keycloak_tokens = AsyncMock(
|
||||
return_value=('test_access_token', 'test_refresh_token')
|
||||
)
|
||||
mock_token_manager.get_user_info = AsyncMock(
|
||||
return_value=create_keycloak_user_info(sub='test_user_id')
|
||||
)
|
||||
mock_token_manager.store_offline_token = AsyncMock()
|
||||
|
||||
result = await keycloak_offline_callback(
|
||||
'test_code', 'test_state', mock_request
|
||||
)
|
||||
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
assert result.headers['location'] == 'http://localhost:8000/onboarding'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authenticate_success():
|
||||
"""Test successful authentication."""
|
||||
@@ -2185,20 +2042,12 @@ async def test_accept_tos_stores_timezone_naive_datetime(mock_request):
|
||||
|
||||
mock_request.json = AsyncMock(return_value={'redirect_url': 'http://example.com'})
|
||||
|
||||
# Mock user for onboarding check (user already completed onboarding)
|
||||
mock_user_for_onboarding = MagicMock()
|
||||
mock_user_for_onboarding.onboarding_completed = True
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth', AsyncMock(return_value=mock_user_auth)
|
||||
),
|
||||
patch('server.routes.auth.a_session_maker', return_value=mock_session_context),
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch(
|
||||
'server.routes.auth._get_post_auth_redirect',
|
||||
AsyncMock(return_value='http://example.com'),
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await accept_tos(mock_request)
|
||||
@@ -2209,65 +2058,3 @@ async def test_accept_tos_stores_timezone_naive_datetime(mock_request):
|
||||
# The datetime assigned to user.accepted_tos must be timezone-naive
|
||||
# (compatible with TIMESTAMP WITHOUT TIME ZONE database column)
|
||||
assert mock_user.accepted_tos.tzinfo is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accept_tos_preserves_offline_flow_redirect(mock_request):
|
||||
"""Test that accept_tos does not override redirect_url when it's the offline token flow."""
|
||||
# Arrange
|
||||
test_user_id = '12345678-1234-5678-1234-567812345678'
|
||||
offline_redirect_url = 'https://auth.example.com/realms/test/protocol/openid-connect/auth?redirect_uri=https://example.com/oauth/keycloak/offline/callback'
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = test_user_id
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = mock_user
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute.return_value = mock_result
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_session_context = AsyncMock()
|
||||
mock_session_context.__aenter__.return_value = mock_session
|
||||
mock_session_context.__aexit__.return_value = None
|
||||
|
||||
mock_user_auth = MagicMock(spec=SaasUserAuth)
|
||||
mock_user_auth.get_access_token = AsyncMock(
|
||||
return_value=SecretStr('test_access_token')
|
||||
)
|
||||
mock_user_auth.refresh_token = SecretStr('test_refresh_token')
|
||||
mock_user_auth.get_user_id = AsyncMock(return_value=test_user_id)
|
||||
|
||||
mock_request.json = AsyncMock(return_value={'redirect_url': offline_redirect_url})
|
||||
|
||||
mock_get_post_auth_redirect = AsyncMock(
|
||||
return_value='http://example.com/onboarding'
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.auth.get_user_auth', AsyncMock(return_value=mock_user_auth)
|
||||
),
|
||||
patch('server.routes.auth.a_session_maker', return_value=mock_session_context),
|
||||
patch('server.routes.auth.set_response_cookie'),
|
||||
patch(
|
||||
'server.routes.auth._get_post_auth_redirect',
|
||||
mock_get_post_auth_redirect,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = await accept_tos(mock_request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, JSONResponse)
|
||||
assert result.status_code == status.HTTP_200_OK
|
||||
|
||||
# _get_post_auth_redirect should NOT be called for offline flow
|
||||
mock_get_post_auth_redirect.assert_not_called()
|
||||
|
||||
# The redirect_url should be preserved (not overridden to onboarding)
|
||||
import json
|
||||
|
||||
response_body = json.loads(result.body.decode())
|
||||
assert response_body['redirect_url'] == offline_redirect_url
|
||||
|
||||
@@ -56,7 +56,6 @@ class TestPermission:
|
||||
assert Permission.VIEW_ORG_SETTINGS.value == 'view_org_settings'
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME.value == 'change_organization_name'
|
||||
assert Permission.DELETE_ORGANIZATION.value == 'delete_organization'
|
||||
assert Permission.MANAGE_AUTOMATIONS.value == 'manage_automations'
|
||||
|
||||
def test_permission_from_string(self):
|
||||
"""
|
||||
@@ -143,7 +142,6 @@ class TestRolePermissions:
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER in owner_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME in owner_perms
|
||||
assert Permission.DELETE_ORGANIZATION in owner_perms
|
||||
assert Permission.MANAGE_AUTOMATIONS in owner_perms
|
||||
|
||||
def test_admin_has_admin_permissions(self):
|
||||
"""
|
||||
@@ -161,7 +159,6 @@ class TestRolePermissions:
|
||||
assert Permission.INVITE_USER_TO_ORGANIZATION in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_MEMBER in admin_perms
|
||||
assert Permission.CHANGE_USER_ROLE_ADMIN in admin_perms
|
||||
assert Permission.MANAGE_AUTOMATIONS in admin_perms
|
||||
# Admin should NOT have owner-only permissions
|
||||
assert Permission.CHANGE_USER_ROLE_OWNER not in admin_perms
|
||||
assert Permission.CHANGE_ORGANIZATION_NAME not in admin_perms
|
||||
@@ -180,7 +177,6 @@ class TestRolePermissions:
|
||||
assert Permission.MANAGE_INTEGRATIONS in member_perms
|
||||
assert Permission.MANAGE_APPLICATION_SETTINGS in member_perms
|
||||
assert Permission.MANAGE_API_KEYS in member_perms
|
||||
assert Permission.MANAGE_AUTOMATIONS in member_perms
|
||||
assert Permission.VIEW_LLM_SETTINGS in member_perms
|
||||
assert Permission.VIEW_ORG_SETTINGS in member_perms
|
||||
# Member should NOT have admin/owner permissions
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
"""Tests for ConversationStateUpdateEvent filtering in shared_event_router.
|
||||
|
||||
The shared-events endpoints are unauthenticated, so internal system state
|
||||
(ConversationStateUpdateEvent) must not be returned. The frontend shared-
|
||||
conversation viewer never renders these events — it only uses messages,
|
||||
actions, observations, errors, and hook-execution events.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from server.sharing.shared_event_router import (
|
||||
_is_viewable,
|
||||
batch_get_shared_events,
|
||||
get_shared_event,
|
||||
search_shared_events,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage
|
||||
from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent
|
||||
from openhands.sdk.event.llm_convertible import MessageEvent
|
||||
from openhands.sdk.llm import Message, TextContent
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_message_event() -> MessageEvent:
|
||||
return MessageEvent(
|
||||
source='user',
|
||||
llm_message=Message(role='user', content=[TextContent(text='Hello')]),
|
||||
)
|
||||
|
||||
|
||||
def _make_state_event(
|
||||
key: str = 'full_state', value: dict | str = 'idle'
|
||||
) -> ConversationStateUpdateEvent:
|
||||
return ConversationStateUpdateEvent(key=key, value=value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_viewable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsViewable:
|
||||
def test_message_event_is_viewable(self):
|
||||
assert _is_viewable(_make_message_event()) is True
|
||||
|
||||
def test_full_state_event_is_not_viewable(self):
|
||||
assert _is_viewable(_make_state_event('full_state', {'agent': {}})) is False
|
||||
|
||||
def test_execution_status_event_is_not_viewable(self):
|
||||
assert _is_viewable(_make_state_event('execution_status', 'running')) is False
|
||||
|
||||
def test_stats_event_is_not_viewable(self):
|
||||
assert _is_viewable(_make_state_event('stats', {})) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# search_shared_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchSharedEvents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_out_state_events(self):
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.return_value = EventPage(
|
||||
items=[msg, state, msg], next_page_id=None
|
||||
)
|
||||
|
||||
result = await search_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert all(
|
||||
not isinstance(e, ConversationStateUpdateEvent) for e in result.items
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_page_unchanged(self):
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.return_value = EventPage(
|
||||
items=[], next_page_id=None
|
||||
)
|
||||
|
||||
result = await search_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert result.items == []
|
||||
assert result.next_page_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_additional_pages_when_filtering_reduces_count(self):
|
||||
"""Fetch next page when first page has only state events."""
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.side_effect = [
|
||||
# Page 1: only state events — all filtered out
|
||||
EventPage(items=[state, state, state], next_page_id='page2'),
|
||||
# Page 2: all viewable
|
||||
EventPage(items=[msg, msg, msg], next_page_id=None),
|
||||
]
|
||||
|
||||
result = await search_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
limit=3,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert len(result.items) == 3
|
||||
assert result.next_page_id is None
|
||||
assert mock_service.search_shared_events.call_count == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_pages_until_limit_reached(self):
|
||||
"""Keep fetching mixed pages until limit viewable events accumulated."""
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.side_effect = [
|
||||
EventPage(items=[msg, state], next_page_id='p2'),
|
||||
EventPage(items=[state, msg], next_page_id='p3'),
|
||||
EventPage(items=[msg], next_page_id='p4'),
|
||||
]
|
||||
|
||||
result = await search_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
limit=3,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert len(result.items) == 3
|
||||
assert result.next_page_id == 'p4'
|
||||
assert mock_service.search_shared_events.call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stops_when_no_more_pages(self):
|
||||
"""Return partial results when no more backend pages are available."""
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.side_effect = [
|
||||
EventPage(items=[msg, state], next_page_id='p2'),
|
||||
EventPage(items=[state], next_page_id=None),
|
||||
]
|
||||
|
||||
result = await search_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
limit=5,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.next_page_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_remaining_as_limit_to_backend(self):
|
||||
"""Pass remaining needed count as limit to each backend call."""
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
conv_id = uuid4().hex
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.side_effect = [
|
||||
# First call: limit=3, returns 1 viewable
|
||||
EventPage(items=[msg, state, state], next_page_id='p2'),
|
||||
# Second call: limit should be 2 (remaining)
|
||||
EventPage(items=[msg, msg], next_page_id=None),
|
||||
]
|
||||
|
||||
await search_shared_events(
|
||||
conversation_id=conv_id,
|
||||
limit=3,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
calls = mock_service.search_shared_events.call_args_list
|
||||
assert calls[0].kwargs['limit'] == 3
|
||||
assert calls[1].kwargs['limit'] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_next_page_id_when_all_filtered(self):
|
||||
"""Continue fetching when all events on a page are filtered out."""
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.search_shared_events.side_effect = [
|
||||
EventPage(items=[state], next_page_id='p2'),
|
||||
EventPage(items=[msg], next_page_id='p3'),
|
||||
]
|
||||
|
||||
result = await search_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
limit=1,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.next_page_id == 'p3'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# batch_get_shared_events
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchGetSharedEvents:
|
||||
@pytest.mark.asyncio
|
||||
async def test_replaces_state_events_with_none(self):
|
||||
msg = _make_message_event()
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.batch_get_shared_events.return_value = [msg, state, None]
|
||||
|
||||
result = await batch_get_shared_events(
|
||||
conversation_id=uuid4().hex,
|
||||
id=[uuid4().hex, uuid4().hex, uuid4().hex],
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert len(result) == 3
|
||||
assert result[0] is msg
|
||||
assert result[1] is None # state event replaced with None
|
||||
assert result[2] is None # originally None stays None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_shared_event
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSharedEvent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_message_event(self):
|
||||
msg = _make_message_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.get_shared_event.return_value = msg
|
||||
|
||||
result = await get_shared_event(
|
||||
conversation_id=uuid4().hex,
|
||||
event_id=uuid4().hex,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert result is msg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_for_state_event(self):
|
||||
state = _make_state_event()
|
||||
mock_service = AsyncMock()
|
||||
mock_service.get_shared_event.return_value = state
|
||||
|
||||
result = await get_shared_event(
|
||||
conversation_id=uuid4().hex,
|
||||
event_id=uuid4().hex,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found(self):
|
||||
mock_service = AsyncMock()
|
||||
mock_service.get_shared_event.return_value = None
|
||||
|
||||
result = await get_shared_event(
|
||||
conversation_id=uuid4().hex,
|
||||
event_id=uuid4().hex,
|
||||
shared_event_service=mock_service,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
@@ -1325,354 +1325,3 @@ async def test_migrate_user_sql_multiple_conversations(async_session_maker):
|
||||
# statements that have SQLite/UUID compatibility issues in the test environment.
|
||||
# The SQL migration tests above (test_migrate_user_sql_type_handling, etc.) verify
|
||||
# the SQL operations work correctly with proper type handling.
|
||||
|
||||
|
||||
# --- Tests for mark_onboarding_completed ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_onboarding_completed_success(async_session_maker):
|
||||
"""Test successfully marking onboarding as completed."""
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Create test data
|
||||
async with async_session_maker() as session:
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
user = User(id=user_id, current_org_id=org_id, onboarding_completed=False)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
# Test marking onboarding complete
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.mark_onboarding_completed(str(user_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
assert result.onboarding_completed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_onboarding_completed_user_not_found(async_session_maker):
|
||||
"""Test that mark_onboarding_completed returns None for non-existent user."""
|
||||
non_existent_id = str(uuid.uuid4())
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.mark_onboarding_completed(non_existent_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_onboarding_completed_already_completed(async_session_maker):
|
||||
"""Test marking onboarding complete for user who already completed it."""
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Create user with onboarding already completed
|
||||
async with async_session_maker() as session:
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
user = User(id=user_id, current_org_id=org_id, onboarding_completed=True)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
# Should still succeed and return user
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.mark_onboarding_completed(str(user_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
assert result.onboarding_completed is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mark_onboarding_completed_user_with_null_onboarding(async_session_maker):
|
||||
"""Test marking onboarding complete for user with null onboarding_completed value."""
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
# Create user with null onboarding_completed (default)
|
||||
async with async_session_maker() as session:
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
user = User(
|
||||
id=user_id, current_org_id=org_id
|
||||
) # onboarding_completed defaults to None
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.mark_onboarding_completed(str(user_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.id == user_id
|
||||
assert result.onboarding_completed is True
|
||||
|
||||
|
||||
# --- Tests for get_first_owner_in_org ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_first_owner_in_org_returns_first_owner(async_session_maker):
|
||||
"""Test that get_first_owner_in_org returns the owner with earliest accepted_tos."""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
org_id = uuid.uuid4()
|
||||
first_owner_id = uuid.uuid4()
|
||||
second_owner_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Create org
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
|
||||
# Create owner role
|
||||
owner_role = Role(id=1, name='owner', rank=10)
|
||||
session.add(owner_role)
|
||||
|
||||
# Create first owner (earlier TOS acceptance)
|
||||
first_owner = User(
|
||||
id=first_owner_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=datetime.now() - timedelta(days=10),
|
||||
)
|
||||
session.add(first_owner)
|
||||
|
||||
# Create second owner (later TOS acceptance)
|
||||
second_owner = User(
|
||||
id=second_owner_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=datetime.now() - timedelta(days=5),
|
||||
)
|
||||
session.add(second_owner)
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add both as org members with owner role
|
||||
first_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=first_owner_id,
|
||||
role_id=owner_role.id,
|
||||
llm_api_key='test-key-1',
|
||||
)
|
||||
session.add(first_member)
|
||||
|
||||
second_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=second_owner_id,
|
||||
role_id=owner_role.id,
|
||||
llm_api_key='test-key-2',
|
||||
)
|
||||
session.add(second_member)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_first_owner_in_org(org_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.id == first_owner_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_first_owner_in_org_ignores_non_owners(async_session_maker):
|
||||
"""Test that get_first_owner_in_org ignores users with non-owner roles."""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
org_id = uuid.uuid4()
|
||||
admin_id = uuid.uuid4()
|
||||
owner_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Create org
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
|
||||
# Create roles
|
||||
owner_role = Role(id=1, name='owner', rank=10)
|
||||
admin_role = Role(id=2, name='admin', rank=20)
|
||||
session.add(owner_role)
|
||||
session.add(admin_role)
|
||||
|
||||
# Create admin with earlier TOS acceptance
|
||||
admin_user = User(
|
||||
id=admin_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=datetime.now() - timedelta(days=10),
|
||||
)
|
||||
session.add(admin_user)
|
||||
|
||||
# Create owner with later TOS acceptance
|
||||
owner_user = User(
|
||||
id=owner_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=datetime.now() - timedelta(days=5),
|
||||
)
|
||||
session.add(owner_user)
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add admin member
|
||||
admin_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=admin_id,
|
||||
role_id=admin_role.id,
|
||||
llm_api_key='test-key-admin',
|
||||
)
|
||||
session.add(admin_member)
|
||||
|
||||
# Add owner member
|
||||
owner_member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=owner_id,
|
||||
role_id=owner_role.id,
|
||||
llm_api_key='test-key-owner',
|
||||
)
|
||||
session.add(owner_member)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_first_owner_in_org(org_id)
|
||||
|
||||
# Should return the owner, not the admin (even though admin has earlier TOS)
|
||||
assert result is not None
|
||||
assert result.id == owner_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_first_owner_in_org_returns_none_when_no_owners(async_session_maker):
|
||||
"""Test that get_first_owner_in_org returns None when org has no owners."""
|
||||
from datetime import datetime
|
||||
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
org_id = uuid.uuid4()
|
||||
member_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Create org
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
|
||||
# Create member role only
|
||||
member_role = Role(id=3, name='member', rank=100)
|
||||
session.add(member_role)
|
||||
|
||||
# Create user with member role
|
||||
member_user = User(
|
||||
id=member_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=datetime.now(),
|
||||
)
|
||||
session.add(member_user)
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add as member
|
||||
member = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=member_id,
|
||||
role_id=member_role.id,
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
session.add(member)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_first_owner_in_org(org_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_first_owner_in_org_ignores_owners_without_tos(async_session_maker):
|
||||
"""Test that get_first_owner_in_org ignores owners who haven't accepted TOS."""
|
||||
from datetime import datetime
|
||||
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
org_id = uuid.uuid4()
|
||||
owner_no_tos_id = uuid.uuid4()
|
||||
owner_with_tos_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Create org
|
||||
org = Org(id=org_id, name='test-org')
|
||||
session.add(org)
|
||||
|
||||
# Create owner role
|
||||
owner_role = Role(id=1, name='owner', rank=10)
|
||||
session.add(owner_role)
|
||||
|
||||
# Create owner without TOS
|
||||
owner_no_tos = User(
|
||||
id=owner_no_tos_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=None,
|
||||
)
|
||||
session.add(owner_no_tos)
|
||||
|
||||
# Create owner with TOS
|
||||
owner_with_tos = User(
|
||||
id=owner_with_tos_id,
|
||||
current_org_id=org_id,
|
||||
accepted_tos=datetime.now(),
|
||||
)
|
||||
session.add(owner_with_tos)
|
||||
|
||||
await session.flush()
|
||||
|
||||
# Add both as owners
|
||||
member_no_tos = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=owner_no_tos_id,
|
||||
role_id=owner_role.id,
|
||||
llm_api_key='test-key-1',
|
||||
)
|
||||
session.add(member_no_tos)
|
||||
|
||||
member_with_tos = OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=owner_with_tos_id,
|
||||
role_id=owner_role.id,
|
||||
llm_api_key='test-key-2',
|
||||
)
|
||||
session.add(member_with_tos)
|
||||
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_first_owner_in_org(org_id)
|
||||
|
||||
# Should return the owner who has accepted TOS
|
||||
assert result is not None
|
||||
assert result.id == owner_with_tos_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_first_owner_in_org_returns_none_for_empty_org(async_session_maker):
|
||||
"""Test that get_first_owner_in_org returns None for org with no members."""
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
async with async_session_maker() as session:
|
||||
# Create org only, no members
|
||||
org = Org(id=org_id, name='empty-org')
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.user_store.a_session_maker', async_session_maker):
|
||||
result = await UserStore.get_first_owner_in_org(org_id)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -49,34 +49,17 @@ vi.mock("#/utils/custom-toast-handlers", () => ({
|
||||
displayErrorToast: vi.fn(),
|
||||
}));
|
||||
|
||||
const mockUseAppMode = vi.fn(() => ({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas" as string | undefined,
|
||||
deploymentMode: "cloud" as string | undefined,
|
||||
}));
|
||||
vi.mock("#/hooks/use-app-mode", () => ({
|
||||
useAppMode: () => mockUseAppMode(),
|
||||
// Mock feature flags - we'll control the return value in each test
|
||||
const mockEnableProjUserJourney = vi.fn(() => true);
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => mockEnableProjUserJourney(),
|
||||
}));
|
||||
|
||||
describe("LoginContent", () => {
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal("location", { href: "" });
|
||||
// Reset mock to return SaaS Cloud (CTA enabled) by default
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas",
|
||||
deploymentMode: "cloud",
|
||||
});
|
||||
// Reset mock to return true by default
|
||||
mockEnableProjUserJourney.mockReturnValue(true);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -299,18 +282,7 @@ describe("LoginContent", () => {
|
||||
expect(screen.getByTestId("terms-and-privacy-notice")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display the enterprise LoginCTA component when in SaaS Cloud mode", () => {
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas",
|
||||
deploymentMode: "cloud",
|
||||
});
|
||||
|
||||
it("should display the enterprise LoginCTA component when appMode is saas and feature flag enabled", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
@@ -324,18 +296,7 @@ describe("LoginContent", () => {
|
||||
expect(screen.getByTestId("login-cta")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when in OSS mode", () => {
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: true,
|
||||
isSaas: false,
|
||||
isCloud: false,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "oss",
|
||||
deploymentMode: undefined,
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when appMode is oss even with feature flag enabled", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
@@ -349,17 +310,23 @@ describe("LoginContent", () => {
|
||||
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when in SaaS Self-hosted mode", () => {
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: false,
|
||||
isSelfHosted: true,
|
||||
isEnterpriseSelfHosted: true,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "saas",
|
||||
deploymentMode: "self_hosted",
|
||||
});
|
||||
it("should not display the enterprise LoginCTA component when appMode is null", () => {
|
||||
render(
|
||||
<MemoryRouter>
|
||||
<LoginContent
|
||||
githubAuthUrl="https://github.com/oauth/authorize"
|
||||
appMode={null}
|
||||
providersConfigured={["github"]}
|
||||
/>
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not display the enterprise LoginCTA component when feature flag is disabled", () => {
|
||||
// Disable the feature flag
|
||||
mockEnableProjUserJourney.mockReturnValue(false);
|
||||
|
||||
render(
|
||||
<MemoryRouter>
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import { screen } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { renderWithProviders } from "test-utils";
|
||||
import { EnterpriseBanner } from "#/components/features/device-verify/enterprise-banner";
|
||||
|
||||
const mockCapture = vi.fn();
|
||||
vi.mock("posthog-js/react", () => ({
|
||||
usePostHog: () => ({
|
||||
capture: mockCapture,
|
||||
}),
|
||||
}));
|
||||
|
||||
const { ENABLE_PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => ENABLE_PROJ_USER_JOURNEY_MOCK(),
|
||||
}));
|
||||
|
||||
describe("EnterpriseBanner", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
});
|
||||
|
||||
describe("Feature Flag", () => {
|
||||
it("should not render when proj_user_journey feature flag is disabled", () => {
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
|
||||
|
||||
const { container } = renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
expect(container.firstChild).toBeNull();
|
||||
expect(screen.queryByText("ENTERPRISE$TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render when proj_user_journey feature flag is enabled", () => {
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
expect(screen.getByText("ENTERPRISE$TITLE")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
describe("Rendering", () => {
|
||||
it("should render the self-hosted label", () => {
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
expect(screen.getByText("ENTERPRISE$SELF_HOSTED")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the enterprise title", () => {
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
expect(screen.getByText("ENTERPRISE$TITLE")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the enterprise description", () => {
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
expect(screen.getByText("ENTERPRISE$DESCRIPTION")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render all four enterprise feature items", () => {
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
expect(
|
||||
screen.getByText("ENTERPRISE$FEATURE_DATA_PRIVACY"),
|
||||
).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("ENTERPRISE$FEATURE_DEPLOYMENT"),
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByText("ENTERPRISE$FEATURE_SSO")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText("ENTERPRISE$FEATURE_SUPPORT"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the learn more link", () => {
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
const link = screen.getByRole("link", {
|
||||
name: "ENTERPRISE$LEARN_MORE_ARIA",
|
||||
});
|
||||
expect(link).toBeInTheDocument();
|
||||
expect(link).toHaveTextContent("ENTERPRISE$LEARN_MORE");
|
||||
expect(link).toHaveAttribute("href", "https://openhands.dev/enterprise");
|
||||
expect(link).toHaveAttribute("target", "_blank");
|
||||
expect(link).toHaveAttribute("rel", "noopener noreferrer");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Learn More Link Interaction", () => {
|
||||
it("should capture PostHog event when learn more link is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
const link = screen.getByRole("link", {
|
||||
name: "ENTERPRISE$LEARN_MORE_ARIA",
|
||||
});
|
||||
await user.click(link);
|
||||
|
||||
expect(mockCapture).toHaveBeenCalledWith("saas_selfhosted_inquiry");
|
||||
});
|
||||
|
||||
it("should have correct href attribute for opening in new tab", () => {
|
||||
renderWithProviders(<EnterpriseBanner />);
|
||||
|
||||
const link = screen.getByRole("link", {
|
||||
name: "ENTERPRISE$LEARN_MORE_ARIA",
|
||||
});
|
||||
expect(link).toHaveAttribute("href", "https://openhands.dev/enterprise");
|
||||
expect(link).toHaveAttribute("target", "_blank");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -90,15 +90,14 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
// Mock the search function that's used by the dropdown
|
||||
vi.spyOn(GitService, "searchGitRepositories").mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
});
|
||||
vi.spyOn(GitService, "searchGitRepositories").mockResolvedValue(
|
||||
MOCK_RESPOSITORIES,
|
||||
);
|
||||
|
||||
renderRepoConnector();
|
||||
|
||||
@@ -128,8 +127,8 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
renderRepoConnector();
|
||||
@@ -139,11 +138,14 @@ describe("RepoConnector", () => {
|
||||
|
||||
// Mock the repository branches API call
|
||||
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
|
||||
items: [
|
||||
branches: [
|
||||
{ name: "main", commit_sha: "123", protected: false },
|
||||
{ name: "develop", commit_sha: "456", protected: false },
|
||||
],
|
||||
next_page_id: null,
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: 30,
|
||||
total_count: 2,
|
||||
});
|
||||
|
||||
// First select the provider
|
||||
@@ -197,8 +199,8 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
renderRepoConnector();
|
||||
@@ -244,8 +246,8 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
renderRepoConnector();
|
||||
@@ -288,8 +290,8 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
renderRepoConnector();
|
||||
@@ -345,8 +347,8 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
renderRepoConnector();
|
||||
@@ -361,11 +363,14 @@ describe("RepoConnector", () => {
|
||||
|
||||
// Mock the repository branches API call
|
||||
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
|
||||
items: [
|
||||
branches: [
|
||||
{ name: "main", commit_sha: "123", protected: false },
|
||||
{ name: "develop", commit_sha: "456", protected: false },
|
||||
],
|
||||
next_page_id: null,
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: 30,
|
||||
total_count: 2,
|
||||
});
|
||||
|
||||
// First select the provider
|
||||
@@ -422,17 +427,20 @@ describe("RepoConnector", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
// Mock the repository branches API call
|
||||
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
|
||||
items: [
|
||||
branches: [
|
||||
{ name: "main", commit_sha: "123", protected: false },
|
||||
{ name: "develop", commit_sha: "456", protected: false },
|
||||
],
|
||||
next_page_id: null,
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: 30,
|
||||
total_count: 2,
|
||||
});
|
||||
|
||||
renderRepoConnector();
|
||||
|
||||
@@ -187,7 +187,7 @@ describe("RepositorySelectionForm", () => {
|
||||
},
|
||||
];
|
||||
mockUseGitRepositories.mockReturnValue({
|
||||
data: { pages: [{ items: MOCK_REPOS }] },
|
||||
data: { pages: [{ data: MOCK_REPOS }] },
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
hasNextPage: false,
|
||||
@@ -229,7 +229,7 @@ describe("RepositorySelectionForm", () => {
|
||||
|
||||
// Create a spy on the API call
|
||||
const searchGitReposSpy = vi.spyOn(GitService, "searchGitRepositories");
|
||||
searchGitReposSpy.mockResolvedValue({ items: MOCK_SEARCH_REPOS, next_page_id: null });
|
||||
searchGitReposSpy.mockResolvedValue(MOCK_SEARCH_REPOS);
|
||||
|
||||
mockUseGitRepositories.mockReturnValue({
|
||||
data: { pages: [] },
|
||||
@@ -267,7 +267,7 @@ describe("RepositorySelectionForm", () => {
|
||||
];
|
||||
|
||||
mockUseGitRepositories.mockReturnValue({
|
||||
data: { pages: [{ items: MOCK_SEARCH_REPOS }] },
|
||||
data: { pages: [{ data: MOCK_SEARCH_REPOS }] },
|
||||
isLoading: false,
|
||||
isError: false,
|
||||
hasNextPage: false,
|
||||
|
||||
@@ -115,8 +115,8 @@ describe("TaskCard", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,20 +1,15 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import { screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import { MemoryRouter } from "react-router";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { I18nextProvider } from "react-i18next";
|
||||
import i18n from "i18next";
|
||||
import { renderWithProviders } from "../../../../test-utils";
|
||||
import OnboardingForm from "#/routes/onboarding-form";
|
||||
|
||||
const mockMutate = vi.fn();
|
||||
const mockNavigate = vi.fn();
|
||||
const mockUseMe = vi.fn();
|
||||
const mockUseConfig = vi.fn();
|
||||
const mockTrackOnboardingCompleted = vi.fn();
|
||||
|
||||
// Loader data set in beforeEach for each test suite
|
||||
let loaderData: { config: { app_mode: string; feature_flags: { deployment_mode: string } } };
|
||||
|
||||
vi.mock("react-router", async (importOriginal) => {
|
||||
const original = await importOriginal<typeof import("react-router")>();
|
||||
return {
|
||||
@@ -29,8 +24,8 @@ vi.mock("#/hooks/mutation/use-submit-onboarding", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-me", () => ({
|
||||
useMe: () => mockUseMe(),
|
||||
vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => mockUseConfig(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
@@ -39,71 +34,49 @@ vi.mock("#/hooks/use-tracking", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
const renderOnboardingForm = async () => {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
path: "/",
|
||||
Component: OnboardingForm,
|
||||
loader: () => loaderData,
|
||||
},
|
||||
]);
|
||||
|
||||
const result = render(
|
||||
<I18nextProvider i18n={i18n}>
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<RouterStub initialEntries={["/"]} />
|
||||
</QueryClientProvider>
|
||||
</I18nextProvider>,
|
||||
const renderOnboardingForm = () => {
|
||||
return renderWithProviders(
|
||||
<MemoryRouter>
|
||||
<OnboardingForm />
|
||||
</MemoryRouter>,
|
||||
);
|
||||
|
||||
// Wait for the component to render
|
||||
await screen.findByTestId("onboarding-form");
|
||||
return result;
|
||||
};
|
||||
|
||||
describe("OnboardingForm - Cloud Mode", () => {
|
||||
describe("OnboardingForm - SaaS Mode", () => {
|
||||
beforeEach(() => {
|
||||
mockMutate.mockClear();
|
||||
mockNavigate.mockClear();
|
||||
mockTrackOnboardingCompleted.mockClear();
|
||||
loaderData = {
|
||||
config: {
|
||||
app_mode: "saas",
|
||||
feature_flags: { deployment_mode: "cloud" },
|
||||
},
|
||||
};
|
||||
// Cloud mode tracks all users, role doesn't matter
|
||||
mockUseMe.mockReturnValue({ data: { role: "member" } });
|
||||
mockUseConfig.mockReturnValue({
|
||||
data: { app_mode: "saas" },
|
||||
isLoading: false,
|
||||
});
|
||||
});
|
||||
|
||||
it("should render with the correct test id", async () => {
|
||||
await renderOnboardingForm();
|
||||
it("should render with the correct test id", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
expect(screen.getByTestId("onboarding-form")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should render the first step initially", async () => {
|
||||
await renderOnboardingForm();
|
||||
it("should render the first step initially", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
expect(screen.getByTestId("step-header")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("step-content")).toBeInTheDocument();
|
||||
expect(screen.getByTestId("step-actions")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display step progress indicator with 3 bars for cloud mode", async () => {
|
||||
await renderOnboardingForm();
|
||||
it("should display step progress indicator with 3 bars for saas mode", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
const stepHeader = screen.getByTestId("step-header");
|
||||
const progressBars = stepHeader.querySelectorAll(".rounded-full");
|
||||
expect(progressBars).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("should have the Next button disabled when no option is selected", async () => {
|
||||
await renderOnboardingForm();
|
||||
it("should have the Next button disabled when no option is selected", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).toBeDisabled();
|
||||
@@ -111,7 +84,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should enable the Next button when an option is selected", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
|
||||
@@ -121,7 +94,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should advance to the next step when Next is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// On step 1, first progress bar should be filled (bg-white)
|
||||
const stepHeader = screen.getByTestId("step-header");
|
||||
@@ -138,7 +111,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should disable Next button again on new step until option is selected", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
@@ -149,7 +122,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should call submitOnboarding with selections when finishing the last step", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Step 1 - select org size (first step in saas mode - single select)
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
@@ -173,11 +146,11 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should track onboarding completion to PostHog in cloud mode", async () => {
|
||||
it("should track onboarding completion to PostHog in SaaS mode", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Complete the full cloud onboarding flow
|
||||
// Complete the full SaaS onboarding flow
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
@@ -195,8 +168,8 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should render 5 options on step 1 (org size question)", async () => {
|
||||
await renderOnboardingForm();
|
||||
it("should render 5 options on step 1 (org size question)", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
const options = screen
|
||||
.getAllByRole("button")
|
||||
@@ -208,7 +181,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should preserve selections when navigating through steps", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Select org size on step 1 (single select)
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
@@ -234,7 +207,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should allow selecting multiple options on multi-select steps", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Step 1 - select org size (single select)
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
@@ -261,7 +234,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should allow deselecting options on multi-select steps", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Step 1 - select org size
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
@@ -289,7 +262,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should show all progress bars filled on the last step", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Navigate to step 3
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
@@ -304,8 +277,8 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
expect(progressBars).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("should not render the Back button on the first step", async () => {
|
||||
await renderOnboardingForm();
|
||||
it("should not render the Back button on the first step", () => {
|
||||
renderOnboardingForm();
|
||||
|
||||
const backButton = screen.queryByRole("button", { name: /back/i });
|
||||
expect(backButton).not.toBeInTheDocument();
|
||||
@@ -313,7 +286,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should render the Back button on step 2", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
@@ -324,7 +297,7 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
|
||||
it("should go back to the previous step when Back is clicked", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
renderOnboardingForm();
|
||||
|
||||
// Navigate to step 2
|
||||
await user.click(screen.getByTestId("step-option-solo"));
|
||||
@@ -343,197 +316,3 @@ describe("OnboardingForm - Cloud Mode", () => {
|
||||
expect(progressBars).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe("OnboardingForm - Self-Hosted Mode", () => {
|
||||
// Self-hosted mode has 3 steps: org_name, org_size, use_case
|
||||
// The role question is cloud-only and not shown in self-hosted mode
|
||||
|
||||
beforeEach(() => {
|
||||
mockMutate.mockClear();
|
||||
mockNavigate.mockClear();
|
||||
mockTrackOnboardingCompleted.mockClear();
|
||||
loaderData = {
|
||||
config: {
|
||||
app_mode: "saas",
|
||||
feature_flags: { deployment_mode: "self_hosted" },
|
||||
},
|
||||
};
|
||||
// Self-hosted mode only tracks org owners
|
||||
mockUseMe.mockReturnValue({ data: { role: "owner" } });
|
||||
});
|
||||
|
||||
it("should render with the correct test id", async () => {
|
||||
await renderOnboardingForm();
|
||||
|
||||
expect(screen.getByTestId("onboarding-form")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should display step progress indicator with 3 bars for self-hosted mode", async () => {
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Self-hosted has 3 steps: org_name, org_size, use_case (role is cloud-only)
|
||||
const stepHeader = screen.getByTestId("step-header");
|
||||
const progressBars = stepHeader.querySelectorAll(".rounded-full");
|
||||
expect(progressBars).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("should start with org_name question as first step with two input fields", async () => {
|
||||
await renderOnboardingForm();
|
||||
|
||||
// The first step in self-hosted mode should be org_name with two inputs
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
expect(orgNameInput).toBeInTheDocument();
|
||||
expect(orgDomainInput).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should call submitOnboarding with all selections including org_name when finishing", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Step 1 - enter org name and domain (input fields)
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
await user.type(orgNameInput, "Acme Corp");
|
||||
await user.type(orgDomainInput, "acme.com");
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 2 - select org size (single select)
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// Step 3 - select use case (multi-select) - this is the last step in self-hosted mode
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
expect(mockMutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockMutate).toHaveBeenCalledWith({
|
||||
selections: {
|
||||
org_name: "Acme Corp",
|
||||
org_domain: "acme.com",
|
||||
org_size: "org_2_10",
|
||||
use_case: ["new_features"],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it("should track onboarding completion in self-hosted mode", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Complete the full self-hosted onboarding flow (3 steps)
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
await user.type(orgNameInput, "Test Company");
|
||||
await user.type(orgDomainInput, "test.com");
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
expect(mockTrackOnboardingCompleted).toHaveBeenCalledTimes(1);
|
||||
// Note: role is not included since role question is cloud-only
|
||||
expect(mockTrackOnboardingCompleted).toHaveBeenCalledWith({
|
||||
role: undefined,
|
||||
orgSize: "org_2_10",
|
||||
useCase: ["new_features"],
|
||||
});
|
||||
});
|
||||
|
||||
it("should show all 3 progress bars filled on the last step", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Navigate through all 3 steps
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
await user.type(orgNameInput, "Test Company");
|
||||
await user.type(orgDomainInput, "test.com");
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
// On step 3, all three progress bars should be filled
|
||||
const stepHeader = screen.getByTestId("step-header");
|
||||
const progressBars = stepHeader.querySelectorAll(".bg-white");
|
||||
expect(progressBars).toHaveLength(3);
|
||||
});
|
||||
|
||||
it("should have Next button disabled when both org_name inputs are empty", async () => {
|
||||
await renderOnboardingForm();
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).toBeDisabled();
|
||||
});
|
||||
|
||||
it("should enable Next button when both org_name and org_domain are entered", async () => {
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
await user.type(orgNameInput, "My Company");
|
||||
await user.type(orgDomainInput, "mycompany.com");
|
||||
|
||||
const nextButton = screen.getByRole("button", { name: /next/i });
|
||||
expect(nextButton).not.toBeDisabled();
|
||||
});
|
||||
|
||||
it("should NOT track onboarding completion for non-owners in self-hosted mode", async () => {
|
||||
// Override the mock to return a member (non-owner) role
|
||||
mockUseMe.mockReturnValue({ data: { role: "member" } });
|
||||
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Complete the full self-hosted onboarding flow (3 steps)
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
await user.type(orgNameInput, "Test Company");
|
||||
await user.type(orgDomainInput, "test.com");
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
// Tracking should NOT be called for non-owners in self-hosted mode
|
||||
expect(mockTrackOnboardingCompleted).not.toHaveBeenCalled();
|
||||
|
||||
// But onboarding submission should still work
|
||||
expect(mockMutate).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("should NOT track onboarding completion for admins in self-hosted mode", async () => {
|
||||
// Override the mock to return an admin role
|
||||
mockUseMe.mockReturnValue({ data: { role: "admin" } });
|
||||
|
||||
const user = userEvent.setup();
|
||||
await renderOnboardingForm();
|
||||
|
||||
// Complete the full self-hosted onboarding flow (3 steps)
|
||||
const orgNameInput = screen.getByTestId("form-input-org_name");
|
||||
const orgDomainInput = screen.getByTestId("form-input-org_domain");
|
||||
await user.type(orgNameInput, "Test Company");
|
||||
await user.type(orgDomainInput, "test.com");
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-org_2_10"));
|
||||
await user.click(screen.getByRole("button", { name: /next/i }));
|
||||
|
||||
await user.click(screen.getByTestId("step-option-new_features"));
|
||||
await user.click(screen.getByRole("button", { name: /finish/i }));
|
||||
|
||||
// Tracking should NOT be called for admins in self-hosted mode (only owners)
|
||||
expect(mockTrackOnboardingCompleted).not.toHaveBeenCalled();
|
||||
|
||||
// But onboarding submission should still work
|
||||
expect(mockMutate).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -11,7 +11,6 @@ import OptionService from "#/api/option-service/option-service.api";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
import { useSelectedOrganizationStore } from "#/stores/selected-organization-store";
|
||||
import * as FeatureFlags from "#/utils/feature-flags";
|
||||
|
||||
// Helper to create mock config with sensible defaults
|
||||
const createMockConfig = (
|
||||
@@ -186,36 +185,4 @@ describe("Sidebar", () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Automations button visibility", () => {
|
||||
let enableAutomationsSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
enableAutomationsSpy = vi.spyOn(FeatureFlags, "ENABLE_AUTOMATIONS");
|
||||
});
|
||||
|
||||
it("should show automations button when ENABLE_AUTOMATIONS flag is on", async () => {
|
||||
enableAutomationsSpy.mockReturnValue(true);
|
||||
|
||||
renderSidebar();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByTestId("automations-button"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should hide automations button when ENABLE_AUTOMATIONS flag is off", async () => {
|
||||
enableAutomationsSpy.mockReturnValue(false);
|
||||
|
||||
renderSidebar();
|
||||
|
||||
await waitFor(() => expect(getSettingsSpy).toHaveBeenCalled());
|
||||
|
||||
expect(
|
||||
screen.queryByTestId("automations-button"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -23,6 +23,12 @@ vi.mock("#/hooks/use-breakpoint", () => ({
|
||||
useBreakpoint: vi.fn(() => false), // Default to desktop (not mobile)
|
||||
}));
|
||||
|
||||
// Mock feature flags
|
||||
const mockEnableProjUserJourney = vi.fn(() => true);
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => mockEnableProjUserJourney(),
|
||||
}));
|
||||
|
||||
// Mock useTracking hook for CTA
|
||||
vi.mock("#/hooks/use-tracking", () => ({
|
||||
useTracking: () => ({
|
||||
@@ -138,8 +144,9 @@ describe("UserContextMenu", () => {
|
||||
// Ensure clean state at the start of each test
|
||||
vi.restoreAllMocks();
|
||||
useSelectedOrganizationStore.setState({ organizationId: null });
|
||||
// Reset breakpoint mock to desktop by default
|
||||
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(false);
|
||||
// Reset feature flag and breakpoint mocks to defaults
|
||||
mockEnableProjUserJourney.mockReturnValue(true);
|
||||
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(false); // Desktop by default
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -743,26 +750,15 @@ describe("UserContextMenu", () => {
|
||||
});
|
||||
|
||||
describe("Context Menu CTA", () => {
|
||||
it("should render the CTA component in SaaS Cloud mode on desktop", async () => {
|
||||
it("should render the CTA component in SaaS mode on desktop with feature flag enabled", async () => {
|
||||
// Set SaaS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
deployment_mode: "cloud",
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("context-menu-cta")).toBeInTheDocument();
|
||||
});
|
||||
@@ -770,73 +766,59 @@ describe("UserContextMenu", () => {
|
||||
expect(screen.getByText("CTA$LEARN_MORE")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render the CTA component in OSS mode", async () => {
|
||||
it("should not render the CTA component in OSS mode even with feature flag enabled", async () => {
|
||||
// Set OSS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({ app_mode: "oss" }),
|
||||
);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render the CTA component on mobile even in SaaS Cloud mode", async () => {
|
||||
it("should not render the CTA component on mobile even in SaaS mode with feature flag enabled", async () => {
|
||||
// Set SaaS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
deployment_mode: "cloud",
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
// Set mobile mode
|
||||
vi.mocked(breakpoint.useBreakpoint).mockReturnValue(true);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not render the CTA component in SaaS Self-hosted mode", async () => {
|
||||
it("should not render the CTA component when feature flag is disabled in SaaS mode", async () => {
|
||||
// Set SaaS mode
|
||||
vi.spyOn(OptionService, "getConfig").mockResolvedValue(
|
||||
createMockWebClientConfig({
|
||||
app_mode: "saas",
|
||||
feature_flags: {
|
||||
deployment_mode: "self_hosted",
|
||||
enable_billing: false,
|
||||
hide_llm_settings: false,
|
||||
enable_jira: false,
|
||||
enable_jira_dc: false,
|
||||
enable_linear: false,
|
||||
hide_users_page: false,
|
||||
hide_billing_page: false,
|
||||
hide_integrations_page: false,
|
||||
},
|
||||
}),
|
||||
createMockWebClientConfig({ app_mode: "saas" }),
|
||||
);
|
||||
// Disable the feature flag
|
||||
mockEnableProjUserJourney.mockReturnValue(false);
|
||||
|
||||
renderUserContextMenu({ type: "member", onClose: vi.fn, onOpenInviteModal: vi.fn });
|
||||
|
||||
// Wait for config to load
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("user-context-menu")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(screen.queryByTestId("context-menu-cta")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,45 +1,12 @@
|
||||
import { describe, it, expect, vi } from "vitest";
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { ModelSelector } from "#/components/shared/modals/settings/model-selector";
|
||||
import type { LLMProvider, LLMModel } from "#/api/config-service/config-service.types";
|
||||
|
||||
const mockProviders: LLMProvider[] = [
|
||||
{ name: "openai", verified: true },
|
||||
{ name: "azure", verified: false },
|
||||
{ name: "vertex_ai", verified: false },
|
||||
];
|
||||
|
||||
const mockModelsByProvider: Record<string, LLMModel[]> = {
|
||||
openai: [
|
||||
{ provider: "openai", name: "gpt-4o", verified: true },
|
||||
{ provider: "openai", name: "gpt-4o-mini", verified: true },
|
||||
],
|
||||
azure: [
|
||||
{ provider: "azure", name: "ada", verified: false },
|
||||
{ provider: "azure", name: "gpt-35-turbo", verified: false },
|
||||
],
|
||||
vertex_ai: [
|
||||
{ provider: "vertex_ai", name: "chat-bison", verified: false },
|
||||
{ provider: "vertex_ai", name: "chat-bison-32k", verified: false },
|
||||
],
|
||||
};
|
||||
|
||||
vi.mock("#/hooks/query/use-search-providers", () => ({
|
||||
useSearchProviders: () => ({ data: mockProviders }),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-provider-models", () => ({
|
||||
useProviderModels: (provider: string | null) => ({
|
||||
data: provider ? (mockModelsByProvider[provider] ?? []) : [],
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("react-i18next", () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => {
|
||||
const translations: Record<string, string> = {
|
||||
const translations: { [key: string]: string } = {
|
||||
LLM$PROVIDER: "LLM Provider",
|
||||
LLM$MODEL: "LLM Model",
|
||||
LLM$SELECT_PROVIDER_PLACEHOLDER: "Select a provider",
|
||||
@@ -50,19 +17,34 @@ vi.mock("react-i18next", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
function renderWithQuery(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>{ui}</QueryClientProvider>,
|
||||
);
|
||||
}
|
||||
|
||||
describe("ModelSelector", () => {
|
||||
const models = {
|
||||
openai: {
|
||||
separator: "/",
|
||||
models: ["gpt-4o", "gpt-4o-mini"],
|
||||
},
|
||||
azure: {
|
||||
separator: "/",
|
||||
models: ["ada", "gpt-35-turbo"],
|
||||
},
|
||||
vertex_ai: {
|
||||
separator: "/",
|
||||
models: ["chat-bison", "chat-bison-32k"],
|
||||
},
|
||||
};
|
||||
|
||||
const verifiedModels = ["gpt-4o", "gpt-4o-mini"];
|
||||
const verifiedProviders = ["openai"];
|
||||
|
||||
it("should display the provider selector", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderWithQuery(<ModelSelector />);
|
||||
render(
|
||||
<ModelSelector
|
||||
models={models}
|
||||
verifiedModels={verifiedModels}
|
||||
verifiedProviders={verifiedProviders}
|
||||
/>,
|
||||
);
|
||||
|
||||
const selector = screen.getByLabelText("LLM Provider");
|
||||
expect(selector).toBeInTheDocument();
|
||||
@@ -76,7 +58,13 @@ describe("ModelSelector", () => {
|
||||
|
||||
it("should disable the model selector if the provider is not selected", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderWithQuery(<ModelSelector />);
|
||||
render(
|
||||
<ModelSelector
|
||||
models={models}
|
||||
verifiedModels={verifiedModels}
|
||||
verifiedProviders={verifiedProviders}
|
||||
/>,
|
||||
);
|
||||
|
||||
const modelSelector = screen.getByLabelText("LLM Model");
|
||||
expect(modelSelector).toBeDisabled();
|
||||
@@ -92,7 +80,13 @@ describe("ModelSelector", () => {
|
||||
|
||||
it("should display the model selector", async () => {
|
||||
const user = userEvent.setup();
|
||||
renderWithQuery(<ModelSelector />);
|
||||
render(
|
||||
<ModelSelector
|
||||
models={models}
|
||||
verifiedModels={verifiedModels}
|
||||
verifiedProviders={verifiedProviders}
|
||||
/>,
|
||||
);
|
||||
|
||||
const providerSelector = screen.getByLabelText("LLM Provider");
|
||||
await user.click(providerSelector);
|
||||
@@ -111,7 +105,14 @@ describe("ModelSelector", () => {
|
||||
const user = userEvent.setup();
|
||||
const onChange = vi.fn();
|
||||
|
||||
renderWithQuery(<ModelSelector onChange={onChange} />);
|
||||
render(
|
||||
<ModelSelector
|
||||
models={models}
|
||||
verifiedModels={verifiedModels}
|
||||
verifiedProviders={verifiedProviders}
|
||||
onChange={onChange}
|
||||
/>,
|
||||
);
|
||||
|
||||
const providerSelector = screen.getByLabelText("LLM Provider");
|
||||
await user.click(providerSelector);
|
||||
@@ -125,12 +126,18 @@ describe("ModelSelector", () => {
|
||||
expect(onChange).toHaveBeenNthCalledWith(2, "azure", "ada");
|
||||
});
|
||||
|
||||
it("should have a default value if passed", async () => {
|
||||
renderWithQuery(<ModelSelector currentModel="azure/ada" />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByLabelText("LLM Provider")).toHaveValue("Azure");
|
||||
expect(screen.getByLabelText("LLM Model")).toHaveValue("ada");
|
||||
});
|
||||
it("should have a default value if passed", async () => {
|
||||
render(
|
||||
<ModelSelector
|
||||
models={models}
|
||||
verifiedModels={verifiedModels}
|
||||
verifiedProviders={verifiedProviders}
|
||||
currentModel="azure/ada"
|
||||
/>,
|
||||
);
|
||||
|
||||
expect(screen.getByLabelText("LLM Provider")).toHaveValue("Azure");
|
||||
expect(screen.getByLabelText("LLM Model")).toHaveValue("ada");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
import { createEvent, fireEvent, render, screen } from "@testing-library/react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { AutomationsButton } from "#/components/shared/buttons/automations-button";
|
||||
|
||||
describe("AutomationsButton", () => {
|
||||
it("should render a link to /automations", () => {
|
||||
render(<AutomationsButton />);
|
||||
|
||||
const link = screen.getByTestId("automations-button");
|
||||
expect(link).toBeInTheDocument();
|
||||
expect(link).toHaveAttribute("href", "/automations");
|
||||
});
|
||||
|
||||
it("should be focusable and accessible when enabled", () => {
|
||||
render(<AutomationsButton />);
|
||||
|
||||
const link = screen.getByTestId("automations-button");
|
||||
expect(link).toHaveAttribute("tabIndex", "0");
|
||||
expect(link).toHaveAttribute("aria-label", "SIDEBAR$AUTOMATIONS");
|
||||
});
|
||||
|
||||
it("should prevent navigation and remove from tab order when disabled", () => {
|
||||
render(<AutomationsButton disabled />);
|
||||
|
||||
const link = screen.getByTestId("automations-button");
|
||||
expect(link).toHaveAttribute("tabIndex", "-1");
|
||||
|
||||
const clickEvent = createEvent.click(link);
|
||||
fireEvent(link, clickEvent);
|
||||
expect(clickEvent.defaultPrevented).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -14,7 +14,13 @@ describe("SettingsForm", () => {
|
||||
const RouteStub = createRoutesStub([
|
||||
{
|
||||
Component: () => (
|
||||
<SettingsForm settings={DEFAULT_SETTINGS} onClose={onCloseMock} />
|
||||
<SettingsForm
|
||||
settings={DEFAULT_SETTINGS}
|
||||
models={[DEFAULT_SETTINGS.llm_model]}
|
||||
verifiedModels={[]}
|
||||
verifiedProviders={["openhands"]}
|
||||
onClose={onCloseMock}
|
||||
/>
|
||||
),
|
||||
path: "/",
|
||||
},
|
||||
|
||||
@@ -5,21 +5,12 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { createRoutesStub } from "react-router";
|
||||
import DeviceVerify from "#/routes/device-verify";
|
||||
|
||||
const { useIsAuthedMock, mockUseAppMode } = vi.hoisted(() => ({
|
||||
const { useIsAuthedMock, ENABLE_PROJ_USER_JOURNEY_MOCK } = vi.hoisted(() => ({
|
||||
useIsAuthedMock: vi.fn(() => ({
|
||||
data: false as boolean | undefined,
|
||||
isLoading: false,
|
||||
})),
|
||||
mockUseAppMode: vi.fn(() => ({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas" as string | undefined,
|
||||
deploymentMode: "cloud" as string | undefined,
|
||||
})),
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK: vi.fn(() => true),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/query/use-is-authed", () => ({
|
||||
@@ -32,8 +23,8 @@ vi.mock("posthog-js/react", () => ({
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-app-mode", () => ({
|
||||
useAppMode: () => mockUseAppMode(),
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => ENABLE_PROJ_USER_JOURNEY_MOCK(),
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
@@ -75,17 +66,8 @@ describe("DeviceVerify", () => {
|
||||
}),
|
||||
),
|
||||
);
|
||||
// Reset useAppMode to SaaS Cloud (CTA enabled) by default
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas",
|
||||
deploymentMode: "cloud",
|
||||
});
|
||||
// Enable feature flag by default
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(true);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -253,17 +235,7 @@ describe("DeviceVerify", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should include the LoginCTA component when in SaaS Cloud mode", async () => {
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas",
|
||||
deploymentMode: "cloud",
|
||||
});
|
||||
it("should include the LoginCTA component when feature flag is enabled", async () => {
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
@@ -281,45 +253,8 @@ describe("DeviceVerify", () => {
|
||||
});
|
||||
});
|
||||
|
||||
it("should not include the LoginCTA component when in OSS mode", async () => {
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: true,
|
||||
isSaas: false,
|
||||
isCloud: false,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "oss",
|
||||
deploymentMode: undefined,
|
||||
});
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
});
|
||||
|
||||
render(
|
||||
<RouterStub initialEntries={["/device-verify?user_code=ABC-123"]} />,
|
||||
{
|
||||
wrapper: createWrapper(),
|
||||
},
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByTestId("login-cta")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not include the LoginCTA and be center-aligned when in SaaS Self-hosted mode", async () => {
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: false,
|
||||
isSelfHosted: true,
|
||||
isEnterpriseSelfHosted: true,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "saas",
|
||||
deploymentMode: "self_hosted",
|
||||
});
|
||||
it("should not include the LoginCTA and be center-aligned when feature flag is disabled", async () => {
|
||||
ENABLE_PROJ_USER_JOURNEY_MOCK.mockReturnValue(false);
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
|
||||
@@ -13,7 +13,7 @@ import AuthService from "#/api/auth-service/auth-service.api";
|
||||
import MainApp from "#/routes/root-layout";
|
||||
import { MOCK_DEFAULT_USER_SETTINGS } from "#/mocks/handlers";
|
||||
|
||||
const { DEFAULT_FEATURE_FLAGS, useIsAuthedMock, useConfigMock, mockUseAppMode } = vi.hoisted(
|
||||
const { DEFAULT_FEATURE_FLAGS, useIsAuthedMock, useConfigMock } = vi.hoisted(
|
||||
() => {
|
||||
const defaultFeatureFlags = {
|
||||
enable_billing: false,
|
||||
@@ -41,16 +41,6 @@ const { DEFAULT_FEATURE_FLAGS, useIsAuthedMock, useConfigMock, mockUseAppMode }
|
||||
},
|
||||
isLoading: false,
|
||||
}),
|
||||
mockUseAppMode: vi.fn().mockReturnValue({
|
||||
isOss: true,
|
||||
isSaas: false,
|
||||
isCloud: false,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "oss",
|
||||
deploymentMode: undefined,
|
||||
}),
|
||||
};
|
||||
},
|
||||
);
|
||||
@@ -63,19 +53,6 @@ vi.mock("#/hooks/query/use-config", () => ({
|
||||
useConfig: () => useConfigMock(),
|
||||
}));
|
||||
|
||||
vi.mock("#/hooks/use-app-mode", () => ({
|
||||
useAppMode: () => mockUseAppMode() ?? {
|
||||
isOss: true,
|
||||
isSaas: false,
|
||||
isCloud: false,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "oss",
|
||||
deploymentMode: undefined,
|
||||
},
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
{
|
||||
Component: MainApp,
|
||||
@@ -251,17 +228,20 @@ describe("HomeScreen", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
// Mock the repository branches API call
|
||||
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
|
||||
items: [
|
||||
branches: [
|
||||
{ name: "main", commit_sha: "123", protected: false },
|
||||
{ name: "develop", commit_sha: "456", protected: false },
|
||||
],
|
||||
next_page_id: null,
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: 30,
|
||||
total_count: 2,
|
||||
});
|
||||
|
||||
renderHomeScreen();
|
||||
@@ -292,17 +272,20 @@ describe("HomeScreen", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
|
||||
// Mock the repository branches API call
|
||||
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
|
||||
items: [
|
||||
branches: [
|
||||
{ name: "main", commit_sha: "123", protected: false },
|
||||
{ name: "develop", commit_sha: "456", protected: false },
|
||||
],
|
||||
next_page_id: null,
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: 30,
|
||||
total_count: 2,
|
||||
});
|
||||
|
||||
renderHomeScreen();
|
||||
@@ -349,11 +332,14 @@ describe("HomeScreen", () => {
|
||||
|
||||
// Mock the repository branches API call
|
||||
vi.spyOn(GitService, "getRepositoryBranches").mockResolvedValue({
|
||||
items: [
|
||||
branches: [
|
||||
{ name: "main", commit_sha: "123", protected: false },
|
||||
{ name: "develop", commit_sha: "456", protected: false },
|
||||
],
|
||||
next_page_id: null,
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: 30,
|
||||
total_count: 2,
|
||||
});
|
||||
|
||||
// Select a repository to enable the repo launch button
|
||||
@@ -385,8 +371,8 @@ describe("HomeScreen", () => {
|
||||
"retrieveUserGitRepositories",
|
||||
);
|
||||
retrieveUserGitRepositoriesSpy.mockResolvedValue({
|
||||
items: MOCK_RESPOSITORIES,
|
||||
next_page_id: null,
|
||||
data: MOCK_RESPOSITORIES,
|
||||
nextPage: null,
|
||||
});
|
||||
});
|
||||
|
||||
@@ -635,9 +621,14 @@ describe("HomepageCTA visibility", () => {
|
||||
|
||||
getSettingsSpy.mockResolvedValue(MOCK_DEFAULT_USER_SETTINGS);
|
||||
|
||||
// Mock localStorage for CTA dismissal
|
||||
// Mock localStorage to enable the PROJ_USER_JOURNEY feature flag (CTA dismissal also uses localStorage)
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null),
|
||||
getItem: vi.fn((key: string) => {
|
||||
if (key === "FEATURE_PROJ_USER_JOURNEY") {
|
||||
return "true";
|
||||
}
|
||||
return null;
|
||||
}),
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
@@ -649,7 +640,7 @@ describe("HomepageCTA visibility", () => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("should show HomepageCTA in SaaS Cloud mode when not dismissed", async () => {
|
||||
it("should show HomepageCTA in SaaS mode when not dismissed and feature flag enabled", async () => {
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
@@ -657,26 +648,16 @@ describe("HomepageCTA visibility", () => {
|
||||
isError: false,
|
||||
});
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { app_mode: "saas", feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "cloud" } },
|
||||
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas",
|
||||
deploymentMode: "cloud",
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
providers_configured: ["github"],
|
||||
auth_url: "https://auth.example.com",
|
||||
feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "cloud" },
|
||||
feature_flags: DEFAULT_FEATURE_FLAGS,
|
||||
maintenance_start_time: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
@@ -693,7 +674,7 @@ describe("HomepageCTA visibility", () => {
|
||||
expect(ctaLink).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show HomepageCTA in OSS mode", async () => {
|
||||
it("should not show HomepageCTA in OSS mode even with feature flag enabled", async () => {
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
@@ -704,16 +685,6 @@ describe("HomepageCTA visibility", () => {
|
||||
data: { app_mode: "oss", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: true,
|
||||
isSaas: false,
|
||||
isCloud: false,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "oss",
|
||||
deploymentMode: undefined,
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "oss",
|
||||
@@ -733,10 +704,18 @@ describe("HomepageCTA visibility", () => {
|
||||
|
||||
await screen.findByTestId("home-screen");
|
||||
|
||||
expect(screen.queryByTestId("homepage-cta-learn-more")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("should not show HomepageCTA in SaaS Self-hosted mode", async () => {
|
||||
it("should not show HomepageCTA when feature flag is disabled", async () => {
|
||||
// Override localStorage to disable the feature flag
|
||||
vi.stubGlobal("localStorage", {
|
||||
getItem: vi.fn(() => null), // No feature flags set
|
||||
setItem: vi.fn(),
|
||||
removeItem: vi.fn(),
|
||||
clear: vi.fn(),
|
||||
});
|
||||
|
||||
useIsAuthedMock.mockReturnValue({
|
||||
data: true,
|
||||
isLoading: false,
|
||||
@@ -744,26 +723,16 @@ describe("HomepageCTA visibility", () => {
|
||||
isError: false,
|
||||
});
|
||||
useConfigMock.mockReturnValue({
|
||||
data: { app_mode: "saas", feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "self_hosted" } },
|
||||
data: { app_mode: "saas", feature_flags: DEFAULT_FEATURE_FLAGS },
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseAppMode.mockReturnValue({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: false,
|
||||
isSelfHosted: true,
|
||||
isEnterpriseSelfHosted: true,
|
||||
isEnterpriseCloud: false,
|
||||
appMode: "saas",
|
||||
deploymentMode: "self_hosted",
|
||||
});
|
||||
|
||||
getConfigSpy.mockResolvedValue({
|
||||
app_mode: "saas",
|
||||
posthog_client_key: "test-posthog-key",
|
||||
providers_configured: ["github"],
|
||||
auth_url: "https://auth.example.com",
|
||||
feature_flags: { ...DEFAULT_FEATURE_FLAGS, deployment_mode: "self_hosted" },
|
||||
feature_flags: DEFAULT_FEATURE_FLAGS,
|
||||
maintenance_start_time: null,
|
||||
recaptcha_site_key: null,
|
||||
faulty_models: [],
|
||||
@@ -776,7 +745,7 @@ describe("HomepageCTA visibility", () => {
|
||||
|
||||
await screen.findByTestId("home-screen");
|
||||
|
||||
expect(screen.queryByTestId("homepage-cta-learn-more")).not.toBeInTheDocument();
|
||||
expect(screen.queryByText("CTA$ENTERPRISE_TITLE")).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
@@ -73,18 +73,9 @@ vi.mock("#/hooks/use-invitation", () => ({
|
||||
useInvitation: () => useInvitationMock(),
|
||||
}));
|
||||
|
||||
// Mock useAppMode hook - enable CTA by default (SaaS Cloud mode)
|
||||
vi.mock("#/hooks/use-app-mode", () => ({
|
||||
useAppMode: () => ({
|
||||
isOss: false,
|
||||
isSaas: true,
|
||||
isCloud: true,
|
||||
isSelfHosted: false,
|
||||
isEnterpriseSelfHosted: false,
|
||||
isEnterpriseCloud: true,
|
||||
appMode: "saas",
|
||||
deploymentMode: "cloud",
|
||||
}),
|
||||
// Mock feature flags - enable by default for tests
|
||||
vi.mock("#/utils/feature-flags", () => ({
|
||||
ENABLE_PROJ_USER_JOURNEY: () => true,
|
||||
}));
|
||||
|
||||
const RouterStub = createRoutesStub([
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
import { test, expect } from "vitest";
|
||||
import { isNumber } from "../../src/utils/is-number";
|
||||
|
||||
test("isNumber", () => {
|
||||
expect(isNumber(1)).toBe(true);
|
||||
expect(isNumber(0)).toBe(true);
|
||||
expect(isNumber("3")).toBe(true);
|
||||
expect(isNumber("0")).toBe(true);
|
||||
});
|
||||
@@ -38,7 +38,6 @@ vi.mock("#/query-client-config", () => ({
|
||||
queryClient: {
|
||||
getQueryData: vi.fn(() => mockConfig),
|
||||
setQueryData: vi.fn(),
|
||||
fetchQuery: vi.fn(() => Promise.resolve(mockConfig)),
|
||||
},
|
||||
}));
|
||||
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { describe, it, expect, vi, test } from "vitest";
|
||||
import {
|
||||
formatTimestamp,
|
||||
getExtension,
|
||||
removeApiKey,
|
||||
} from "../../src/utils/utils";
|
||||
import { getStatusText } from "#/utils/utils";
|
||||
import { AgentState } from "#/types/agent-state";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
@@ -17,6 +22,25 @@ const t = (key: string) => {
|
||||
return translations[key] || key;
|
||||
};
|
||||
|
||||
test("removeApiKey", () => {
|
||||
const data = [{ args: { LLM_API_KEY: "key", LANGUAGE: "en" } }];
|
||||
expect(removeApiKey(data)).toEqual([{ args: { LANGUAGE: "en" } }]);
|
||||
});
|
||||
|
||||
test("getExtension", () => {
|
||||
expect(getExtension("main.go")).toBe("go");
|
||||
expect(getExtension("get-extension.test.ts")).toBe("ts");
|
||||
expect(getExtension("directory")).toBe("");
|
||||
});
|
||||
|
||||
test("formatTimestamp", () => {
|
||||
const morningDate = new Date("2021-10-10T10:10:10.000").toISOString();
|
||||
expect(formatTimestamp(morningDate)).toBe("10/10/2021, 10:10:10");
|
||||
|
||||
const eveningDate = new Date("2021-10-10T22:10:10.000").toISOString();
|
||||
expect(formatTimestamp(eveningDate)).toBe("10/10/2021, 22:10:10");
|
||||
});
|
||||
|
||||
describe("getStatusText", () => {
|
||||
it("returns STOPPING when pausing", () => {
|
||||
const result = getStatusText({
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import type {
|
||||
LLMModelPage,
|
||||
ProviderPage,
|
||||
SearchModelsParams,
|
||||
SearchProvidersParams,
|
||||
} from "./config-service.types";
|
||||
|
||||
function toSearchParams(
|
||||
params: SearchModelsParams | SearchProvidersParams,
|
||||
): string {
|
||||
const searchParams = new URLSearchParams();
|
||||
for (const [key, value] of Object.entries(params)) {
|
||||
if (value !== undefined && value !== null) {
|
||||
searchParams.append(key, String(value));
|
||||
}
|
||||
}
|
||||
return searchParams.toString();
|
||||
}
|
||||
|
||||
class ConfigService {
|
||||
static async searchModels(
|
||||
params: SearchModelsParams = {},
|
||||
): Promise<LLMModelPage> {
|
||||
const qs = toSearchParams(params);
|
||||
const { data } = await openHands.get<LLMModelPage>(
|
||||
`/api/v1/config/models/search?${qs}`,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
static async searchProviders(
|
||||
params: SearchProvidersParams = {},
|
||||
): Promise<ProviderPage> {
|
||||
const qs = toSearchParams(params);
|
||||
const { data } = await openHands.get<ProviderPage>(
|
||||
`/api/v1/config/providers/search?${qs}`,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
}
|
||||
|
||||
export default ConfigService;
|
||||
@@ -1,37 +0,0 @@
|
||||
/** V1 Config API types for models and providers */
|
||||
|
||||
export interface LLMModel {
|
||||
provider: string | null;
|
||||
name: string;
|
||||
verified: boolean;
|
||||
}
|
||||
|
||||
export interface LLMModelPage {
|
||||
items: LLMModel[];
|
||||
next_page_id: string | null;
|
||||
}
|
||||
|
||||
export interface SearchModelsParams {
|
||||
page_id?: string;
|
||||
limit?: number;
|
||||
query?: string;
|
||||
verified__eq?: boolean;
|
||||
provider__eq?: string;
|
||||
}
|
||||
|
||||
export interface LLMProvider {
|
||||
name: string;
|
||||
verified: boolean;
|
||||
}
|
||||
|
||||
export interface ProviderPage {
|
||||
items: LLMProvider[];
|
||||
next_page_id: string | null;
|
||||
}
|
||||
|
||||
export interface SearchProvidersParams {
|
||||
page_id?: string;
|
||||
limit?: number;
|
||||
query?: string;
|
||||
verified__eq?: boolean;
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
import { openHands } from "../open-hands-axios";
|
||||
import { RepositoryPage, BranchPage, InstallationPage } from "#/types/git";
|
||||
import { Provider } from "#/types/settings";
|
||||
import { GitRepository, PaginatedBranchesResponse, Branch } from "#/types/git";
|
||||
import { extractNextPageFromLink } from "#/utils/extract-next-page-from-link";
|
||||
import { GitChange, GitChangeDiff } from "../open-hands.types";
|
||||
import ConversationService from "../conversation-service/conversation-service.api";
|
||||
|
||||
@@ -10,122 +12,130 @@ class GitService {
|
||||
/**
|
||||
* Search for Git repositories
|
||||
* @param query Search query
|
||||
* @param provider Git provider to search in (required)
|
||||
* @param limit Number of results per page
|
||||
* @param pageId Cursor for pagination
|
||||
* @param installationId Filter by installation ID
|
||||
* @param sortOrder Sort order (asc or desc)
|
||||
* @returns Paginated repository response
|
||||
* @param per_page Number of results per page
|
||||
* @param selected_provider Git provider to search in
|
||||
* @returns List of matching repositories
|
||||
*/
|
||||
static async searchGitRepositories(
|
||||
query: string,
|
||||
provider: string,
|
||||
limit = 100,
|
||||
pageId?: string,
|
||||
installationId?: string,
|
||||
): Promise<RepositoryPage> {
|
||||
const { data } = await openHands.get<RepositoryPage>(
|
||||
"/api/v1/git/repositories/search",
|
||||
per_page = 100,
|
||||
selected_provider?: Provider,
|
||||
): Promise<GitRepository[]> {
|
||||
const response = await openHands.get<GitRepository[]>(
|
||||
"/api/user/search/repositories",
|
||||
{
|
||||
params: {
|
||||
provider,
|
||||
query,
|
||||
limit,
|
||||
page_id: pageId,
|
||||
installation_id: installationId,
|
||||
per_page,
|
||||
selected_provider,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return data;
|
||||
return response.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve user's Git repositories
|
||||
* @param provider Git provider
|
||||
* @param pageId Cursor for pagination
|
||||
* @param limit Number of results per page
|
||||
* @param installationId Filter by installation ID
|
||||
* @param sortOrder Sort order (asc or desc)
|
||||
* @param selected_provider Git provider
|
||||
* @param page Page number
|
||||
* @param per_page Number of results per page
|
||||
* @returns User's repositories with pagination info
|
||||
*/
|
||||
static async retrieveUserGitRepositories(
|
||||
provider: string,
|
||||
pageId?: string,
|
||||
limit = 30,
|
||||
installationId?: string,
|
||||
): Promise<RepositoryPage> {
|
||||
const { data } = await openHands.get<RepositoryPage>(
|
||||
"/api/v1/git/repositories/search",
|
||||
selected_provider: Provider,
|
||||
page = 1,
|
||||
per_page = 30,
|
||||
) {
|
||||
const { data } = await openHands.get<GitRepository[]>(
|
||||
"/api/user/repositories",
|
||||
{
|
||||
params: {
|
||||
provider,
|
||||
limit,
|
||||
page_id: pageId,
|
||||
installation_id: installationId,
|
||||
selected_provider,
|
||||
sort: "pushed",
|
||||
page,
|
||||
per_page,
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
return data;
|
||||
const link =
|
||||
data.length > 0 && data[0].link_header ? data[0].link_header : "";
|
||||
const nextPage = extractNextPageFromLink(link);
|
||||
|
||||
return { data, nextPage };
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve repositories from a specific installation
|
||||
* @param provider Git provider
|
||||
* @param selected_provider Git provider
|
||||
* @param installationIndex Current installation index
|
||||
* @param installations List of installation IDs
|
||||
* @param pageId Cursor for pagination
|
||||
* @param limit Number of results per page
|
||||
* @param page Page number
|
||||
* @param per_page Number of results per page
|
||||
* @returns Installation repositories with pagination info
|
||||
*/
|
||||
static async retrieveInstallationRepositories(
|
||||
provider: string,
|
||||
selected_provider: Provider,
|
||||
installationIndex: number,
|
||||
installations: string[],
|
||||
pageId?: string,
|
||||
limit = 30,
|
||||
): Promise<RepositoryPage> {
|
||||
page = 1,
|
||||
per_page = 30,
|
||||
) {
|
||||
const installationId = installations[installationIndex];
|
||||
const { data } = await openHands.get<RepositoryPage>(
|
||||
"/api/v1/git/repositories/search",
|
||||
const response = await openHands.get<GitRepository[]>(
|
||||
"/api/user/repositories",
|
||||
{
|
||||
params: {
|
||||
provider,
|
||||
limit,
|
||||
page_id: pageId,
|
||||
selected_provider,
|
||||
sort: "pushed",
|
||||
page,
|
||||
per_page,
|
||||
installation_id: installationId,
|
||||
},
|
||||
},
|
||||
);
|
||||
return data;
|
||||
const link =
|
||||
response.data.length > 0 && response.data[0].link_header
|
||||
? response.data[0].link_header
|
||||
: "";
|
||||
const nextPage = extractNextPageFromLink(link);
|
||||
let nextInstallation: number | null;
|
||||
if (nextPage) {
|
||||
nextInstallation = installationIndex;
|
||||
} else if (installationIndex + 1 < installations.length) {
|
||||
nextInstallation = installationIndex + 1;
|
||||
} else {
|
||||
nextInstallation = null;
|
||||
}
|
||||
return {
|
||||
data: response.data,
|
||||
nextPage,
|
||||
installationIndex: nextInstallation,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get repository branches
|
||||
* @param repository Repository name
|
||||
* @param provider Git provider (required)
|
||||
* @param query Search query (required - can be empty string)
|
||||
* @param pageId Cursor for pagination
|
||||
* @param limit Number of results per page
|
||||
* @param page Page number
|
||||
* @param perPage Number of results per page
|
||||
* @returns Paginated branches response
|
||||
*/
|
||||
static async getRepositoryBranches(
|
||||
repository: string,
|
||||
provider: string,
|
||||
query: string = "",
|
||||
pageId?: string,
|
||||
limit = 30,
|
||||
): Promise<BranchPage> {
|
||||
const { data } = await openHands.get<BranchPage>(
|
||||
"/api/v1/git/branches/search",
|
||||
page: number = 1,
|
||||
perPage: number = 30,
|
||||
selectedProvider?: Provider,
|
||||
): Promise<PaginatedBranchesResponse> {
|
||||
const { data } = await openHands.get<PaginatedBranchesResponse>(
|
||||
`/api/user/repository/branches`,
|
||||
{
|
||||
params: {
|
||||
provider,
|
||||
repository,
|
||||
query,
|
||||
page_id: pageId,
|
||||
limit,
|
||||
page,
|
||||
per_page: perPage,
|
||||
selected_provider: selectedProvider,
|
||||
},
|
||||
},
|
||||
);
|
||||
@@ -135,51 +145,40 @@ class GitService {
|
||||
|
||||
/**
|
||||
* Search repository branches
|
||||
* @deprecated Use getRepositoryBranches instead - this method is identical
|
||||
* @param repository Repository name
|
||||
* @param provider Git provider (required)
|
||||
* @param query Search query
|
||||
* @param pageId Cursor for pagination
|
||||
* @param limit Number of results per page
|
||||
* @param perPage Number of results per page
|
||||
* @param selectedProvider Git provider
|
||||
* @returns List of matching branches
|
||||
*/
|
||||
static async searchRepositoryBranches(
|
||||
repository: string,
|
||||
provider: string,
|
||||
query: string,
|
||||
pageId?: string,
|
||||
limit = 30,
|
||||
): Promise<BranchPage> {
|
||||
return this.getRepositoryBranches(
|
||||
repository,
|
||||
provider,
|
||||
query,
|
||||
pageId,
|
||||
limit,
|
||||
perPage: number = 30,
|
||||
selectedProvider?: Provider,
|
||||
): Promise<Branch[]> {
|
||||
const { data } = await openHands.get<Branch[]>(
|
||||
`/api/user/search/branches`,
|
||||
{
|
||||
params: {
|
||||
repository,
|
||||
query,
|
||||
per_page: perPage,
|
||||
selected_provider: selectedProvider,
|
||||
},
|
||||
},
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the user installation IDs
|
||||
* @param provider The provider to get installation IDs for (github, bitbucket, etc.)
|
||||
* @param pageId Cursor for pagination
|
||||
* @param limit Max number of results
|
||||
* @returns Paginated installation response
|
||||
* @returns List of installation IDs
|
||||
*/
|
||||
static async getUserInstallations(
|
||||
provider: string,
|
||||
pageId?: string,
|
||||
limit = 100,
|
||||
): Promise<InstallationPage> {
|
||||
const { data } = await openHands.get<InstallationPage>(
|
||||
"/api/v1/git/installations/search",
|
||||
{
|
||||
params: {
|
||||
provider,
|
||||
page_id: pageId,
|
||||
limit,
|
||||
},
|
||||
},
|
||||
static async getUserInstallationIds(provider: Provider): Promise<string[]> {
|
||||
const { data } = await openHands.get<string[]>(
|
||||
`/api/user/installations?provider=${provider}`,
|
||||
);
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import { Provider } from "#/types/settings";
|
||||
|
||||
export type DeploymentMode = "cloud" | "self_hosted";
|
||||
|
||||
/**
|
||||
* Structured response from ``GET /api/options/models``.
|
||||
*
|
||||
@@ -28,7 +26,6 @@ export interface WebClientFeatureFlags {
|
||||
hide_users_page: boolean;
|
||||
hide_billing_page: boolean;
|
||||
hide_integrations_page: boolean;
|
||||
deployment_mode?: DeploymentMode;
|
||||
}
|
||||
|
||||
export interface WebClientConfig {
|
||||
|
||||
@@ -9,7 +9,7 @@ class SettingsService {
|
||||
* Get the settings from the server or use the default settings if not found
|
||||
*/
|
||||
static async getSettings(): Promise<Settings> {
|
||||
const { data } = await openHands.get<Settings>("/api/v1/settings");
|
||||
const { data } = await openHands.get<Settings>("/api/settings");
|
||||
return data;
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ class SettingsService {
|
||||
* @param settings - the settings to save
|
||||
*/
|
||||
static async saveSettings(settings: Partial<Settings>): Promise<boolean> {
|
||||
const data = await openHands.post("/api/v1/settings", settings);
|
||||
const data = await openHands.post("/api/settings", settings);
|
||||
return data.status === 200;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,8 +14,8 @@ import { useRecaptcha } from "#/hooks/use-recaptcha";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
|
||||
import { LoginCTA } from "./login-cta";
|
||||
import { useAppMode } from "#/hooks/use-app-mode";
|
||||
|
||||
export interface LoginContentProps {
|
||||
githubAuthUrl: string | null;
|
||||
@@ -45,7 +45,6 @@ export function LoginContent({
|
||||
const { t } = useTranslation();
|
||||
const { trackLoginButtonClick } = useTracking();
|
||||
const { data: config } = useConfig();
|
||||
const { isEnterpriseCloud } = useAppMode();
|
||||
|
||||
// reCAPTCHA - only need token generation, verification happens at backend callback
|
||||
const { isReady: recaptchaReady, executeRecaptcha } = useRecaptcha({
|
||||
@@ -307,7 +306,7 @@ export function LoginContent({
|
||||
<TermsAndPrivacyNotice className="max-w-[320px] text-[#A3A3A3]" />
|
||||
</div>
|
||||
|
||||
{isEnterpriseCloud && <LoginCTA />}
|
||||
{appMode === "saas" && ENABLE_PROJ_USER_JOURNEY() && <LoginCTA />}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { H2, Text } from "#/ui/typography";
|
||||
import CheckCircleFillIcon from "#/icons/check-circle-fill.svg?react";
|
||||
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
|
||||
|
||||
const ENTERPRISE_FEATURE_KEYS: I18nKey[] = [
|
||||
I18nKey.ENTERPRISE$FEATURE_DATA_PRIVACY,
|
||||
I18nKey.ENTERPRISE$FEATURE_DEPLOYMENT,
|
||||
I18nKey.ENTERPRISE$FEATURE_SSO,
|
||||
I18nKey.ENTERPRISE$FEATURE_SUPPORT,
|
||||
];
|
||||
|
||||
export function EnterpriseBanner() {
|
||||
const { t } = useTranslation();
|
||||
const posthog = usePostHog();
|
||||
|
||||
if (!ENABLE_PROJ_USER_JOURNEY()) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const handleLearnMore = () => {
|
||||
posthog?.capture("saas_selfhosted_inquiry");
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full max-w-md mx-auto lg:mx-0 lg:w-80 p-6 rounded-lg bg-gradient-to-b from-slate-800 to-slate-900 border border-slate-700 h-fit">
|
||||
{/* Self-Hosted Label */}
|
||||
<div className="flex justify-center mb-4">
|
||||
<div className="px-8 py-0.5 rounded-full bg-gradient-to-r from-blue-900 to-blue-950 border border-blue-800">
|
||||
<Text className="text-xs font-medium text-blue-400 tracking-wider uppercase">
|
||||
{t(I18nKey.ENTERPRISE$SELF_HOSTED)}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Title */}
|
||||
<H2 className="text-center mb-3">{t(I18nKey.ENTERPRISE$TITLE)}</H2>
|
||||
|
||||
{/* Description */}
|
||||
<Text className="text-sm text-gray-400 text-center mb-6 block">
|
||||
{t(I18nKey.ENTERPRISE$DESCRIPTION)}
|
||||
</Text>
|
||||
|
||||
{/* Features List */}
|
||||
<ul className="space-y-3 mb-6">
|
||||
{ENTERPRISE_FEATURE_KEYS.map((featureKey) => (
|
||||
<li key={featureKey} className="flex items-center gap-2">
|
||||
<CheckCircleFillIcon className="w-4 h-4 text-blue-400 flex-shrink-0" />
|
||||
<Text className="text-sm text-gray-300">{t(featureKey)}</Text>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
|
||||
{/* Learn More Button */}
|
||||
<a
|
||||
href="https://openhands.dev/enterprise"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
onClick={handleLearnMore}
|
||||
aria-label={t(I18nKey.ENTERPRISE$LEARN_MORE_ARIA)}
|
||||
className="block w-full py-2.5 px-4 rounded-lg bg-blue-600 hover:bg-blue-700 text-white font-medium transition-colors text-center"
|
||||
>
|
||||
{t(I18nKey.ENTERPRISE$LEARN_MORE)}
|
||||
</a>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -38,7 +38,7 @@ export function useRepositoryData(
|
||||
|
||||
// Combine all repositories from paginated data
|
||||
const allRepositories = useMemo(
|
||||
() => repoData?.pages?.flatMap((page) => page.items) || [],
|
||||
() => repoData?.pages?.flatMap((page) => page.data) || [],
|
||||
[repoData],
|
||||
);
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@ export function useUrlSearch(inputValue: string, provider: Provider) {
|
||||
try {
|
||||
const repositories = await GitService.searchGitRepositories(
|
||||
repoName,
|
||||
provider,
|
||||
3,
|
||||
provider,
|
||||
);
|
||||
|
||||
setUrlSearchResults(repositories.items);
|
||||
setUrlSearchResults(repositories);
|
||||
} catch {
|
||||
setUrlSearchResults([]);
|
||||
} finally {
|
||||
|
||||
@@ -6,7 +6,6 @@ import { UserActions } from "./user-actions";
|
||||
import { OpenHandsLogoButton } from "#/components/shared/buttons/openhands-logo-button";
|
||||
import { NewProjectButton } from "#/components/shared/buttons/new-project-button";
|
||||
import { ConversationPanelButton } from "#/components/shared/buttons/conversation-panel-button";
|
||||
import { AutomationsButton } from "#/components/shared/buttons/automations-button";
|
||||
import { SettingsModal } from "#/components/shared/modals/settings/settings-modal";
|
||||
import { useSettings } from "#/hooks/query/use-settings";
|
||||
import { ConversationPanel } from "../conversation-panel/conversation-panel";
|
||||
@@ -15,7 +14,6 @@ import { useConfig } from "#/hooks/query/use-config";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { ENABLE_AUTOMATIONS } from "#/utils/feature-flags";
|
||||
|
||||
export function Sidebar() {
|
||||
const { t } = useTranslation();
|
||||
@@ -89,11 +87,6 @@ export function Sidebar() {
|
||||
}
|
||||
disabled={settings?.email_verified === false}
|
||||
/>
|
||||
{ENABLE_AUTOMATIONS() && (
|
||||
<AutomationsButton
|
||||
disabled={settings?.email_verified === false}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="flex flex-row md:flex-col md:items-center gap-[26px]">
|
||||
|
||||
@@ -15,9 +15,10 @@ import { ContextMenuCTA } from "../context-menu/context-menu-cta";
|
||||
import { ContextMenuNavLink } from "../context-menu/context-menu-nav-link";
|
||||
import { useShouldHideOrgSelector } from "#/hooks/use-should-hide-org-selector";
|
||||
import { useBreakpoint } from "#/hooks/use-breakpoint";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
import { ENABLE_PROJ_USER_JOURNEY } from "#/utils/feature-flags";
|
||||
import { SettingsNavHeader } from "../settings/settings-nav-header";
|
||||
import { SettingsNavDivider } from "../settings/settings-nav-divider";
|
||||
import { useAppMode } from "#/hooks/use-app-mode";
|
||||
|
||||
// Shared className for context menu list items in the user context menu
|
||||
const contextMenuListItemClassName = cn(
|
||||
@@ -41,12 +42,13 @@ export function UserContextMenu({
|
||||
const settingsNavItems = useSettingsNavItems();
|
||||
const shouldHideSelector = useShouldHideOrgSelector();
|
||||
const isMobile = useBreakpoint(768);
|
||||
const { isSaas, isEnterpriseCloud } = useAppMode();
|
||||
const { data: config } = useConfig();
|
||||
|
||||
// Keep all nav items including headers and dividers for proper section grouping
|
||||
const navItems = settingsNavItems;
|
||||
|
||||
const isMember = type === "member";
|
||||
const isSaasMode = config?.app_mode === "saas";
|
||||
|
||||
// Check if the ORG SETTINGS header exists in nav items
|
||||
const hasOrgHeader = navItems.some(
|
||||
@@ -58,9 +60,8 @@ export function UserContextMenu({
|
||||
// Show invite button for admin/owner in team orgs
|
||||
const showInviteButton = !isMember && !isPersonalOrg;
|
||||
|
||||
// CTA only renders in SaaS Cloud desktop mode
|
||||
const isCTAEnabled = isEnterpriseCloud && !isMobile;
|
||||
|
||||
// CTA only renders in SaaS desktop with feature flag enabled
|
||||
const showCta = isSaasMode && !isMobile && ENABLE_PROJ_USER_JOURNEY();
|
||||
const handleLogout = () => {
|
||||
logout();
|
||||
onClose();
|
||||
@@ -154,7 +155,7 @@ export function UserContextMenu({
|
||||
</a>
|
||||
|
||||
{/* Only show logout in saas mode - oss mode has no session to invalidate */}
|
||||
{isSaas && (
|
||||
{isSaasMode && (
|
||||
<ContextMenuListItem
|
||||
onClick={handleLogout}
|
||||
className={contextMenuListItemClassName}
|
||||
@@ -166,7 +167,7 @@ export function UserContextMenu({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isCTAEnabled && <ContextMenuCTA />}
|
||||
{showCta && <ContextMenuCTA />}
|
||||
</ContextMenuContainer>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import React from "react";
|
||||
import { PostHogProvider } from "posthog-js/react";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { QUERY_KEYS, CONFIG_CACHE_OPTIONS } from "#/hooks/query/query-keys";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
const POSTHOG_BOOTSTRAP_KEY = "posthog_bootstrap";
|
||||
@@ -49,12 +47,7 @@ export function PostHogWrapper({ children }: { children: React.ReactNode }) {
|
||||
React.useEffect(() => {
|
||||
(async () => {
|
||||
try {
|
||||
// Use fetchQuery for automatic caching and deduplication
|
||||
const config = await queryClient.fetchQuery({
|
||||
queryKey: QUERY_KEYS.WEB_CLIENT_CONFIG,
|
||||
queryFn: OptionService.getConfig,
|
||||
...CONFIG_CACHE_OPTIONS,
|
||||
});
|
||||
const config = await OptionService.getConfig();
|
||||
setPosthogClientKey(config.posthog_client_key);
|
||||
} catch {
|
||||
displayErrorToast("Error fetching PostHog client key");
|
||||
|
||||
@@ -1,38 +0,0 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { StyledTooltip } from "#/components/shared/buttons/styled-tooltip";
|
||||
import AutomationsIcon from "#/icons/automations.svg?react";
|
||||
import { cn } from "#/utils/utils";
|
||||
|
||||
interface AutomationsButtonProps {
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
export function AutomationsButton({
|
||||
disabled = false,
|
||||
}: AutomationsButtonProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const label = t(I18nKey.SIDEBAR$AUTOMATIONS);
|
||||
|
||||
return (
|
||||
<StyledTooltip content={label} placement="right">
|
||||
<a
|
||||
href="/automations"
|
||||
data-testid="automations-button"
|
||||
aria-label={label}
|
||||
tabIndex={disabled ? -1 : 0}
|
||||
onClick={(e) => {
|
||||
if (disabled) {
|
||||
e.preventDefault();
|
||||
}
|
||||
}}
|
||||
className={cn("inline-flex items-center justify-center", {
|
||||
"pointer-events-none opacity-50": disabled,
|
||||
})}
|
||||
>
|
||||
<AutomationsIcon width={24} height={24} />
|
||||
</a>
|
||||
</StyledTooltip>
|
||||
);
|
||||
}
|
||||
@@ -11,11 +11,14 @@ import { extractModelAndProvider } from "#/utils/extract-model-and-provider";
|
||||
import { cn } from "#/utils/utils";
|
||||
import { HelpLink } from "#/ui/help-link";
|
||||
import { PRODUCT_URL } from "#/utils/constants";
|
||||
import { useSearchProviders } from "#/hooks/query/use-search-providers";
|
||||
import { useProviderModels } from "#/hooks/query/use-provider-models";
|
||||
|
||||
interface ModelSelectorProps {
|
||||
isDisabled?: boolean;
|
||||
models: Record<string, { separator: string; models: string[] }>;
|
||||
/** Model names (no provider prefix) the backend considers verified. */
|
||||
verifiedModels: string[];
|
||||
/** Provider names the backend considers verified. */
|
||||
verifiedProviders: string[];
|
||||
currentModel?: string;
|
||||
onChange?: (provider: string | null, model: string | null) => void;
|
||||
onDefaultValuesChanged?: (
|
||||
@@ -28,6 +31,9 @@ interface ModelSelectorProps {
|
||||
|
||||
export function ModelSelector({
|
||||
isDisabled,
|
||||
models,
|
||||
verifiedModels,
|
||||
verifiedProviders,
|
||||
currentModel,
|
||||
onChange,
|
||||
onDefaultValuesChanged,
|
||||
@@ -40,51 +46,30 @@ export function ModelSelector({
|
||||
);
|
||||
const [selectedModel, setSelectedModel] = React.useState<string | null>(null);
|
||||
|
||||
const { data: providers = [] } = useSearchProviders();
|
||||
const {
|
||||
data: providerModels = [],
|
||||
isLoading: isLoadingModels,
|
||||
error: modelsError,
|
||||
} = useProviderModels(selectedProvider);
|
||||
|
||||
const verifiedProviders = React.useMemo(
|
||||
() => providers.filter((p) => p.verified),
|
||||
[providers],
|
||||
);
|
||||
const unverifiedProviders = React.useMemo(
|
||||
() => providers.filter((p) => !p.verified),
|
||||
[providers],
|
||||
);
|
||||
|
||||
const verifiedModels = React.useMemo(
|
||||
() => providerModels.filter((m) => m.verified),
|
||||
[providerModels],
|
||||
);
|
||||
const unverifiedModels = React.useMemo(
|
||||
() => providerModels.filter((m) => !m.verified),
|
||||
[providerModels],
|
||||
);
|
||||
|
||||
React.useEffect(() => {
|
||||
if (currentModel) {
|
||||
// runs when resetting to defaults
|
||||
const { provider, model } = extractModelAndProvider(currentModel);
|
||||
|
||||
setLitellmId(currentModel);
|
||||
setSelectedProvider(provider || null);
|
||||
setSelectedProvider(provider);
|
||||
setSelectedModel(model);
|
||||
onDefaultValuesChanged?.(provider || null, model);
|
||||
onDefaultValuesChanged?.(provider, model);
|
||||
}
|
||||
}, [currentModel]);
|
||||
|
||||
const handleChangeProvider = (provider: string) => {
|
||||
setSelectedProvider(provider);
|
||||
setSelectedModel(null);
|
||||
setLitellmId(`${provider}/`);
|
||||
|
||||
const separator = models[provider]?.separator || "";
|
||||
setLitellmId(provider + separator);
|
||||
onChange?.(provider, null);
|
||||
};
|
||||
|
||||
const handleChangeModel = (model: string) => {
|
||||
let fullModel = `${selectedProvider}/${model}`;
|
||||
const separator = models[selectedProvider || ""]?.separator || "";
|
||||
let fullModel = selectedProvider + separator + model;
|
||||
if (selectedProvider === "openai") {
|
||||
// LiteLLM lists OpenAI models without the openai/ prefix
|
||||
fullModel = model;
|
||||
@@ -138,22 +123,28 @@ export function ModelSelector({
|
||||
}}
|
||||
>
|
||||
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$VERIFIED)}>
|
||||
{verifiedProviders.map((provider) => (
|
||||
<AutocompleteItem
|
||||
data-testid={`provider-item-${provider.name}`}
|
||||
key={provider.name}
|
||||
>
|
||||
{mapProvider(provider.name)}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
{unverifiedProviders.length > 0 ? (
|
||||
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
|
||||
{unverifiedProviders.map((provider) => (
|
||||
<AutocompleteItem key={provider.name}>
|
||||
{mapProvider(provider.name)}
|
||||
{verifiedProviders
|
||||
.filter((provider) => models[provider])
|
||||
.map((provider) => (
|
||||
<AutocompleteItem
|
||||
data-testid={`provider-item-${provider}`}
|
||||
key={provider}
|
||||
>
|
||||
{mapProvider(provider)}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
{Object.keys(models).some(
|
||||
(provider) => !verifiedProviders.includes(provider),
|
||||
) ? (
|
||||
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
|
||||
{Object.keys(models)
|
||||
.filter((provider) => !verifiedProviders.includes(provider))
|
||||
.map((provider) => (
|
||||
<AutocompleteItem key={provider}>
|
||||
{mapProvider(provider)}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
) : null}
|
||||
</Autocomplete>
|
||||
@@ -178,7 +169,6 @@ export function ModelSelector({
|
||||
data-testid="llm-model-input"
|
||||
isRequired
|
||||
isVirtualized={false}
|
||||
isLoading={isLoadingModels}
|
||||
name="llm-model-input"
|
||||
aria-label={t(I18nKey.LLM$MODEL)}
|
||||
placeholder={t(I18nKey.LLM$SELECT_MODEL_PLACEHOLDER)}
|
||||
@@ -200,28 +190,31 @@ export function ModelSelector({
|
||||
}}
|
||||
>
|
||||
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$VERIFIED)}>
|
||||
{verifiedModels.map((model) => (
|
||||
<AutocompleteItem key={model.name}>{model.name}</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
{unverifiedModels.length > 0 ? (
|
||||
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
|
||||
{unverifiedModels.map((model) => (
|
||||
<AutocompleteItem
|
||||
data-testid={`model-item-${model.name}`}
|
||||
key={model.name}
|
||||
>
|
||||
{model.name}
|
||||
</AutocompleteItem>
|
||||
{verifiedModels
|
||||
.filter((model) =>
|
||||
models[selectedProvider || ""]?.models?.includes(model),
|
||||
)
|
||||
.map((model) => (
|
||||
<AutocompleteItem key={model}>{model}</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
{models[selectedProvider || ""]?.models?.some(
|
||||
(model) => !verifiedModels.includes(model),
|
||||
) ? (
|
||||
<AutocompleteSection title={t(I18nKey.MODEL_SELECTOR$OTHERS)}>
|
||||
{models[selectedProvider || ""]?.models
|
||||
.filter((model) => !verifiedModels.includes(model))
|
||||
.map((model) => (
|
||||
<AutocompleteItem
|
||||
data-testid={`model-item-${model}`}
|
||||
key={model}
|
||||
>
|
||||
{model}
|
||||
</AutocompleteItem>
|
||||
))}
|
||||
</AutocompleteSection>
|
||||
) : null}
|
||||
</Autocomplete>
|
||||
{modelsError && (
|
||||
<p data-testid="models-error" className="text-danger text-xs">
|
||||
{t(I18nKey.CONFIGURATION$ERROR_FETCH_MODELS)}
|
||||
</p>
|
||||
)}
|
||||
</fieldset>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -3,6 +3,7 @@ import { useTranslation } from "react-i18next";
|
||||
import React from "react";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { organizeModelsAndProviders } from "#/utils/organize-models-and-providers";
|
||||
import { DangerModal } from "../confirmation-modals/danger-modal";
|
||||
import { extractSettings } from "#/utils/settings-utils";
|
||||
import { ModalBackdrop } from "../modal-backdrop";
|
||||
@@ -16,10 +17,19 @@ import { SETTINGS_FORM } from "#/utils/constants";
|
||||
|
||||
interface SettingsFormProps {
|
||||
settings: Settings;
|
||||
models: string[];
|
||||
verifiedModels: string[];
|
||||
verifiedProviders: string[];
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function SettingsForm({ settings, onClose }: SettingsFormProps) {
|
||||
export function SettingsForm({
|
||||
settings,
|
||||
models,
|
||||
verifiedModels,
|
||||
verifiedProviders,
|
||||
onClose,
|
||||
}: SettingsFormProps) {
|
||||
const posthog = usePostHog();
|
||||
const { mutate: saveUserSettings } = useSaveSettings();
|
||||
|
||||
@@ -77,6 +87,9 @@ export function SettingsForm({ settings, onClose }: SettingsFormProps) {
|
||||
>
|
||||
<div className="flex flex-col gap-[17px]">
|
||||
<ModelSelector
|
||||
models={organizeModelsAndProviders(models)}
|
||||
verifiedModels={verifiedModels}
|
||||
verifiedProviders={verifiedProviders}
|
||||
currentModel={settings.llm_model}
|
||||
wrapperClassName="!flex-col !gap-[17px]"
|
||||
labelClassName={SETTINGS_FORM.LABEL_CLASSNAME}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { useAIConfigOptions } from "#/hooks/query/use-ai-config-options";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { LoadingSpinner } from "../../loading-spinner";
|
||||
import { ModalBackdrop } from "../modal-backdrop";
|
||||
import { SettingsForm } from "./settings-form";
|
||||
import { Settings } from "#/types/settings";
|
||||
@@ -12,6 +14,7 @@ interface SettingsModalProps {
|
||||
}
|
||||
|
||||
export function SettingsModal({ onClose, settings }: SettingsModalProps) {
|
||||
const aiConfigOptions = useAIConfigOptions();
|
||||
const { t } = useTranslation();
|
||||
|
||||
return (
|
||||
@@ -20,6 +23,9 @@ export function SettingsModal({ onClose, settings }: SettingsModalProps) {
|
||||
data-testid="ai-config-modal"
|
||||
className="bg-[#25272D] min-w-full max-w-[475px] m-4 p-6 rounded-xl flex flex-col gap-[17px] border border-tertiary api-configuration-modal"
|
||||
>
|
||||
{aiConfigOptions.error && (
|
||||
<p className="text-danger text-xs">{aiConfigOptions.error.message}</p>
|
||||
)}
|
||||
<span className="text-5 leading-6 font-semibold -tracking-[0.2px]">
|
||||
{t(I18nKey.AI_SETTINGS$TITLE)}
|
||||
</span>
|
||||
@@ -34,10 +40,20 @@ export function SettingsModal({ onClose, settings }: SettingsModalProps) {
|
||||
suffixClassName="text-white"
|
||||
/>
|
||||
|
||||
<SettingsForm
|
||||
settings={settings || DEFAULT_SETTINGS}
|
||||
onClose={onClose}
|
||||
/>
|
||||
{aiConfigOptions.isLoading && (
|
||||
<div className="flex justify-center">
|
||||
<LoadingSpinner size="small" />
|
||||
</div>
|
||||
)}
|
||||
{aiConfigOptions.data && (
|
||||
<SettingsForm
|
||||
settings={settings || DEFAULT_SETTINGS}
|
||||
models={aiConfigOptions.data?.models}
|
||||
verifiedModels={aiConfigOptions.data?.verifiedModels ?? []}
|
||||
verifiedProviders={aiConfigOptions.data?.verifiedProviders ?? []}
|
||||
onClose={onClose}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</ModalBackdrop>
|
||||
);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
|
||||
export type OnboardingAppMode = "cloud" | "self-hosted" | "oss";
|
||||
export type OnboardingAppMode = "saas" | "self-hosted";
|
||||
|
||||
interface BaseOnboardingQuestion {
|
||||
id: string;
|
||||
@@ -43,9 +43,9 @@ export const ONBOARDING_FORM: OnboardingQuestion[] = [
|
||||
{
|
||||
id: "org_size",
|
||||
type: "single",
|
||||
app_mode: ["cloud", "self-hosted"],
|
||||
app_mode: ["saas", "self-hosted"],
|
||||
questionKey: I18nKey.ONBOARDING$ORG_SIZE_QUESTION,
|
||||
subtitleKey: I18nKey.ONBOARDING$SELECT_ONE_SUBTITLE,
|
||||
subtitleKey: I18nKey.ONBOARDING$ORG_SIZE_SUBTITLE,
|
||||
answerOptions: [
|
||||
{ key: I18nKey.ONBOARDING$ORG_SIZE_SOLO, id: "solo" },
|
||||
{ key: I18nKey.ONBOARDING$ORG_SIZE_2_10, id: "org_2_10" },
|
||||
@@ -57,9 +57,9 @@ export const ONBOARDING_FORM: OnboardingQuestion[] = [
|
||||
{
|
||||
id: "use_case",
|
||||
type: "multi",
|
||||
app_mode: ["cloud", "self-hosted"],
|
||||
app_mode: ["saas", "self-hosted"],
|
||||
questionKey: I18nKey.ONBOARDING$USE_CASE_QUESTION,
|
||||
subtitleKey: I18nKey.ONBOARDING$SELECT_MULTIPLE_SUBTITLE,
|
||||
subtitleKey: I18nKey.ONBOARDING$USE_CASE_SUBTITLE,
|
||||
answerOptions: [
|
||||
{ key: I18nKey.ONBOARDING$USE_CASE_NEW_FEATURES, id: "new_features" },
|
||||
{
|
||||
@@ -78,9 +78,8 @@ export const ONBOARDING_FORM: OnboardingQuestion[] = [
|
||||
{
|
||||
id: "role",
|
||||
type: "single",
|
||||
app_mode: ["cloud"],
|
||||
app_mode: ["saas"],
|
||||
questionKey: I18nKey.ONBOARDING$ROLE_QUESTION,
|
||||
subtitleKey: I18nKey.ONBOARDING$SELECT_ONE_SUBTITLE,
|
||||
answerOptions: [
|
||||
{
|
||||
key: I18nKey.ONBOARDING$ROLE_SOFTWARE_ENGINEER,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { useMutation } from "@tanstack/react-query";
|
||||
import { useNavigate } from "react-router";
|
||||
import { openHands } from "#/api/open-hands-axios";
|
||||
import { displayErrorToast } from "#/utils/custom-toast-handlers";
|
||||
|
||||
type SubmitOnboardingArgs = {
|
||||
@@ -9,18 +8,14 @@ type SubmitOnboardingArgs = {
|
||||
|
||||
export const useSubmitOnboarding = () => {
|
||||
const navigate = useNavigate();
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
return useMutation({
|
||||
mutationFn: async ({ selections }: SubmitOnboardingArgs) => {
|
||||
// Mark onboarding as complete
|
||||
await openHands.post("/api/complete_onboarding");
|
||||
return { selections };
|
||||
},
|
||||
mutationFn: async ({ selections }: SubmitOnboardingArgs) =>
|
||||
// TODO: mark onboarding as complete
|
||||
// TODO: persist user responses
|
||||
({ selections }),
|
||||
onSuccess: () => {
|
||||
queryClient.invalidateQueries({ queryKey: ["settings"] });
|
||||
|
||||
const finalRedirectUrl = "/";
|
||||
const finalRedirectUrl = "/"; // TODO: use redirect url from api response
|
||||
// Check if the redirect URL is an external URL (starts with http or https)
|
||||
if (
|
||||
finalRedirectUrl.startsWith("http://") ||
|
||||
|
||||
@@ -5,7 +5,6 @@ import { organizationService } from "#/api/organization-service/organization-ser
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
import { I18nKey } from "#/i18n/declaration";
|
||||
import { displaySuccessToast } from "#/utils/custom-toast-handlers";
|
||||
import { setSelectedOrg } from "#/utils/local-storage";
|
||||
|
||||
export const useSwitchOrganization = () => {
|
||||
const { t } = useTranslation();
|
||||
@@ -34,8 +33,6 @@ export const useSwitchOrganization = () => {
|
||||
// Update local state - this triggers automatic refetch for all org-scoped queries
|
||||
// since their query keys include organizationId (e.g., ["settings", orgId], ["secrets", orgId])
|
||||
setOrganizationId(orgId);
|
||||
// Broadcast org change to other apps (e.g. Automations) via localStorage
|
||||
setSelectedOrg(orgId);
|
||||
// Invalidate conversations to fetch data for the new org context
|
||||
queryClient.invalidateQueries({ queryKey: ["user", "conversations"] });
|
||||
// Remove all individual conversation queries to clear any stale/null data
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
/**
|
||||
* Centralized query keys and cache configuration for TanStack Query.
|
||||
* Using constants ensures type safety and prevents typos.
|
||||
*/
|
||||
|
||||
export const QUERY_KEYS = {
|
||||
/** Web client configuration from the server */
|
||||
WEB_CLIENT_CONFIG: ["web-client-config"] as const,
|
||||
} as const;
|
||||
|
||||
/** Cache configuration shared across all config-related queries */
|
||||
export const CONFIG_CACHE_OPTIONS = {
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
} as const;
|
||||
|
||||
export type QueryKeys = (typeof QUERY_KEYS)[keyof typeof QUERY_KEYS];
|
||||
@@ -1,10 +1,21 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
|
||||
const fetchAiConfigOptions = async () => {
|
||||
const modelsResponse = await OptionService.getModels();
|
||||
return {
|
||||
models: modelsResponse.models,
|
||||
verifiedModels: modelsResponse.verified_models,
|
||||
verifiedProviders: modelsResponse.verified_providers,
|
||||
defaultModel: modelsResponse.default_model,
|
||||
securityAnalyzers: await OptionService.getSecurityAnalyzers(),
|
||||
};
|
||||
};
|
||||
|
||||
export const useAIConfigOptions = () =>
|
||||
useQuery({
|
||||
queryKey: ["ai-config-options"],
|
||||
queryFn: OptionService.getSecurityAnalyzers,
|
||||
staleTime: 1000 * 60 * 5,
|
||||
gcTime: 1000 * 60 * 15,
|
||||
queryFn: fetchAiConfigOptions,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
});
|
||||
|
||||
@@ -6,9 +6,6 @@ import { useUserProviders } from "../use-user-providers";
|
||||
import { Provider } from "#/types/settings";
|
||||
import { shouldUseInstallationRepos } from "#/utils/utils";
|
||||
|
||||
/**
|
||||
* Get the first page of app installations for the provider given.
|
||||
*/
|
||||
export const useAppInstallations = (selectedProvider: Provider | null) => {
|
||||
const { data: config } = useConfig();
|
||||
const { data: userIsAuthenticated } = useIsAuthed();
|
||||
@@ -16,7 +13,7 @@ export const useAppInstallations = (selectedProvider: Provider | null) => {
|
||||
|
||||
return useQuery({
|
||||
queryKey: ["installations", providers || [], selectedProvider],
|
||||
queryFn: () => GitService.getUserInstallations(selectedProvider!),
|
||||
queryFn: () => GitService.getUserInstallationIds(selectedProvider!),
|
||||
enabled:
|
||||
userIsAuthenticated &&
|
||||
!!selectedProvider &&
|
||||
|
||||
@@ -30,11 +30,9 @@ export function useBranchData(
|
||||
provider,
|
||||
);
|
||||
|
||||
// Combine all branches from paginated data - use .items for V1 response
|
||||
// Combine all branches from paginated data
|
||||
const allBranches = useMemo(
|
||||
() =>
|
||||
branchData?.pages?.flatMap((page: { items: Branch[] }) => page.items) ||
|
||||
[],
|
||||
() => branchData?.pages?.flatMap((page) => page.branches) || [],
|
||||
[branchData],
|
||||
);
|
||||
|
||||
@@ -42,7 +40,7 @@ export function useBranchData(
|
||||
const defaultBranchInLoaded = useMemo(
|
||||
() =>
|
||||
defaultBranch
|
||||
? allBranches.find((branch: Branch) => branch.name === defaultBranch)
|
||||
? allBranches.find((branch) => branch.name === defaultBranch)
|
||||
: null,
|
||||
[allBranches, defaultBranch],
|
||||
);
|
||||
@@ -77,7 +75,7 @@ export function useBranchData(
|
||||
if (defaultBranch) {
|
||||
// Use the already computed defaultBranchInLoaded or check in current branches
|
||||
let defaultBranchObj = shouldUseSearch
|
||||
? branchesToUse.find((branch: Branch) => branch.name === defaultBranch)
|
||||
? branchesToUse.find((branch) => branch.name === defaultBranch)
|
||||
: defaultBranchInLoaded;
|
||||
|
||||
// If not found in current branches, check if we have it from the default branch search
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { useIsOnIntermediatePage } from "#/hooks/use-is-on-intermediate-page";
|
||||
import { QUERY_KEYS, CONFIG_CACHE_OPTIONS } from "./query-keys";
|
||||
|
||||
interface UseConfigOptions {
|
||||
enabled?: boolean;
|
||||
@@ -11,9 +10,10 @@ export const useConfig = (options?: UseConfigOptions) => {
|
||||
const isOnIntermediatePage = useIsOnIntermediatePage();
|
||||
|
||||
return useQuery({
|
||||
queryKey: QUERY_KEYS.WEB_CLIENT_CONFIG,
|
||||
queryKey: ["web-client-config"],
|
||||
queryFn: OptionService.getConfig,
|
||||
...CONFIG_CACHE_OPTIONS,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes,
|
||||
enabled: options?.enabled ?? !isOnIntermediatePage,
|
||||
});
|
||||
};
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { useInfiniteQuery, InfiniteData } from "@tanstack/react-query";
|
||||
import { useInfiniteQuery } from "@tanstack/react-query";
|
||||
import { useConfig } from "./use-config";
|
||||
import { useUserProviders } from "../use-user-providers";
|
||||
import { useAppInstallations } from "./use-app-installations";
|
||||
import { RepositoryPage } from "../../types/git";
|
||||
import { GitRepository } from "../../types/git";
|
||||
import { Provider } from "../../types/settings";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import { shouldUseInstallationRepos } from "#/utils/utils";
|
||||
@@ -13,27 +13,29 @@ interface UseGitRepositoriesOptions {
|
||||
enabled?: boolean;
|
||||
}
|
||||
|
||||
type InstallationCursor = { installationIndex: number; pageId: string | null };
|
||||
type UserCursor = string | null;
|
||||
type Cursor = InstallationCursor | UserCursor;
|
||||
interface UserRepositoriesResponse {
|
||||
data: GitRepository[];
|
||||
nextPage: number | null;
|
||||
}
|
||||
|
||||
interface InstallationRepositoriesResponse {
|
||||
data: GitRepository[];
|
||||
nextPage: number | null;
|
||||
installationIndex: number | null;
|
||||
}
|
||||
|
||||
export function useGitRepositories(options: UseGitRepositoriesOptions) {
|
||||
const { provider, pageSize = 30, enabled = true } = options;
|
||||
const { providers } = useUserProviders();
|
||||
const { data: config } = useConfig();
|
||||
const { data: page } = useAppInstallations(provider);
|
||||
const installations = page?.items;
|
||||
const { data: installations } = useAppInstallations(provider);
|
||||
|
||||
const useInstallationRepos = provider
|
||||
? shouldUseInstallationRepos(provider, config?.app_mode)
|
||||
: false;
|
||||
|
||||
const repos = useInfiniteQuery<
|
||||
RepositoryPage,
|
||||
Error,
|
||||
InfiniteData<RepositoryPage>,
|
||||
[string, string[], Provider | null, boolean, number, ...unknown[]],
|
||||
Cursor
|
||||
UserRepositoriesResponse | InstallationRepositoriesResponse
|
||||
>({
|
||||
queryKey: [
|
||||
"repositories",
|
||||
@@ -49,52 +51,56 @@ export function useGitRepositories(options: UseGitRepositoriesOptions) {
|
||||
}
|
||||
|
||||
if (useInstallationRepos) {
|
||||
const { repoPage, installationIndex } = pageParam as {
|
||||
installationIndex: number | null;
|
||||
repoPage: number | null;
|
||||
};
|
||||
|
||||
if (!installations) {
|
||||
throw new Error("Missing installation list");
|
||||
}
|
||||
|
||||
const cursor = pageParam as InstallationCursor;
|
||||
const result = await GitService.retrieveInstallationRepositories(
|
||||
return GitService.retrieveInstallationRepositories(
|
||||
provider,
|
||||
cursor.installationIndex,
|
||||
installationIndex || 0,
|
||||
installations,
|
||||
cursor.pageId ?? undefined,
|
||||
repoPage || 1,
|
||||
pageSize,
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
const cursor = pageParam as UserCursor;
|
||||
const result = await GitService.retrieveUserGitRepositories(
|
||||
return GitService.retrieveUserGitRepositories(
|
||||
provider,
|
||||
cursor ?? undefined,
|
||||
pageParam as number,
|
||||
pageSize,
|
||||
);
|
||||
return result;
|
||||
},
|
||||
getNextPageParam: (lastPage, allPages, lastPageParam) => {
|
||||
if (useInstallationRepos && installations) {
|
||||
// Installation-based pagination
|
||||
const currentCursor = lastPageParam as InstallationCursor;
|
||||
if (lastPage.next_page_id) {
|
||||
getNextPageParam: (lastPage) => {
|
||||
if (useInstallationRepos) {
|
||||
const installationPage = lastPage as InstallationRepositoriesResponse;
|
||||
if (installationPage.nextPage) {
|
||||
return {
|
||||
installationIndex: currentCursor.installationIndex,
|
||||
pageId: lastPage.next_page_id,
|
||||
installationIndex: installationPage.installationIndex,
|
||||
repoPage: installationPage.nextPage,
|
||||
};
|
||||
}
|
||||
// Advance to next installation
|
||||
const nextInstallationIndex = currentCursor.installationIndex + 1;
|
||||
if (nextInstallationIndex < installations.length) {
|
||||
return { installationIndex: nextInstallationIndex, pageId: null };
|
||||
|
||||
if (installationPage.installationIndex !== null) {
|
||||
return {
|
||||
installationIndex: installationPage.installationIndex,
|
||||
repoPage: 1,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
|
||||
return null;
|
||||
}
|
||||
// User repositories pagination
|
||||
return lastPage.next_page_id;
|
||||
|
||||
const userPage = lastPage as UserRepositoriesResponse;
|
||||
return userPage.nextPage;
|
||||
},
|
||||
initialPageParam: useInstallationRepos
|
||||
? { installationIndex: 0, pageId: null }
|
||||
: null,
|
||||
? { installationIndex: 0, repoPage: 1 }
|
||||
: 1,
|
||||
enabled:
|
||||
enabled &&
|
||||
(providers || []).length > 0 &&
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import ConfigService from "#/api/config-service/config-service.api";
|
||||
import type { LLMModel } from "#/api/config-service/config-service.types";
|
||||
|
||||
const MAX_PAGINATION_DEPTH = 10;
|
||||
|
||||
async function fetchPage(
|
||||
provider: string,
|
||||
pageId?: string,
|
||||
depth = 0,
|
||||
): Promise<LLMModel[]> {
|
||||
if (depth >= MAX_PAGINATION_DEPTH) {
|
||||
throw new Error(`Too many pagination requests for provider ${provider}`);
|
||||
}
|
||||
|
||||
const page = await ConfigService.searchModels({
|
||||
provider__eq: provider,
|
||||
limit: 100,
|
||||
page_id: pageId,
|
||||
});
|
||||
|
||||
if (page.next_page_id) {
|
||||
const rest = await fetchPage(provider, page.next_page_id, depth + 1);
|
||||
return [...page.items, ...rest];
|
||||
}
|
||||
return page.items;
|
||||
}
|
||||
|
||||
export const useProviderModels = (provider: string | null) =>
|
||||
useQuery({
|
||||
queryKey: ["config", "models", provider],
|
||||
queryFn: () => fetchPage(provider!),
|
||||
enabled: !!provider,
|
||||
staleTime: 1000 * 60 * 5,
|
||||
gcTime: 1000 * 60 * 15,
|
||||
});
|
||||
@@ -1,20 +1,35 @@
|
||||
import { useInfiniteQuery, InfiniteData } from "@tanstack/react-query";
|
||||
import { useQuery, useInfiniteQuery } from "@tanstack/react-query";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import { BranchPage } from "#/types/git";
|
||||
import { Branch, PaginatedBranchesResponse } from "#/types/git";
|
||||
import { Provider } from "#/types/settings";
|
||||
|
||||
export const useRepositoryBranches = (
|
||||
repository: string | null,
|
||||
selectedProvider?: Provider,
|
||||
) =>
|
||||
useQuery<Branch[]>({
|
||||
queryKey: ["repository", repository, "branches", selectedProvider],
|
||||
queryFn: async () => {
|
||||
if (!repository) return [];
|
||||
const response = await GitService.getRepositoryBranches(
|
||||
repository,
|
||||
1,
|
||||
30,
|
||||
selectedProvider,
|
||||
);
|
||||
// Ensure we return an array even if the response is malformed
|
||||
return Array.isArray(response.branches) ? response.branches : [];
|
||||
},
|
||||
enabled: !!repository,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
});
|
||||
|
||||
export const useRepositoryBranchesPaginated = (
|
||||
repository: string | null,
|
||||
perPage: number = 30,
|
||||
selectedProvider?: Provider,
|
||||
) => {
|
||||
const result = useInfiniteQuery<
|
||||
BranchPage,
|
||||
Error,
|
||||
InfiniteData<BranchPage>,
|
||||
[string, string | null, ...unknown[]],
|
||||
string | null
|
||||
>({
|
||||
) =>
|
||||
useInfiniteQuery<PaginatedBranchesResponse, Error>({
|
||||
queryKey: [
|
||||
"repository",
|
||||
repository,
|
||||
@@ -23,29 +38,27 @@ export const useRepositoryBranchesPaginated = (
|
||||
perPage,
|
||||
selectedProvider,
|
||||
],
|
||||
queryFn: async ({ pageParam }) => {
|
||||
if (!repository || !selectedProvider) {
|
||||
queryFn: async ({ pageParam = 1 }) => {
|
||||
if (!repository) {
|
||||
return {
|
||||
items: [],
|
||||
next_page_id: null,
|
||||
branches: [],
|
||||
has_next_page: false,
|
||||
current_page: 1,
|
||||
per_page: perPage,
|
||||
total_count: 0,
|
||||
};
|
||||
}
|
||||
return GitService.getRepositoryBranches(
|
||||
repository,
|
||||
selectedProvider,
|
||||
"", // query (empty = list all)
|
||||
pageParam ?? undefined,
|
||||
pageParam as number,
|
||||
perPage,
|
||||
selectedProvider,
|
||||
);
|
||||
},
|
||||
enabled: !!repository && !!selectedProvider,
|
||||
enabled: !!repository,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
getNextPageParam: (lastPage) =>
|
||||
lastPage.next_page_id ? lastPage.next_page_id : undefined,
|
||||
initialPageParam: null,
|
||||
// Use the has_next_page flag from the API response
|
||||
lastPage.has_next_page ? lastPage.current_page + 1 : undefined,
|
||||
initialPageParam: 1,
|
||||
});
|
||||
|
||||
return {
|
||||
...result,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -20,17 +20,15 @@ export function useSearchBranches(
|
||||
selectedProvider,
|
||||
],
|
||||
queryFn: async () => {
|
||||
if (!repository || !query || !selectedProvider) return [];
|
||||
const response = await GitService.searchRepositoryBranches(
|
||||
if (!repository || !query) return [];
|
||||
return GitService.searchRepositoryBranches(
|
||||
repository,
|
||||
selectedProvider,
|
||||
query,
|
||||
undefined, // pageId
|
||||
perPage,
|
||||
selectedProvider,
|
||||
);
|
||||
return response.items;
|
||||
},
|
||||
enabled: !!repository && !!query && !!selectedProvider,
|
||||
enabled: !!repository && !!query,
|
||||
staleTime: 1000 * 60 * 5,
|
||||
gcTime: 1000 * 60 * 15,
|
||||
});
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import ConfigService from "#/api/config-service/config-service.api";
|
||||
import type { LLMProvider } from "#/api/config-service/config-service.types";
|
||||
|
||||
async function fetchAllProviders(): Promise<LLMProvider[]> {
|
||||
// Providers are a small set; fetch all in one call with a high limit.
|
||||
const page = await ConfigService.searchProviders({ limit: 100 });
|
||||
return page.items;
|
||||
}
|
||||
|
||||
export const useSearchProviders = () =>
|
||||
useQuery({
|
||||
queryKey: ["config", "providers"],
|
||||
queryFn: fetchAllProviders,
|
||||
staleTime: 1000 * 60 * 5,
|
||||
gcTime: 1000 * 60 * 15,
|
||||
});
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import GitService from "#/api/git-service/git-service.api";
|
||||
import { GitRepository } from "#/types/git";
|
||||
import { Provider } from "#/types/settings";
|
||||
|
||||
export function useSearchRepositories(
|
||||
@@ -9,20 +8,14 @@ export function useSearchRepositories(
|
||||
disabled?: boolean,
|
||||
pageSize: number = 100,
|
||||
) {
|
||||
// For backward compatibility, return the items array directly
|
||||
return useQuery<GitRepository[]>({
|
||||
return useQuery({
|
||||
queryKey: ["repositories", "search", query, selectedProvider, pageSize],
|
||||
queryFn: async () => {
|
||||
if (!selectedProvider) {
|
||||
return [];
|
||||
}
|
||||
const response = await GitService.searchGitRepositories(
|
||||
queryFn: () =>
|
||||
GitService.searchGitRepositories(
|
||||
query,
|
||||
selectedProvider, // provider (required)
|
||||
pageSize,
|
||||
);
|
||||
return response.items;
|
||||
},
|
||||
selectedProvider || undefined,
|
||||
),
|
||||
enabled: !!query && !!selectedProvider && !disabled,
|
||||
staleTime: 1000 * 60 * 5, // 5 minutes
|
||||
gcTime: 1000 * 60 * 15, // 15 minutes
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import { useMemo } from "react";
|
||||
import { useConfig } from "#/hooks/query/use-config";
|
||||
|
||||
/**
|
||||
* Hook that provides boolean checks for app mode deployment mode.
|
||||
*
|
||||
* App Mode (app_mode):
|
||||
* - "oss": Open source version running locally/self-hosted
|
||||
* - "saas": All-Hands managed SaaS version
|
||||
*
|
||||
* Deployment Mode (deployment_mode):
|
||||
* - "cloud": Enterprise customers running on All-Hands managed infrastructure (*.all-hands.dev, *.openhands.ai)
|
||||
* - "self_hosted": Enterprise customers running on their own infrastructure
|
||||
*
|
||||
* Note: SaaS mode can have either cloud or self_hosted deployment mode.
|
||||
*/
|
||||
export function useAppMode() {
|
||||
const { data: config } = useConfig();
|
||||
|
||||
return useMemo(() => {
|
||||
const appMode = config?.app_mode;
|
||||
const deploymentMode = config?.feature_flags?.deployment_mode;
|
||||
|
||||
return {
|
||||
// App Mode checks
|
||||
isOss: appMode === "oss",
|
||||
isSaas: appMode === "saas",
|
||||
|
||||
// Deployment Mode checks
|
||||
isCloud: deploymentMode === "cloud",
|
||||
isSelfHosted: deploymentMode === "self_hosted",
|
||||
|
||||
/** Enterprise checks */
|
||||
isEnterpriseSelfHosted:
|
||||
appMode === "saas" && deploymentMode === "self_hosted",
|
||||
isEnterpriseCloud: appMode === "saas" && deploymentMode === "cloud",
|
||||
|
||||
// Raw values (for cases where the actual value is needed)
|
||||
appMode,
|
||||
deploymentMode,
|
||||
};
|
||||
}, [config?.app_mode, config?.feature_flags?.deployment_mode]);
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from "react";
|
||||
import { useSelectedOrganizationId } from "#/context/use-selected-organization";
|
||||
import { useOrganizations } from "#/hooks/query/use-organizations";
|
||||
import { setSelectedOrg } from "#/utils/local-storage";
|
||||
|
||||
/**
|
||||
* Hook that automatically selects an organization when:
|
||||
@@ -29,8 +28,6 @@ export function useAutoSelectOrganization() {
|
||||
// Revalidation is only needed when user explicitly switches organizations
|
||||
// to redirect away from admin-only pages they may no longer have access to.
|
||||
setOrganizationId(initialOrgId, { skipRevalidation: true });
|
||||
// Broadcast org selection to other apps (e.g. Automations) via localStorage
|
||||
setSelectedOrg(initialOrgId);
|
||||
}
|
||||
}, [organizationId, organizations, currentOrgId, setOrganizationId]);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
import { useMemo } from "react";
|
||||
|
||||
/**
|
||||
* Hook to check if the current domain is an All Hands SaaS environment
|
||||
* @returns True if the current domain contains "all-hands.dev" or "openhands.dev" postfix
|
||||
*/
|
||||
export const useIsAllHandsSaaSEnvironment = (): boolean =>
|
||||
useMemo(() => {
|
||||
const { hostname } = window.location;
|
||||
return (
|
||||
hostname.endsWith("all-hands.dev") || hostname.endsWith("openhands.dev")
|
||||
);
|
||||
}, []);
|
||||
@@ -470,7 +470,6 @@ export enum I18nKey {
|
||||
SIDEBAR$NAVIGATION_LABEL = "SIDEBAR$NAVIGATION_LABEL",
|
||||
FEEDBACK$PUBLIC_LABEL = "FEEDBACK$PUBLIC_LABEL",
|
||||
FEEDBACK$PRIVATE_LABEL = "FEEDBACK$PRIVATE_LABEL",
|
||||
SIDEBAR$AUTOMATIONS = "SIDEBAR$AUTOMATIONS",
|
||||
SIDEBAR$CONVERSATIONS = "SIDEBAR$CONVERSATIONS",
|
||||
STATUS$CONNECTING_TO_RUNTIME = "STATUS$CONNECTING_TO_RUNTIME",
|
||||
STATUS$STARTING_RUNTIME = "STATUS$STARTING_RUNTIME",
|
||||
@@ -1134,14 +1133,14 @@ export enum I18nKey {
|
||||
ONBOARDING$ORG_NAME_INPUT_NAME = "ONBOARDING$ORG_NAME_INPUT_NAME",
|
||||
ONBOARDING$ORG_NAME_INPUT_DOMAIN = "ONBOARDING$ORG_NAME_INPUT_DOMAIN",
|
||||
ONBOARDING$ORG_SIZE_QUESTION = "ONBOARDING$ORG_SIZE_QUESTION",
|
||||
ONBOARDING$SELECT_ONE_SUBTITLE = "ONBOARDING$SELECT_ONE_SUBTITLE",
|
||||
ONBOARDING$ORG_SIZE_SUBTITLE = "ONBOARDING$ORG_SIZE_SUBTITLE",
|
||||
ONBOARDING$ORG_SIZE_SOLO = "ONBOARDING$ORG_SIZE_SOLO",
|
||||
ONBOARDING$ORG_SIZE_2_10 = "ONBOARDING$ORG_SIZE_2_10",
|
||||
ONBOARDING$ORG_SIZE_11_50 = "ONBOARDING$ORG_SIZE_11_50",
|
||||
ONBOARDING$ORG_SIZE_51_200 = "ONBOARDING$ORG_SIZE_51_200",
|
||||
ONBOARDING$ORG_SIZE_200_PLUS = "ONBOARDING$ORG_SIZE_200_PLUS",
|
||||
ONBOARDING$USE_CASE_QUESTION = "ONBOARDING$USE_CASE_QUESTION",
|
||||
ONBOARDING$SELECT_MULTIPLE_SUBTITLE = "ONBOARDING$SELECT_MULTIPLE_SUBTITLE",
|
||||
ONBOARDING$USE_CASE_SUBTITLE = "ONBOARDING$USE_CASE_SUBTITLE",
|
||||
ONBOARDING$USE_CASE_NEW_FEATURES = "ONBOARDING$USE_CASE_NEW_FEATURES",
|
||||
ONBOARDING$USE_CASE_APP_FROM_SCRATCH = "ONBOARDING$USE_CASE_APP_FROM_SCRATCH",
|
||||
ONBOARDING$USE_CASE_FIXING_BUGS = "ONBOARDING$USE_CASE_FIXING_BUGS",
|
||||
|
||||
@@ -7989,23 +7989,6 @@
|
||||
"uk": "Приватний",
|
||||
"ca": "Privat"
|
||||
},
|
||||
"SIDEBAR$AUTOMATIONS": {
|
||||
"en": "Automations",
|
||||
"zh-CN": "自动化",
|
||||
"zh-TW": "自動化",
|
||||
"de": "Automatisierungen",
|
||||
"ko-KR": "자동화",
|
||||
"no": "Automatiseringer",
|
||||
"it": "Automazioni",
|
||||
"pt": "Automatizações",
|
||||
"es": "Automatizaciones",
|
||||
"ar": "الأتمتة",
|
||||
"fr": "Automatisations",
|
||||
"tr": "Otomasyonlar",
|
||||
"ja": "自動化",
|
||||
"uk": "Автоматизації",
|
||||
"ca": "Automatitzacions"
|
||||
},
|
||||
"SIDEBAR$CONVERSATIONS": {
|
||||
"en": "Conversations",
|
||||
"ja": "会話",
|
||||
@@ -19280,7 +19263,7 @@
|
||||
"uk": "Якого розміру організація, в якій ви працюєте?",
|
||||
"ca": "De quina mida és la vostra organització?"
|
||||
},
|
||||
"ONBOARDING$SELECT_ONE_SUBTITLE": {
|
||||
"ONBOARDING$ORG_SIZE_SUBTITLE": {
|
||||
"en": "Select one",
|
||||
"ja": "1つ選択",
|
||||
"zh-CN": "选择一个",
|
||||
@@ -19399,7 +19382,7 @@
|
||||
"uk": "Для яких випадків використання ви хочете використовувати OpenHands?",
|
||||
"ca": "Per a quins casos d'ús voleu fer servir OpenHands?"
|
||||
},
|
||||
"ONBOARDING$SELECT_MULTIPLE_SUBTITLE": {
|
||||
"ONBOARDING$USE_CASE_SUBTITLE": {
|
||||
"en": "Check all that apply",
|
||||
"ja": "該当するものをすべて選択",
|
||||
"zh-CN": "选择所有适用的",
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 20 20" fill="none">
|
||||
<mask id="mask0_18339_1445" style="mask-type:luminance" maskUnits="userSpaceOnUse" x="1" y="1" width="18" height="18">
|
||||
<path d="M19 1H1V19H19V1Z" fill="white"/>
|
||||
</mask>
|
||||
<g mask="url(#mask0_18339_1445)">
|
||||
<path d="M10 18.1818V16.5454" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10 1.81836V3.45472" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M1.81824 10H3.4546" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M18.1818 10H16.5454" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M5.93359 17.1019L6.74359 15.6782" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M14.0663 2.89844L13.2563 4.32208" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2.89819 5.93359L4.32183 6.74359" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M17.1019 14.0663L15.6782 13.2563" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M5.92542 2.90625L6.7436 4.3217" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M14.0909 17.0854L13.2727 15.6699" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M17.0855 5.90918L15.67 6.72736" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M2.91455 14.0911L4.33001 13.2729" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M10 16.5455C13.615 16.5455 16.5455 13.615 16.5455 10C16.5455 6.38509 13.615 3.45459 10 3.45459C6.38509 3.45459 3.45459 6.38509 3.45459 10C3.45459 13.615 6.38509 16.5455 10 16.5455Z" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M12.5854 9.77916L8.80545 7.59461C8.63363 7.49643 8.41272 7.61916 8.41272 7.81552V12.1846C8.41272 12.381 8.62545 12.5119 8.80545 12.4055L12.5854 10.221C12.7573 10.1228 12.7573 9.86916 12.5854 9.77098V9.77916Z" fill="currentColor" stroke="currentColor" stroke-width="1.67" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 2.5 KiB |
@@ -68,120 +68,47 @@ export const resetTestHandlersMockSettings = () => {
|
||||
MOCK_USER_PREFERENCES.settings = MOCK_DEFAULT_USER_SETTINGS;
|
||||
};
|
||||
|
||||
// Mock model data used by both V0 and V1 endpoints
|
||||
const MOCK_MODELS = [
|
||||
"anthropic/claude-3.5",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
"anthropic/claude-sonnet-4-5-20250929",
|
||||
"anthropic/claude-haiku-4-5-20251001",
|
||||
"anthropic/claude-opus-4-5-20251101",
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openhands/claude-sonnet-4-20250514",
|
||||
"openhands/claude-sonnet-4-5-20250929",
|
||||
"openhands/claude-haiku-4-5-20251001",
|
||||
"openhands/claude-opus-4-5-20251101",
|
||||
"openhands/minimax-m2.7",
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
];
|
||||
|
||||
const MOCK_VERIFIED_MODELS = new Set([
|
||||
"anthropic/claude-opus-4-5-20251101",
|
||||
"anthropic/claude-sonnet-4-5-20250929",
|
||||
"openhands/claude-opus-4-5-20251101",
|
||||
"openhands/claude-sonnet-4-5-20250929",
|
||||
"openhands/minimax-m2.7",
|
||||
]);
|
||||
|
||||
const MOCK_VERIFIED_PROVIDERS = [
|
||||
"openhands",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"mistral",
|
||||
"gemini",
|
||||
"deepseek",
|
||||
"moonshot",
|
||||
"minimax",
|
||||
];
|
||||
|
||||
// --- Handlers for options/config/settings ---
|
||||
|
||||
export const SETTINGS_HANDLERS = [
|
||||
// V0 (legacy) models endpoint – still used for default_model
|
||||
http.get("/api/options/models", async () =>
|
||||
HttpResponse.json({
|
||||
models: MOCK_MODELS,
|
||||
models: [
|
||||
"anthropic/claude-3.5",
|
||||
"anthropic/claude-sonnet-4-20250514",
|
||||
"anthropic/claude-sonnet-4-5-20250929",
|
||||
"anthropic/claude-haiku-4-5-20251001",
|
||||
"anthropic/claude-opus-4-5-20251101",
|
||||
"openai/gpt-3.5-turbo",
|
||||
"openai/gpt-4o",
|
||||
"openai/gpt-4o-mini",
|
||||
"openhands/claude-sonnet-4-20250514",
|
||||
"openhands/claude-sonnet-4-5-20250929",
|
||||
"openhands/claude-haiku-4-5-20251001",
|
||||
"openhands/claude-opus-4-5-20251101",
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
],
|
||||
verified_models: [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-sonnet-4-5-20250929",
|
||||
],
|
||||
verified_providers: MOCK_VERIFIED_PROVIDERS,
|
||||
verified_providers: [
|
||||
"openhands",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"mistral",
|
||||
"gemini",
|
||||
"deepseek",
|
||||
"moonshot",
|
||||
"minimax",
|
||||
],
|
||||
default_model: "openhands/claude-opus-4-5-20251101",
|
||||
}),
|
||||
),
|
||||
|
||||
// V1 providers search
|
||||
http.get("/api/v1/config/providers/search", async ({ request }) => {
|
||||
const url = new URL(request.url);
|
||||
const query = url.searchParams.get("query")?.toLowerCase();
|
||||
const verifiedEq = url.searchParams.get("verified__eq");
|
||||
|
||||
// Build unique provider list from models
|
||||
const seen = new Set<string>();
|
||||
let providers: { name: string; verified: boolean }[] = [];
|
||||
for (const model of MOCK_MODELS) {
|
||||
const [providerName] = model.split("/");
|
||||
if (providerName && !seen.has(providerName)) {
|
||||
seen.add(providerName);
|
||||
providers.push({
|
||||
name: providerName,
|
||||
verified: MOCK_VERIFIED_PROVIDERS.includes(providerName),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (query) {
|
||||
providers = providers.filter((p) => p.name.toLowerCase().includes(query));
|
||||
}
|
||||
if (verifiedEq !== null && verifiedEq !== undefined) {
|
||||
const wantVerified = verifiedEq === "true";
|
||||
providers = providers.filter((p) => p.verified === wantVerified);
|
||||
}
|
||||
|
||||
return HttpResponse.json({ items: providers, next_page_id: null });
|
||||
}),
|
||||
|
||||
// V1 models search
|
||||
http.get("/api/v1/config/models/search", async ({ request }) => {
|
||||
const url = new URL(request.url);
|
||||
const query = url.searchParams.get("query")?.toLowerCase();
|
||||
const verifiedEq = url.searchParams.get("verified__eq");
|
||||
const providerEq = url.searchParams.get("provider__eq");
|
||||
|
||||
let models = MOCK_MODELS.map((m) => {
|
||||
const [provider, ...rest] = m.split("/");
|
||||
const name = rest.join("/");
|
||||
return {
|
||||
provider: provider || null,
|
||||
name,
|
||||
verified: MOCK_VERIFIED_MODELS.has(m),
|
||||
};
|
||||
});
|
||||
|
||||
if (providerEq) {
|
||||
models = models.filter((m) => m.provider === providerEq);
|
||||
}
|
||||
if (query) {
|
||||
models = models.filter((m) => m.name.toLowerCase().includes(query));
|
||||
}
|
||||
if (verifiedEq !== null && verifiedEq !== undefined) {
|
||||
const wantVerified = verifiedEq === "true";
|
||||
models = models.filter((m) => m.verified === wantVerified);
|
||||
}
|
||||
|
||||
return HttpResponse.json({ items: models, next_page_id: null });
|
||||
}),
|
||||
http.get("/api/options/agents", async () =>
|
||||
HttpResponse.json(["CodeActAgent", "CoActAgent"]),
|
||||
),
|
||||
|
||||
http.get("/api/options/security-analyzers", async () =>
|
||||
HttpResponse.json(["llm", "none"]),
|
||||
@@ -218,7 +145,7 @@ export const SETTINGS_HANDLERS = [
|
||||
return HttpResponse.json(config);
|
||||
}),
|
||||
|
||||
http.get("/api/v1/settings", async () => {
|
||||
http.get("/api/settings", async () => {
|
||||
await delay();
|
||||
const { settings } = MOCK_USER_PREFERENCES;
|
||||
|
||||
@@ -227,7 +154,7 @@ export const SETTINGS_HANDLERS = [
|
||||
return HttpResponse.json(settings);
|
||||
}),
|
||||
|
||||
http.post("/api/v1/settings", async ({ request }) => {
|
||||
http.post("/api/settings", async ({ request }) => {
|
||||
await delay();
|
||||
const body = await request.json();
|
||||
|
||||
|
||||
@@ -16,15 +16,14 @@ import { isBillingHidden } from "#/utils/org/billing-visibility";
|
||||
import { queryClient } from "#/query-client-config";
|
||||
import OptionService from "#/api/option-service/option-service.api";
|
||||
import { WebClientConfig } from "#/api/option-service/option.types";
|
||||
import { QUERY_KEYS, CONFIG_CACHE_OPTIONS } from "#/hooks/query/query-keys";
|
||||
import { getFirstAvailablePath } from "#/utils/settings-utils";
|
||||
|
||||
export const clientLoader = async () => {
|
||||
const config = await queryClient.fetchQuery<WebClientConfig>({
|
||||
queryKey: QUERY_KEYS.WEB_CLIENT_CONFIG,
|
||||
queryFn: OptionService.getConfig,
|
||||
...CONFIG_CACHE_OPTIONS,
|
||||
});
|
||||
let config = queryClient.getQueryData<WebClientConfig>(["web-client-config"]);
|
||||
if (!config) {
|
||||
config = await OptionService.getConfig();
|
||||
queryClient.setQueryData<WebClientConfig>(["web-client-config"], config);
|
||||
}
|
||||
|
||||
const isSaas = config?.app_mode === "saas";
|
||||
const featureFlags = config?.feature_flags;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user