mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 269e27e734 |
@@ -1,202 +0,0 @@
|
||||
---
|
||||
name: cross-repo-testing
|
||||
description: This skill should be used when the user asks to "test a cross-repo feature", "deploy a feature branch to staging", "test SDK against OH Cloud", "e2e test a cloud workspace feature", "test provider tokens", "test secrets inheritance", or when changes span the SDK and OpenHands server repos and need end-to-end validation against a staging deployment.
|
||||
triggers:
|
||||
- cross-repo
|
||||
- staging deployment
|
||||
- feature branch deploy
|
||||
- test against cloud
|
||||
- e2e cloud
|
||||
---
|
||||
|
||||
# Cross-Repo Testing: SDK ↔ OpenHands Cloud
|
||||
|
||||
How to end-to-end test features that span `OpenHands/software-agent-sdk` and `OpenHands/OpenHands` (the Cloud backend).
|
||||
|
||||
## Repository Map
|
||||
|
||||
| Repo | Role | What lives here |
|
||||
|------|------|-----------------|
|
||||
| [`software-agent-sdk`](https://github.com/OpenHands/software-agent-sdk) | Agent core | `openhands-sdk`, `openhands-workspace`, `openhands-tools` packages. `OpenHandsCloudWorkspace` lives here. |
|
||||
| [`OpenHands`](https://github.com/OpenHands/OpenHands) | Cloud backend | FastAPI server (`openhands/app_server/`), sandbox management, auth, enterprise integrations. Deployed as OH Cloud. |
|
||||
| [`deploy`](https://github.com/OpenHands/deploy) | Infrastructure | Helm charts + GitHub Actions that build the enterprise Docker image and deploy to staging/production. |
|
||||
|
||||
**Data flow:** SDK client → OH Cloud API (`/api/v1/...`) → sandbox agent-server (inside runtime container)
|
||||
|
||||
## When You Need This
|
||||
|
||||
There are **two flows** depending on which direction the dependency goes:
|
||||
|
||||
| Flow | When | Example |
|
||||
|------|------|---------|
|
||||
| **A — SDK client → new Cloud API** | The SDK calls an API that doesn't exist yet on production | `workspace.get_llm()` calling `GET /api/v1/users/me?expose_secrets=true` |
|
||||
| **B — OH server → new SDK code** | The Cloud server needs unreleased SDK packages or a new agent-server image | Server consumes a new tool, agent behavior, or workspace method from the SDK |
|
||||
|
||||
Flow A only requires deploying the server PR. Flow B requires pinning the SDK to an unreleased commit in the server PR **and** using the SDK PR's agent-server image. Both flows may apply simultaneously.
|
||||
|
||||
---
|
||||
|
||||
## Flow A: SDK Client Tests Against New Cloud API
|
||||
|
||||
Use this when the SDK calls an endpoint that only exists on the server PR branch.
|
||||
|
||||
### A1. Write and test the server-side changes
|
||||
|
||||
In the `OpenHands` repo, implement the new API endpoint(s). Run unit tests:
|
||||
|
||||
```bash
|
||||
cd OpenHands
|
||||
poetry run pytest tests/unit/app_server/test_<relevant>.py -v
|
||||
```
|
||||
|
||||
Push a PR. Wait for the **"Push Enterprise Image" (Docker) CI job** to succeed — this builds `ghcr.io/openhands/enterprise-server:sha-<COMMIT>`.
|
||||
|
||||
### A2. Write the SDK-side changes
|
||||
|
||||
In `software-agent-sdk`, implement the client code (e.g., new methods on `OpenHandsCloudWorkspace`). Run SDK unit tests:
|
||||
|
||||
```bash
|
||||
cd software-agent-sdk
|
||||
pip install -e openhands-sdk -e openhands-workspace
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
Push a PR. SDK CI is independent — it doesn't need the server changes to pass unit tests.
|
||||
|
||||
### A3. Deploy the server PR to staging
|
||||
|
||||
See [Deploying to a Staging Feature Environment](#deploying-to-a-staging-feature-environment) below.
|
||||
|
||||
### A4. Run the SDK e2e test against staging
|
||||
|
||||
See [Running E2E Tests Against Staging](#running-e2e-tests-against-staging) below.
|
||||
|
||||
---
|
||||
|
||||
## Flow B: OH Server Needs Unreleased SDK Code
|
||||
|
||||
Use this when the Cloud server depends on SDK changes that haven't been released to PyPI yet. The server's runtime containers run the `agent-server` image built from the SDK repo, so the server PR must be configured to use the SDK PR's image and packages.
|
||||
|
||||
### B1. Get the SDK PR merged (or identify the commit)
|
||||
|
||||
The SDK PR must have CI pass so its agent-server Docker image is built. The image is tagged with the **merge-commit SHA** from GitHub Actions — NOT the head-commit SHA shown in the PR.
|
||||
|
||||
Find the correct image tag:
|
||||
- Check the SDK PR description for an `AGENT_SERVER_IMAGES` section
|
||||
- Or check the "Consolidate Build Information" CI job for `"short_sha": "<tag>"`
|
||||
|
||||
### B2. Pin SDK packages to the commit in the OpenHands PR
|
||||
|
||||
In the `OpenHands` repo PR, pin all 3 SDK packages (`openhands-sdk`, `openhands-agent-server`, `openhands-tools`) to the unreleased commit and update the agent-server image tag. This involves editing 3 files and regenerating 3 lock files.
|
||||
|
||||
Follow the **`update-sdk` skill** → "Development: Pin SDK to an Unreleased Commit" section for the full procedure and file-by-file instructions.
|
||||
|
||||
### B3. Wait for the OpenHands enterprise image to build
|
||||
|
||||
Push the pinned changes. The OpenHands CI will build a new enterprise Docker image (`ghcr.io/openhands/enterprise-server:sha-<OH_COMMIT>`) that bundles the unreleased SDK. Wait for the "Push Enterprise Image" job to succeed.
|
||||
|
||||
### B4. Deploy and test
|
||||
|
||||
Follow [Deploying to a Staging Feature Environment](#deploying-to-a-staging-feature-environment) using the new OpenHands commit SHA.
|
||||
|
||||
### B5. Before merging: remove the pin
|
||||
|
||||
**CI guard:** `check-package-versions.yml` blocks merge to `main` if `[tool.poetry.dependencies]` contains `rev` fields. Before the OpenHands PR can merge, the SDK PR must be merged and released to PyPI, then the pin must be replaced with the released version number.
|
||||
|
||||
---
|
||||
|
||||
## Deploying to a Staging Feature Environment
|
||||
|
||||
The `deploy` repo creates preview environments from OpenHands PRs.
|
||||
|
||||
**Option A — GitHub Actions UI (preferred):**
|
||||
Go to `OpenHands/deploy` → Actions → "Create OpenHands preview PR" → enter the OpenHands PR number. This creates a branch `ohpr-<PR>-<random>` and opens a deploy PR.
|
||||
|
||||
**Option B — Update an existing feature branch:**
|
||||
```bash
|
||||
cd deploy
|
||||
git checkout ohpr-<PR>-<random>
|
||||
# In .github/workflows/deploy.yaml, update BOTH:
|
||||
# OPENHANDS_SHA: "<full-40-char-commit>"
|
||||
# OPENHANDS_RUNTIME_IMAGE_TAG: "<same-commit>-nikolaik"
|
||||
git commit -am "Update OPENHANDS_SHA to <commit>" && git push
|
||||
```
|
||||
|
||||
**Before updating the SHA**, verify the enterprise Docker image exists:
|
||||
```bash
|
||||
gh api repos/OpenHands/OpenHands/actions/runs \
|
||||
--jq '.workflow_runs[] | select(.head_sha=="<COMMIT>") | "\(.name): \(.conclusion)"' \
|
||||
| grep Docker
|
||||
# Must show: "Docker: success"
|
||||
```
|
||||
|
||||
The deploy CI auto-triggers and creates the environment at:
|
||||
```
|
||||
https://ohpr-<PR>-<random>.staging.all-hands.dev
|
||||
```
|
||||
|
||||
**Wait for it to be live:**
|
||||
```bash
|
||||
curl -s -o /dev/null -w "%{http_code}" https://ohpr-<PR>-<random>.staging.all-hands.dev/api/v1/health
|
||||
# 401 = server is up (auth required). DNS may take 1-2 min on first deploy.
|
||||
```
|
||||
|
||||
## Running E2E Tests Against Staging
|
||||
|
||||
**Critical: Feature deployments have their own Keycloak instance.** API keys from `app.all-hands.dev` or `$OPENHANDS_API_KEY` will NOT work. You need a test API key issued by the specific feature deployment's Keycloak.
|
||||
|
||||
**You (the agent) cannot obtain this key yourself** — the feature environment requires interactive browser login with credentials you do not have. You must **ask the user** to:
|
||||
1. Log in to the feature deployment at `https://ohpr-<PR>-<random>.staging.all-hands.dev` in their browser
|
||||
2. Generate a test API key from the UI
|
||||
3. Provide the key to you so you can proceed with e2e testing
|
||||
|
||||
Do **not** attempt to log in via the browser or guess credentials. Wait for the user to supply the key before running any e2e tests.
|
||||
|
||||
```python
|
||||
from openhands.workspace import OpenHandsCloudWorkspace
|
||||
|
||||
STAGING = "https://ohpr-<PR>-<random>.staging.all-hands.dev"
|
||||
|
||||
with OpenHandsCloudWorkspace(
|
||||
cloud_api_url=STAGING,
|
||||
cloud_api_key="<test-api-key-for-this-deployment>",
|
||||
) as workspace:
|
||||
# Test the new feature
|
||||
llm = workspace.get_llm()
|
||||
secrets = workspace.get_secrets()
|
||||
print(f"LLM: {llm.model}, secrets: {list(secrets.keys())}")
|
||||
```
|
||||
|
||||
Or run an example script:
|
||||
```bash
|
||||
OPENHANDS_CLOUD_API_KEY="<key>" \
|
||||
OPENHANDS_CLOUD_API_URL="https://ohpr-<PR>-<random>.staging.all-hands.dev" \
|
||||
python examples/02_remote_agent_server/10_cloud_workspace_saas_credentials.py
|
||||
```
|
||||
|
||||
### Recording results
|
||||
|
||||
Both repos support a `.pr/` directory for temporary PR artifacts (design docs, test logs, scripts). These files are automatically removed when the PR is approved — see `.github/workflows/pr-artifacts.yml` and the "PR-Specific Artifacts" section in each repo's `AGENTS.md`.
|
||||
|
||||
Push test output to the `.pr/logs/` directory of whichever repo you're working in:
|
||||
```bash
|
||||
mkdir -p .pr/logs
|
||||
python test_script.py 2>&1 | tee .pr/logs/<test_name>.log
|
||||
git add -f .pr/logs/
|
||||
git commit -m "docs: add e2e test results" && git push
|
||||
```
|
||||
|
||||
Comment on **both PRs** with pass/fail summary and link to logs.
|
||||
|
||||
## Key Gotchas
|
||||
|
||||
| Gotcha | Details |
|
||||
|--------|---------|
|
||||
| **Feature env auth is isolated** | Each `ohpr-*` deployment has its own Keycloak. Production API keys don't work. Agents cannot log in — you must ask the user to provide a test API key from the feature deployment's UI. |
|
||||
| **Two SHAs in deploy.yaml** | `OPENHANDS_SHA` and `OPENHANDS_RUNTIME_IMAGE_TAG` must both be updated. The runtime tag is `<sha>-nikolaik`. |
|
||||
| **Enterprise image must exist** | The Docker CI job on the OpenHands PR must succeed before you can deploy. If it hasn't run, push an empty commit to trigger it. |
|
||||
| **DNS propagation** | First deployment of a new branch takes 1-2 min for DNS. Subsequent deploys are instant. |
|
||||
| **Merge-commit SHA ≠ head SHA** | SDK CI tags Docker images with GitHub Actions' merge-commit SHA, not the PR head SHA. Check the SDK PR description or CI logs for the correct tag. |
|
||||
| **SDK pin blocks merge** | `check-package-versions.yml` prevents merging an OpenHands PR that has `rev` fields in `[tool.poetry.dependencies]`. The SDK must be released to PyPI first. |
|
||||
| **Flow A: stock agent-server is fine** | When only the Cloud API changes, `OpenHandsCloudWorkspace` talks to the Cloud server, not the agent-server. No custom image needed. |
|
||||
| **Flow B: agent-server image is required** | When the server needs new SDK code inside runtime containers, you must pin to the SDK PR's agent-server image. |
|
||||
+4
-3
@@ -1,7 +1,8 @@
|
||||
# CODEOWNERS file for OpenHands repository
|
||||
# See https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners
|
||||
|
||||
/frontend/ @hieptl
|
||||
/openhands-ui/ @hieptl
|
||||
/frontend/ @amanape @hieptl
|
||||
/openhands-ui/ @amanape @hieptl
|
||||
/openhands/ @tofarr @malhotra5 @hieptl
|
||||
/enterprise/ @chuckbutkus @tofarr @malhotra5 @jlav @aivong-openhands
|
||||
/enterprise/ @chuckbutkus @tofarr @malhotra5
|
||||
/evaluation/ @xingyaoww @neubig
|
||||
|
||||
@@ -4,7 +4,7 @@ updates:
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 5
|
||||
open-pull-requests-limit: 1
|
||||
groups:
|
||||
# put packages in their own group if they have a history of breaking the build or needing to be reverted
|
||||
pre-commit:
|
||||
@@ -29,7 +29,7 @@ updates:
|
||||
directory: "/frontend"
|
||||
schedule:
|
||||
interval: "daily"
|
||||
open-pull-requests-limit: 5
|
||||
open-pull-requests-limit: 1
|
||||
groups:
|
||||
docusaurus:
|
||||
patterns:
|
||||
@@ -51,7 +51,7 @@ updates:
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
day: "wednesday"
|
||||
open-pull-requests-limit: 5
|
||||
open-pull-requests-limit: 1
|
||||
groups:
|
||||
docusaurus:
|
||||
patterns:
|
||||
@@ -72,11 +72,9 @@ updates:
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
|
||||
- package-ecosystem: "docker"
|
||||
directories:
|
||||
- "containers/*"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 5
|
||||
|
||||
@@ -12,7 +12,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -19,7 +19,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install poetry via pipx
|
||||
uses: abatilo/actions-poetry@v4
|
||||
|
||||
@@ -10,7 +10,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
fetch-depth: 0
|
||||
@@ -34,7 +34,7 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Find Comment
|
||||
uses: peter-evans/find-comment@v4
|
||||
uses: peter-evans/find-comment@v3
|
||||
id: find-comment
|
||||
with:
|
||||
issue-number: ${{ github.event.pull_request.number }}
|
||||
|
||||
@@ -24,7 +24,7 @@ jobs:
|
||||
fail-fast: true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Node.js
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
fail-fast: true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Node.js
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
|
||||
@@ -33,39 +33,34 @@ jobs:
|
||||
runs-on: blacksmith
|
||||
outputs:
|
||||
base_image: ${{ steps.define-base-images.outputs.base_image }}
|
||||
platforms: ${{ steps.define-base-images.outputs.platforms }}
|
||||
steps:
|
||||
- name: Define base images
|
||||
shell: bash
|
||||
id: define-base-images
|
||||
run: |
|
||||
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]]; then
|
||||
platforms="linux/amd64"
|
||||
json=$(jq -n -c --arg platforms "$platforms" '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik", platforms: $platforms }
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" }
|
||||
]')
|
||||
else
|
||||
platforms="linux/amd64,linux/arm64"
|
||||
json=$(jq -n -c --arg platforms "$platforms" '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22-slim", tag: "nikolaik", platforms: $platforms },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu", platforms: $platforms }
|
||||
json=$(jq -n -c '[
|
||||
{ image: "nikolaik/python-nodejs:python3.12-nodejs22", tag: "nikolaik" },
|
||||
{ image: "ubuntu:24.04", tag: "ubuntu" }
|
||||
]')
|
||||
fi
|
||||
echo "base_image=$json" >> "$GITHUB_OUTPUT"
|
||||
echo "platforms=$platforms" >> "$GITHUB_OUTPUT"
|
||||
|
||||
# Builds the OpenHands Docker images
|
||||
ghcr_build_app:
|
||||
name: Build App Image
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
if: "!(github.event_name == 'push' && startsWith(github.ref, 'refs/tags/ext-v'))"
|
||||
needs: define-matrix
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
@@ -87,7 +82,7 @@ jobs:
|
||||
- name: Build and push app image
|
||||
if: "!github.event.pull_request.head.repo.fork"
|
||||
run: |
|
||||
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --push -p ${{ needs.define-matrix.outputs.platforms }}
|
||||
./containers/build.sh -i openhands -o ${{ env.REPO_OWNER }} --push
|
||||
|
||||
# Builds the runtime Docker images
|
||||
ghcr_build_runtime:
|
||||
@@ -103,7 +98,7 @@ jobs:
|
||||
base_image: ${{ fromJson(needs.define-matrix.outputs.base_image) }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: Set up QEMU
|
||||
@@ -141,7 +136,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
./containers/build.sh -i runtime -o ${{ env.REPO_OWNER }} -t ${{ matrix.base_image.tag }} --dry -p ${{ matrix.base_image.platforms }}
|
||||
./containers/build.sh -i runtime -o ${{ env.REPO_OWNER }} -t ${{ matrix.base_image.tag }} --dry
|
||||
|
||||
DOCKER_BUILD_JSON=$(jq -c . < docker-build-dry.json)
|
||||
echo "DOCKER_TAGS=$(echo "$DOCKER_BUILD_JSON" | jq -r '.tags | join(",")')" >> $GITHUB_ENV
|
||||
@@ -185,7 +180,7 @@ jobs:
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
@@ -215,7 +210,6 @@ jobs:
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
type=match,pattern=cloud-\d+\.\d+\.\d+
|
||||
flavor: |
|
||||
latest=auto
|
||||
prefix=
|
||||
@@ -225,9 +219,11 @@ jobs:
|
||||
- name: Determine app image tag
|
||||
shell: bash
|
||||
run: |
|
||||
# Use the commit SHA to pin the exact app image built by ghcr_build_app,
|
||||
# rather than a mutable branch tag like "main" which can serve stale cached layers.
|
||||
echo "OPENHANDS_DOCKER_TAG=${RELEVANT_SHA}" >> $GITHUB_ENV
|
||||
# Duplicated with build.sh
|
||||
sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g')
|
||||
OPENHANDS_BUILD_VERSION=$sanitized_ref_name
|
||||
sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging
|
||||
echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV
|
||||
- name: Build and push Docker image
|
||||
uses: useblacksmith/build-push-action@v1
|
||||
with:
|
||||
@@ -260,7 +256,7 @@ jobs:
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Get short SHA
|
||||
id: short_sha
|
||||
|
||||
@@ -14,7 +14,7 @@ jobs:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.head_ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.head_ref }}
|
||||
repository: ${{ github.event.pull_request.head.repo.full_name }}
|
||||
|
||||
@@ -21,7 +21,7 @@ jobs:
|
||||
name: Lint frontend
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install Node.js 22
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
name: Lint python
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up python
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
name: Lint enterprise python
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up python
|
||||
|
||||
@@ -27,7 +27,7 @@ jobs:
|
||||
current-version: ${{ steps.version-check.outputs.current-version }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # Need previous commit to compare
|
||||
|
||||
@@ -63,7 +63,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Bun
|
||||
uses: oven-sh/setup-bun@v2
|
||||
|
||||
@@ -86,7 +86,7 @@ jobs:
|
||||
runs-on: "${{ inputs.runner || 'ubuntu-latest' }}"
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
---
|
||||
name: PR Artifacts
|
||||
|
||||
on:
|
||||
workflow_dispatch: # Manual trigger for testing
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
branches: [main]
|
||||
pull_request_review:
|
||||
types: [submitted]
|
||||
|
||||
jobs:
|
||||
# Auto-remove .pr/ directory when a reviewer approves
|
||||
cleanup-on-approval:
|
||||
concurrency:
|
||||
group: cleanup-pr-artifacts-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: false
|
||||
if: github.event_name == 'pull_request_review' && github.event.review.state == 'approved'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Check if fork PR
|
||||
id: check-fork
|
||||
run: |
|
||||
if [ "${{ github.event.pull_request.head.repo.full_name }}" != "${{ github.event.pull_request.base.repo.full_name }}" ]; then
|
||||
echo "is_fork=true" >> $GITHUB_OUTPUT
|
||||
echo "::notice::Fork PR detected - skipping auto-cleanup (manual removal required)"
|
||||
else
|
||||
echo "is_fork=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- uses: actions/checkout@v5
|
||||
if: steps.check-fork.outputs.is_fork == 'false'
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.ref }}
|
||||
token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
|
||||
- name: Remove .pr/ directory
|
||||
id: remove
|
||||
if: steps.check-fork.outputs.is_fork == 'false'
|
||||
run: |
|
||||
if [ -d ".pr" ]; then
|
||||
git config user.name "allhands-bot"
|
||||
git config user.email "allhands-bot@users.noreply.github.com"
|
||||
git rm -rf .pr/
|
||||
git commit -m "chore: Remove PR-only artifacts [automated]"
|
||||
git push || {
|
||||
echo "::error::Failed to push cleanup commit. Check branch protection rules."
|
||||
exit 1
|
||||
}
|
||||
echo "removed=true" >> $GITHUB_OUTPUT
|
||||
echo "::notice::Removed .pr/ directory"
|
||||
else
|
||||
echo "removed=false" >> $GITHUB_OUTPUT
|
||||
echo "::notice::No .pr/ directory to remove"
|
||||
fi
|
||||
|
||||
- name: Update PR comment after cleanup
|
||||
if: steps.check-fork.outputs.is_fork == 'false' && steps.remove.outputs.removed == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const marker = '<!-- pr-artifacts-notice -->';
|
||||
const body = `${marker}
|
||||
✅ **PR Artifacts Cleaned Up**
|
||||
|
||||
The \`.pr/\` directory has been automatically removed.
|
||||
`;
|
||||
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
});
|
||||
|
||||
const existing = comments.find(c => c.body.includes(marker));
|
||||
if (existing) {
|
||||
await github.rest.issues.updateComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
comment_id: existing.id,
|
||||
body: body,
|
||||
});
|
||||
}
|
||||
|
||||
# Warn if .pr/ directory exists (will be auto-removed on approval)
|
||||
check-pr-artifacts:
|
||||
if: github.event_name == 'pull_request'
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
|
||||
- name: Check for .pr/ directory
|
||||
id: check
|
||||
run: |
|
||||
if [ -d ".pr" ]; then
|
||||
echo "exists=true" >> $GITHUB_OUTPUT
|
||||
echo "::warning::.pr/ directory exists and will be automatically removed when the PR is approved. For fork PRs, manual removal is required before merging."
|
||||
else
|
||||
echo "exists=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Post or update PR comment
|
||||
if: steps.check.outputs.exists == 'true'
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const marker = '<!-- pr-artifacts-notice -->';
|
||||
const body = `${marker}
|
||||
📁 **PR Artifacts Notice**
|
||||
|
||||
This PR contains a \`.pr/\` directory with PR-specific documents. This directory will be **automatically removed** when the PR is approved.
|
||||
|
||||
> For fork PRs: Manual removal is required before merging.
|
||||
`;
|
||||
|
||||
const { data: comments } = await github.rest.issues.listComments({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
});
|
||||
|
||||
const existing = comments.find(c => c.body.includes(marker));
|
||||
if (!existing) {
|
||||
await github.rest.issues.createComment({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: body,
|
||||
});
|
||||
}
|
||||
@@ -28,7 +28,7 @@ jobs:
|
||||
steps:
|
||||
- name: Download review trace artifact
|
||||
id: download-trace
|
||||
uses: dawidd6/action-download-artifact@v15
|
||||
uses: dawidd6/action-download-artifact@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
workflow: pr-review-by-openhands.yml
|
||||
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
pull-requests: write
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
id: buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -78,7 +78,7 @@ jobs:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install poetry via pipx
|
||||
run: pipx install poetry
|
||||
- name: Set up Python
|
||||
@@ -111,9 +111,9 @@ jobs:
|
||||
pull-requests: write
|
||||
contents: write
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/download-artifact@v7
|
||||
- uses: actions/download-artifact@v6
|
||||
id: download
|
||||
with:
|
||||
pattern: coverage-*
|
||||
|
||||
@@ -18,12 +18,12 @@ on:
|
||||
jobs:
|
||||
release:
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
# Run when manually dispatched for "app server" OR for tag pushes that don't contain '-cli' and don't start with 'cloud-'
|
||||
# Run when manually dispatched for "app server" OR for tag pushes that don't contain '-cli'
|
||||
if: |
|
||||
(github.event_name == 'workflow_dispatch' && github.event.inputs.reason == 'app server')
|
||||
|| (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-cli') && !startsWith(github.ref, 'refs/tags/cloud-'))
|
||||
|| (github.event_name == 'push' && startsWith(github.ref, 'refs/tags/') && !contains(github.ref, '-cli'))
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v4
|
||||
- uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: 3.12
|
||||
|
||||
@@ -11,7 +11,7 @@ jobs:
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
if: github.repository == 'OpenHands/OpenHands'
|
||||
steps:
|
||||
- uses: actions/stale@v10
|
||||
- uses: actions/stale@v9
|
||||
with:
|
||||
stale-issue-message: 'This issue is stale because it has been open for 40 days with no activity. Remove the stale label or leave a comment, otherwise it will be closed in 10 days.'
|
||||
stale-pr-message: 'This PR is stale because it has been open for 40 days with no activity. Remove the stale label or leave a comment, otherwise it will be closed in 10 days.'
|
||||
|
||||
@@ -22,7 +22,7 @@ jobs:
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
- uses: oven-sh/setup-bun@v2
|
||||
with:
|
||||
bun-version-file: "openhands-ui/.bun-version"
|
||||
|
||||
@@ -36,76 +36,6 @@ then re-run the command to ensure it passes. Common issues include:
|
||||
- Be especially careful with `git reset --hard` after staging files, as it will remove accidentally staged files
|
||||
- When remote has new changes, use `git fetch upstream && git rebase upstream/<branch>` on the same branch
|
||||
|
||||
## Lockfile Regeneration (Preserve Original Tool Versions)
|
||||
|
||||
When regenerating lockfiles (poetry.lock, uv.lock, etc.), you MUST use the same tool version that originally generated the lockfile to avoid unnecessary diff noise. Each lockfile contains a version header indicating which tool version was used.
|
||||
|
||||
### Poetry (poetry.lock)
|
||||
|
||||
1. Extract the version from the lockfile header:
|
||||
```bash
|
||||
POETRY_VERSION=$(grep -m1 "^# This file is automatically @generated by Poetry" poetry.lock | sed 's/.*Poetry \([0-9.]*\).*/\1/')
|
||||
```
|
||||
2. If a version is found, install that specific version:
|
||||
```bash
|
||||
pipx install poetry==$POETRY_VERSION --force
|
||||
```
|
||||
3. Then regenerate the lockfile:
|
||||
```bash
|
||||
poetry lock --no-update
|
||||
```
|
||||
|
||||
### uv (uv.lock)
|
||||
|
||||
1. Extract the version from the lockfile header:
|
||||
```bash
|
||||
UV_VERSION=$(grep -m1 "^# This file was autogenerated by uv" uv.lock | sed 's/.*uv version \([0-9.]*\).*/\1/')
|
||||
```
|
||||
2. If a version is found, install that specific version:
|
||||
```bash
|
||||
pipx install uv==$UV_VERSION --force
|
||||
```
|
||||
3. Then regenerate the lockfile:
|
||||
```bash
|
||||
uv lock
|
||||
```
|
||||
|
||||
This ensures that lockfile updates only contain actual dependency changes, not tool version migration artifacts.
|
||||
|
||||
## PR-Specific Artifacts (`.pr/` directory)
|
||||
|
||||
When working on a PR that requires design documents, scripts meant for development-only, or other temporary artifacts that should NOT be merged to main, store them in a `.pr/` directory at the repository root.
|
||||
|
||||
### Usage
|
||||
|
||||
```
|
||||
.pr/
|
||||
├── design.md # Design decisions and architecture notes
|
||||
├── analysis.md # Investigation or debugging notes
|
||||
├── logs/ # Test output or CI logs for reviewer reference
|
||||
└── notes.md # Any other PR-specific content
|
||||
```
|
||||
|
||||
### How It Works
|
||||
|
||||
1. **Notification**: When `.pr/` exists, a comment is posted to the PR conversation alerting reviewers
|
||||
2. **Auto-cleanup**: When the PR is approved, the `.pr/` directory is automatically removed via `.github/workflows/pr-artifacts.yml`
|
||||
3. **Fork PRs**: Auto-cleanup cannot push to forks, so manual removal is required before merging
|
||||
|
||||
### Important Notes
|
||||
|
||||
- Do NOT put anything in `.pr/` that needs to be preserved after merge
|
||||
- The `.pr/` check passes (green ✅) during development — it only posts a notification, not a blocking error
|
||||
- For fork PRs: You must manually remove `.pr/` before the PR can be merged
|
||||
|
||||
### When to Use
|
||||
|
||||
- Complex refactoring that benefits from written design rationale
|
||||
- Debugging sessions where you want to document your investigation
|
||||
- E2E test results or logs that demonstrate a cross-repo feature works
|
||||
- Feature implementations that need temporary planning docs
|
||||
- Any analysis that helps reviewers understand the PR but isn't needed long-term
|
||||
|
||||
## Repository Structure
|
||||
Backend:
|
||||
- Located in the `openhands` directory
|
||||
|
||||
@@ -125,17 +125,6 @@ For example, a PR title could be:
|
||||
- If your changes are user-facing (e.g. a new feature in the UI, a change in behavior, or a bugfix),
|
||||
please include a short message that we can add to our changelog
|
||||
|
||||
## Becoming a Maintainer
|
||||
|
||||
For contributors who have made significant and sustained contributions to the project, there is a possibility of joining the maintainer team.
|
||||
The process for this is as follows:
|
||||
|
||||
1. Any contributor who has made sustained and high-quality contributions to the codebase can be nominated by any maintainer. If you feel that you may qualify you can reach out to any of the maintainers that have reviewed your PRs and ask if you can be nominated.
|
||||
2. Once a maintainer nominates a new maintainer, there will be a discussion period among the maintainers for at least 3 days.
|
||||
3. If no concerns are raised the nomination will be accepted by acclamation, and if concerns are raised there will be a discussion and possible vote.
|
||||
|
||||
Note that just making many PRs does not immediately imply that you will become a maintainer. We will be looking at sustained high-quality contributions over a period of time, as well as good teamwork and adherence to our [Code of Conduct](./CODE_OF_CONDUCT.md).
|
||||
|
||||
## Need Help?
|
||||
|
||||
- **Slack**: [Join our community](https://openhands.dev/joinslack)
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=pt">Português</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=ru">Русский</a> |
|
||||
<a href="https://www.readme-i18n.com/OpenHands/OpenHands?lang=zh">中文</a>
|
||||
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
@@ -83,71 +84,3 @@ All our work is available under the MIT license, except for the `enterprise/` di
|
||||
The core `openhands` and `agent-server` Docker images are fully MIT-licensed as well.
|
||||
|
||||
If you need help with anything, or just want to chat, [come find us on Slack](https://dub.sh/openhands).
|
||||
|
||||
<hr>
|
||||
|
||||
### Thank You to Our Contributors
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/OpenHands/OpenHands/graphs/contributors)
|
||||
|
||||
</div>
|
||||
|
||||
<hr>
|
||||
|
||||
### Trusted by Engineers at
|
||||
|
||||
<div align="center">
|
||||
<br/><br/>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/tiktok.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/tiktok.svg" alt="TikTok" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/vmware.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/vmware.svg" alt="VMware" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/roche.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/roche.svg" alt="Roche" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/amazon.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/amazon.svg" alt="Amazon" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/c3-ai.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/c3-ai.svg" alt="C3 AI" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/netflix.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/netflix.svg" alt="Netflix" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/mastercard.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/mastercard.svg" alt="Mastercard" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/red-hat.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/red-hat.svg" alt="Red Hat" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/mongodb.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/mongodb.svg" alt="MongoDB" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/apple.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/apple.svg" alt="Apple" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/nvidia.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/nvidia.svg" alt="NVIDIA" height="17" hspace="5">
|
||||
</picture>
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://assets.openhands.dev/logos/external/white/google.svg">
|
||||
<img src="https://assets.openhands.dev/logos/external/black/google.svg" alt="Google" height="17" hspace="5">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
@@ -296,7 +296,7 @@ classpath = "my_package.my_module.MyCustomAgent"
|
||||
#user_id = 1000
|
||||
|
||||
# Container image to use for the sandbox
|
||||
#base_container_image = "nikolaik/python-nodejs:python3.12-nodejs22-slim"
|
||||
#base_container_image = "nikolaik/python-nodejs:python3.12-nodejs22"
|
||||
|
||||
# Use host network
|
||||
#use_host_network = false
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
ARG OPENHANDS_BUILD_VERSION=dev
|
||||
FROM node:25.8-trixie-slim AS frontend-builder
|
||||
FROM node:25.2-trixie-slim AS frontend-builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -50,7 +50,7 @@ RUN mkdir -p $FILE_STORE_PATH
|
||||
RUN mkdir -p $WORKSPACE_BASE
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y curl git ssh sudo \
|
||||
&& apt-get install -y curl ssh sudo \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Default is 1000, but OSX is often 501
|
||||
@@ -73,17 +73,6 @@ ENV VIRTUAL_ENV=/app/.venv \
|
||||
|
||||
COPY --chown=openhands:openhands --chmod=770 --from=backend-builder ${VIRTUAL_ENV} ${VIRTUAL_ENV}
|
||||
|
||||
# Pin pip to a known-good version (reproducible builds) and fix CVE-2025-8869
|
||||
# Pin both venv pip and system pip (Trivy scans both)
|
||||
# - `python -m pip` uses the venv because `PATH` is prefixed with `${VIRTUAL_ENV}/bin`
|
||||
# - `/usr/local/bin/python3 -m pip` uses the system interpreter regardless of `PATH`
|
||||
ARG PIP_VERSION=26.0.1
|
||||
RUN python -m pip install --no-cache-dir "pip==${PIP_VERSION}"
|
||||
|
||||
USER root
|
||||
RUN /usr/local/bin/python3 -m pip install --no-cache-dir "pip==${PIP_VERSION}" --break-system-packages
|
||||
USER openhands
|
||||
|
||||
COPY --chown=openhands:openhands --chmod=770 ./skills ./skills
|
||||
COPY --chown=openhands:openhands --chmod=770 ./openhands ./openhands
|
||||
COPY --chown=openhands:openhands --chmod=777 ./openhands/runtime/plugins ./openhands/runtime/plugins
|
||||
|
||||
+3
-8
@@ -8,17 +8,15 @@ push=0
|
||||
load=0
|
||||
tag_suffix=""
|
||||
dry_run=0
|
||||
platform_override=""
|
||||
|
||||
# Function to display usage information
|
||||
usage() {
|
||||
echo "Usage: $0 -i <image_name> [-o <org_name>] [--push] [--load] [-t <tag_suffix>] [-p <platform>] [--dry]"
|
||||
echo "Usage: $0 -i <image_name> [-o <org_name>] [--push] [--load] [-t <tag_suffix>] [--dry]"
|
||||
echo " -i: Image name (required)"
|
||||
echo " -o: Organization name"
|
||||
echo " --push: Push the image"
|
||||
echo " --load: Load the image"
|
||||
echo " -t: Tag suffix"
|
||||
echo " -p: Platform(s) to build for (e.g. linux/amd64 or linux/amd64,linux/arm64)"
|
||||
echo " --dry: Don't build, only create build-args.json"
|
||||
exit 1
|
||||
}
|
||||
@@ -31,7 +29,6 @@ while [[ $# -gt 0 ]]; do
|
||||
--push) push=1; shift ;;
|
||||
--load) load=1; shift ;;
|
||||
-t) tag_suffix="$2"; shift 2 ;;
|
||||
-p) platform_override="$2"; shift 2 ;;
|
||||
--dry) dry_run=1; shift ;;
|
||||
*) usage ;;
|
||||
esac
|
||||
@@ -137,10 +134,8 @@ fi
|
||||
|
||||
echo "Args: $args"
|
||||
|
||||
# Determine the platform(s) to build for
|
||||
if [[ -n "$platform_override" ]]; then
|
||||
platform="$platform_override"
|
||||
elif [[ $load -eq 1 ]]; then
|
||||
# Modify the platform selection based on --load flag
|
||||
if [[ $load -eq 1 ]]; then
|
||||
# When loading, build only for the current platform
|
||||
platform=$(docker version -f '{{.Server.Os}}/{{.Server.Arch}}')
|
||||
else
|
||||
|
||||
@@ -13,7 +13,7 @@ services:
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.15.0-python}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.12.0-python}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
+1
-1
@@ -8,7 +8,7 @@ services:
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.15.0-python}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.12.0-python}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -10,7 +10,7 @@ LABEL com.datadoghq.tags.env="${DD_ENV}"
|
||||
# Apply security updates to fix CVEs
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_24.x | bash - && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
apt-get install -y jq gettext && \
|
||||
# Apply security updates for packages with available fixes
|
||||
|
||||
@@ -723,13 +723,11 @@
|
||||
"https://$WEB_HOST/slack/keycloak-callback",
|
||||
"https://$WEB_HOST/oauth/device/keycloak-callback",
|
||||
"https://$WEB_HOST/api/email/verified",
|
||||
"/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*",
|
||||
"https://laminar.$WEB_HOST/api/auth/callback/keycloak"
|
||||
"/realms/$KEYCLOAK_REALM_NAME/$KEYCLOAK_CLIENT_ID/*"
|
||||
],
|
||||
"webOrigins": [
|
||||
"https://$WEB_HOST",
|
||||
"https://$AUTH_WEB_HOST",
|
||||
"https://laminar.$WEB_HOST"
|
||||
"https://$AUTH_WEB_HOST"
|
||||
],
|
||||
"notBefore": 0,
|
||||
"bearerOnly": false,
|
||||
|
||||
@@ -43,20 +43,15 @@ class GithubV1CallbackProcessor(EventCallbackProcessor):
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for GitHub V1 integration."""
|
||||
# Only handle ConversationStateUpdateEvent for execution_status
|
||||
# Only handle ConversationStateUpdateEvent
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
|
||||
if event.key != 'execution_status':
|
||||
# Only act when execution has finished
|
||||
if not (event.key == 'execution_status' and event.value == 'finished'):
|
||||
return None
|
||||
|
||||
# Log ALL terminal states for monitoring (finished, error, stuck)
|
||||
_logger.info('[GitHub V1] Callback agent state was %s', event)
|
||||
|
||||
# Only request summary when execution has finished successfully
|
||||
if event.value != 'finished':
|
||||
return None
|
||||
|
||||
_logger.info(
|
||||
'[GitHub V1] Should request summary: %s', self.should_request_summary
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from integrations.github.github_types import (
|
||||
)
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
@@ -27,7 +26,6 @@ from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
@@ -43,14 +41,16 @@ from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.services.conversation_service import start_conversation
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
@@ -154,17 +154,12 @@ class GithubIssue(ResolverViewInterface):
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
|
||||
self.v1_enabled = await is_v1_enabled_for_github_resolver(
|
||||
self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
# Resolve target org based on claimed git organizations
|
||||
self.resolved_org_id = await resolve_org_for_repo(
|
||||
provider='github',
|
||||
full_repo_name=self.full_repo_name,
|
||||
keycloak_user_id=self.user_info.keycloak_user_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[GitHub V1]: User flag found for {self.user_info.keycloak_user_id} is {self.v1_enabled}'
|
||||
)
|
||||
@@ -178,28 +173,16 @@ class GithubIssue(ResolverViewInterface):
|
||||
selected_repository=self.full_repo_name,
|
||||
)
|
||||
|
||||
# Create the conversation store with resolver org routing
|
||||
# (bypasses initialize_conversation to avoid threading enterprise-only
|
||||
# resolver_org_id through the generic OSS interface)
|
||||
store = await SaasConversationStore.get_resolver_instance(
|
||||
get_config(),
|
||||
self.user_info.keycloak_user_id,
|
||||
self.resolved_org_id,
|
||||
)
|
||||
|
||||
conversation_id = uuid4().hex
|
||||
conversation_metadata = ConversationMetadata(
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
conversation_id=conversation_id,
|
||||
title=get_default_conversation_title(conversation_id),
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
await store.save_metadata(conversation_metadata)
|
||||
|
||||
self.conversation_id = conversation_id
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
async def create_new_conversation(
|
||||
@@ -311,10 +294,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
)
|
||||
|
||||
# Set up the GitHub user context for the V1 system
|
||||
github_user_context = ResolverUserContext(
|
||||
saas_user_auth=saas_user_auth,
|
||||
resolver_org_id=self.resolved_org_id,
|
||||
)
|
||||
github_user_context = ResolverUserContext(saas_user_auth=saas_user_auth)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, github_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
@@ -342,7 +322,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
'full_repo_name': self.full_repo_name,
|
||||
'installation_id': self.installation_id,
|
||||
},
|
||||
should_request_summary=self.send_summary_instruction,
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
@@ -496,7 +476,7 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
'comment_id': self.comment_id,
|
||||
},
|
||||
inline_pr_comment=True,
|
||||
should_request_summary=self.send_summary_instruction,
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,20 +41,15 @@ class GitlabV1CallbackProcessor(EventCallbackProcessor):
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for GitLab V1 integration."""
|
||||
# Only handle ConversationStateUpdateEvent for execution_status
|
||||
# Only handle ConversationStateUpdateEvent
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
|
||||
if event.key != 'execution_status':
|
||||
# Only act when execution has finished
|
||||
if not (event.key == 'execution_status' and event.value == 'finished'):
|
||||
return None
|
||||
|
||||
# Log ALL terminal states for monitoring (finished, error, stuck)
|
||||
_logger.info('[GitLab V1] Callback agent state was %s', event)
|
||||
|
||||
# Only request summary when execution has finished successfully
|
||||
if event.value != 'finished':
|
||||
return None
|
||||
|
||||
_logger.info(
|
||||
'[GitLab V1] Should request summary: %s', self.should_request_summary
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from uuid import UUID, uuid4
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_V1_GITLAB_RESOLVER,
|
||||
@@ -15,7 +14,6 @@ from integrations.utils import (
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
@@ -31,13 +29,15 @@ from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.services.conversation_service import start_conversation
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
CONFIDENTIAL_NOTE = 'confidential_note'
|
||||
@@ -118,14 +118,6 @@ class GitlabIssue(ResolverViewInterface):
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# v1_enabled is already set at construction time in the factory method
|
||||
# This is the source of truth for the conversation type
|
||||
|
||||
# Resolve target org based on claimed git organizations
|
||||
self.resolved_org_id = await resolve_org_for_repo(
|
||||
provider='gitlab',
|
||||
full_repo_name=self.full_repo_name,
|
||||
keycloak_user_id=self.user_info.keycloak_user_id,
|
||||
)
|
||||
|
||||
if self.v1_enabled:
|
||||
# Create dummy conversation metadata
|
||||
# Don't save to conversation store
|
||||
@@ -136,28 +128,16 @@ class GitlabIssue(ResolverViewInterface):
|
||||
selected_repository=self.full_repo_name,
|
||||
)
|
||||
|
||||
# Create the conversation store with resolver org routing
|
||||
# (bypasses initialize_conversation to avoid threading enterprise-only
|
||||
# resolver_org_id through the generic OSS interface)
|
||||
store = await SaasConversationStore.get_resolver_instance(
|
||||
get_config(),
|
||||
self.user_info.keycloak_user_id,
|
||||
self.resolved_org_id,
|
||||
)
|
||||
|
||||
conversation_id = uuid4().hex
|
||||
conversation_metadata = ConversationMetadata(
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
conversation_id=conversation_id,
|
||||
title=get_default_conversation_title(conversation_id),
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITLAB,
|
||||
)
|
||||
await store.save_metadata(conversation_metadata)
|
||||
|
||||
self.conversation_id = conversation_id
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
async def create_new_conversation(
|
||||
@@ -248,10 +228,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
)
|
||||
|
||||
# Set up the GitLab user context for the V1 system
|
||||
gitlab_user_context = ResolverUserContext(
|
||||
saas_user_auth=saas_user_auth,
|
||||
resolver_org_id=self.resolved_org_id,
|
||||
)
|
||||
gitlab_user_context = ResolverUserContext(saas_user_auth=saas_user_auth)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, gitlab_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
@@ -283,7 +260,7 @@ class GitlabIssue(ResolverViewInterface):
|
||||
'is_mr': self.is_mr,
|
||||
'discussion_id': getattr(self, 'discussion_id', None),
|
||||
},
|
||||
should_request_summary=self.send_summary_instruction,
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from uuid import UUID
|
||||
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
@@ -14,10 +12,8 @@ class ResolverUserContext(UserContext):
|
||||
def __init__(
|
||||
self,
|
||||
saas_user_auth: UserAuth,
|
||||
resolver_org_id: UUID | None = None,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
self.resolver_org_id = resolver_org_id
|
||||
self._provider_handler: ProviderHandler | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Resolve which OpenHands organization workspace a resolver conversation should be created in.
|
||||
|
||||
This module provides a reusable utility for routing resolver conversations
|
||||
(GitHub, GitLab, Bitbucket, Slack, etc.) to the correct OpenHands organization
|
||||
workspace based on claimed Git organizations.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from storage.org_git_claim_store import OrgGitClaimStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
async def resolve_org_for_repo(
|
||||
provider: str,
|
||||
full_repo_name: str,
|
||||
keycloak_user_id: str | None = None,
|
||||
) -> UUID | None:
|
||||
"""Determine the OpenHands org_id for a resolver conversation.
|
||||
|
||||
If the repo's git organization is claimed by an OpenHands org, returns the
|
||||
claiming org's ID. When keycloak_user_id is provided, also verifies the user
|
||||
is a member of that org.
|
||||
|
||||
Args:
|
||||
provider: Git provider name ("github", "gitlab", "bitbucket")
|
||||
full_repo_name: Full repository name (e.g., "OpenHands/foo")
|
||||
keycloak_user_id: The user's Keycloak UUID string (optional). If provided,
|
||||
membership is verified before returning the org_id.
|
||||
|
||||
Returns:
|
||||
The org_id if the repo's org is claimed (and user is a member when
|
||||
keycloak_user_id is provided), else None
|
||||
"""
|
||||
git_org = full_repo_name.split('/')[0].lower()
|
||||
|
||||
try:
|
||||
claim = await OrgGitClaimStore.get_claim_by_provider_and_git_org(
|
||||
provider, git_org
|
||||
)
|
||||
if not claim:
|
||||
logger.debug(
|
||||
f'[OrgResolver] No claim found for {provider}/{git_org}',
|
||||
)
|
||||
return None
|
||||
|
||||
# Skip membership check if no user_id provided
|
||||
if keycloak_user_id is None:
|
||||
logger.info(
|
||||
f'[OrgResolver] Resolved org {claim.org_id} '
|
||||
f'for {provider}/{git_org} (no user membership check)',
|
||||
)
|
||||
return claim.org_id
|
||||
|
||||
member = await OrgMemberStore.get_org_member(
|
||||
claim.org_id, UUID(keycloak_user_id)
|
||||
)
|
||||
if not member:
|
||||
logger.debug(
|
||||
f'[OrgResolver] User {keycloak_user_id} is not a member of org '
|
||||
f'{claim.org_id} (claimed {provider}/{git_org}). '
|
||||
f'Falling back to personal workspace.',
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f'[OrgResolver] Routing conversation to org {claim.org_id} '
|
||||
f'for {provider}/{git_org} (user {keycloak_user_id})',
|
||||
)
|
||||
return claim.org_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[OrgResolver] Error resolving org for {provider}/{git_org}: {e}',
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
@@ -40,20 +40,16 @@ class SlackV1CallbackProcessor(EventCallbackProcessor):
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for Slack V1 integration."""
|
||||
# Only handle ConversationStateUpdateEvent for execution_status
|
||||
# Only handle ConversationStateUpdateEvent
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
|
||||
if event.key != 'execution_status':
|
||||
# Only act when execution has finished
|
||||
if not (event.key == 'execution_status' and event.value == 'finished'):
|
||||
return None
|
||||
|
||||
# Log ALL terminal states for monitoring (finished, error, stuck)
|
||||
_logger.info('[Slack V1] Callback agent state was %s', event)
|
||||
|
||||
# Only request summary when execution has finished successfully
|
||||
if event.value != 'finished':
|
||||
return None
|
||||
|
||||
try:
|
||||
summary = await self._request_summary(conversation_id)
|
||||
await self._post_summary_to_slack(summary)
|
||||
|
||||
@@ -4,7 +4,6 @@ from uuid import UUID, uuid4
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from integrations.slack.slack_types import (
|
||||
SlackMessageView,
|
||||
SlackViewInterface,
|
||||
@@ -18,9 +17,7 @@ from integrations.utils import (
|
||||
get_user_v1_enabled_setting,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from server.config import get_config
|
||||
from slack_sdk import WebClient
|
||||
from storage.saas_conversation_store import SaasConversationStore
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_conversation_store import SlackConversationStore
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
@@ -39,20 +36,18 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.provider import ProviderHandler, ProviderType
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT
|
||||
from openhands.utils.conversation_summary import get_default_conversation_title
|
||||
|
||||
# =================================================
|
||||
# SECTION: Slack view types
|
||||
@@ -207,22 +202,6 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
|
||||
# Determine git provider from repository (needed for both org routing and conversation creation)
|
||||
self._resolved_git_provider = None
|
||||
if self.selected_repo and provider_tokens:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
repository = await provider_handler.verify_repo_provider(self.selected_repo)
|
||||
self._resolved_git_provider = repository.git_provider
|
||||
|
||||
# Resolve target org based on claimed git organizations
|
||||
self.resolved_org_id = None
|
||||
if self._resolved_git_provider and self.selected_repo:
|
||||
self.resolved_org_id = await resolve_org_for_repo(
|
||||
provider=self._resolved_git_provider.value,
|
||||
full_repo_name=self.selected_repo,
|
||||
keycloak_user_id=self.slack_to_openhands_user.keycloak_user_id,
|
||||
)
|
||||
|
||||
# Check if V1 conversations are enabled for this user
|
||||
self.v1_enabled = await is_v1_enabled_for_slack_resolver(
|
||||
self.slack_to_openhands_user.keycloak_user_id
|
||||
@@ -245,44 +224,30 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
jinja
|
||||
)
|
||||
|
||||
user_id = self.slack_to_openhands_user.keycloak_user_id
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
if self.selected_repo and provider_tokens:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
repository = await provider_handler.verify_repo_provider(self.selected_repo)
|
||||
git_provider = repository.git_provider
|
||||
|
||||
# Create the conversation store with resolver org routing
|
||||
# (bypasses initialize_conversation to avoid threading enterprise-only
|
||||
# resolver_org_id through the generic OSS interface)
|
||||
store = await SaasConversationStore.get_resolver_instance(
|
||||
get_config(),
|
||||
user_id,
|
||||
self.resolved_org_id,
|
||||
)
|
||||
|
||||
conversation_id = uuid4().hex
|
||||
conversation_metadata = ConversationMetadata(
|
||||
trigger=ConversationTrigger.SLACK,
|
||||
conversation_id=conversation_id,
|
||||
title=get_default_conversation_title(conversation_id),
|
||||
user_id=user_id,
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.slack_to_openhands_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
git_provider=self._resolved_git_provider,
|
||||
)
|
||||
await store.save_metadata(conversation_metadata)
|
||||
|
||||
await start_conversation(
|
||||
user_id=user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
initial_user_msg=user_instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_id=conversation_id,
|
||||
conversation_metadata=conversation_metadata,
|
||||
conversation_instructions=(
|
||||
conversation_instructions if conversation_instructions else None
|
||||
),
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.SLACK,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
git_provider=git_provider,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_id
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
logger.info(f'[Slack]: Created V0 conversation: {self.conversation_id}')
|
||||
await self.save_slack_convo(v1_enabled=False)
|
||||
|
||||
@@ -300,8 +265,13 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
# Create the Slack V1 callback processor
|
||||
slack_callback_processor = self._create_slack_v1_callback_processor()
|
||||
|
||||
# Use git provider resolved in create_or_update_conversation
|
||||
git_provider = self._resolved_git_provider
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if self.selected_repo and provider_tokens:
|
||||
provider_handler = ProviderHandler(provider_tokens)
|
||||
repository = await provider_handler.verify_repo_provider(self.selected_repo)
|
||||
git_provider = ProviderType(repository.git_provider.value)
|
||||
|
||||
# Get the app conversation service and start the conversation
|
||||
injector_state = InjectorState()
|
||||
@@ -322,10 +292,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
)
|
||||
|
||||
# Set up the Slack user context for the V1 system
|
||||
slack_user_context = ResolverUserContext(
|
||||
saas_user_auth=self.saas_user_auth,
|
||||
resolver_org_id=self.resolved_org_id,
|
||||
)
|
||||
slack_user_context = ResolverUserContext(saas_user_auth=self.saas_user_auth)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, slack_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
|
||||
@@ -100,25 +100,27 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session, user_id: str, org: Org):
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
async def migrate_customer(user_id: str, org: Org):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -8,7 +8,7 @@ logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine, text # noqa: E402
|
||||
from sqlalchemy import create_engine # noqa: E402
|
||||
from storage.base import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
@@ -109,10 +109,6 @@ def run_migrations_online() -> None:
|
||||
version_table_schema=target_metadata.schema,
|
||||
)
|
||||
|
||||
# Lock number must be unique — md5 hash of 'openhands_enterprise_migrations'
|
||||
# Lock is released when the connection context manager exits
|
||||
connection.execute(text('SELECT pg_advisory_lock(3617572382373537863)'))
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""Add disabled_skills to user_settings.
|
||||
|
||||
Revision ID: 102
|
||||
Revises: 101
|
||||
Create Date: 2026-02-25
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '102'
|
||||
down_revision: Union[str, None] = '101'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'user_settings', sa.Column('disabled_skills', sa.JSON(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user_settings', 'disabled_skills')
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Add mcp_config to org_member for user-specific MCP settings.
|
||||
|
||||
Revision ID: 103
|
||||
Revises: 102
|
||||
Create Date: 2026-03-26
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '103'
|
||||
down_revision: Union[str, None] = '102'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('org_member', sa.Column('mcp_config', sa.JSON(), nullable=True))
|
||||
|
||||
# Migrate existing org-level MCP configs to all members in each org.
|
||||
# This preserves existing configurations while transitioning to user-specific settings.
|
||||
conn = op.get_bind()
|
||||
orgs_with_config = conn.execute(
|
||||
sa.text('SELECT id, mcp_config FROM org WHERE mcp_config IS NOT NULL')
|
||||
).fetchall()
|
||||
|
||||
for org_id, mcp_config in orgs_with_config:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
'UPDATE org_member SET mcp_config = :config WHERE org_id = :org_id'
|
||||
),
|
||||
{'config': json.dumps(mcp_config), 'org_id': str(org_id)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org_member', 'mcp_config')
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Add disabled_skills column to user table.
|
||||
|
||||
Migration 102 added disabled_skills to the legacy user_settings table,
|
||||
but the active SaaS flow (SaasSettingsStore) reads from/writes to the
|
||||
user table. This migration adds the column where it is actually needed.
|
||||
|
||||
Revision ID: 104
|
||||
Revises: 103
|
||||
Create Date: 2026-03-31
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '104'
|
||||
down_revision: Union[str, None] = '103'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('user', sa.Column('disabled_skills', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('user', 'disabled_skills')
|
||||
@@ -1,37 +0,0 @@
|
||||
"""Create org_git_claim table for tracking Git organization claims.
|
||||
|
||||
Revision ID: 105
|
||||
Revises: 104
|
||||
Create Date: 2026-04-01
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '105'
|
||||
down_revision: Union[str, None] = '104'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'org_git_claim',
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('org_id', sa.UUID(), nullable=False),
|
||||
sa.Column('provider', sa.String(), nullable=False),
|
||||
sa.Column('git_organization', sa.String(), nullable=False),
|
||||
sa.Column('claimed_by', sa.UUID(), nullable=False),
|
||||
sa.Column('claimed_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(['org_id'], ['org.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['claimed_by'], ['user.id']),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('provider', 'git_organization', name='uq_provider_git_org'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('org_git_claim')
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Add tags column to conversation_metadata table.
|
||||
|
||||
Tags store key-value pairs for automation context (trigger type, automation_id),
|
||||
skills used, and other metadata. This enables querying conversations by
|
||||
automation source and associating SDK-provided context with conversations.
|
||||
|
||||
Revision ID: 106
|
||||
Revises: 105
|
||||
Create Date: 2026-03-31
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '106'
|
||||
down_revision: Union[str, None] = '105'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('conversation_metadata', 'tags')
|
||||
Generated
+2286
-2552
File diff suppressed because it is too large
Load Diff
@@ -64,7 +64,6 @@ pytest-asyncio = "*"
|
||||
pytest-forked = "*"
|
||||
pytest-xdist = "*"
|
||||
flake8 = "*"
|
||||
freezegun = "^1.5.1"
|
||||
openai = "*"
|
||||
opencv-python = "*"
|
||||
pandas = "*"
|
||||
|
||||
@@ -46,7 +46,6 @@ from server.routes.org_invitations import ( # noqa: E402
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.service import service_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
@@ -113,7 +112,6 @@ if GITLAB_APP_CLIENT_ID:
|
||||
base_app.include_router(gitlab_integration_router)
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(service_router) # Add routes for internal service API
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
|
||||
@@ -35,13 +35,13 @@ Usage:
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
@@ -84,9 +84,6 @@ class Permission(str, Enum):
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
# Git organization claims
|
||||
MANAGE_ORG_CLAIMS = 'manage_org_claims'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
@@ -121,8 +118,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
# Git organization claims
|
||||
Permission.MANAGE_ORG_CLAIMS,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
@@ -144,8 +139,6 @@ ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Git organization claims
|
||||
Permission.MANAGE_ORG_CLAIMS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
@@ -221,19 +214,6 @@ def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
return permission in permissions
|
||||
|
||||
|
||||
async def get_api_key_org_id_from_request(request: Request) -> UUID | None:
|
||||
"""Get the org_id bound to the API key used for authentication.
|
||||
|
||||
Returns None if:
|
||||
- Not authenticated via API key (cookie auth)
|
||||
- API key is a legacy key without org binding
|
||||
"""
|
||||
user_auth = getattr(request.state, 'user_auth', None)
|
||||
if user_auth and hasattr(user_auth, 'get_api_key_org_id'):
|
||||
return user_auth.get_api_key_org_id()
|
||||
return None
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
@@ -241,9 +221,8 @@ def require_permission(permission: Permission):
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Validates API key org binding (if using API key auth)
|
||||
4. Checks if the user has the required permission in the organization
|
||||
5. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
@@ -261,7 +240,6 @@ def require_permission(permission: Permission):
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
request: Request,
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
@@ -271,23 +249,6 @@ def require_permission(permission: Permission):
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
# Validate API key organization binding
|
||||
api_key_org_id = await get_api_key_org_id_from_request(request)
|
||||
if api_key_org_id is not None and org_id is not None:
|
||||
if api_key_org_id != org_id:
|
||||
logger.warning(
|
||||
'API key organization mismatch',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'api_key_org_id': str(api_key_org_id),
|
||||
'target_org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='API key is not authorized for this organization',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
@@ -318,96 +279,3 @@ def require_permission(permission: Permission):
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
|
||||
|
||||
async def require_financial_data_access(
|
||||
request: Request,
|
||||
org_id: UUID,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
"""
|
||||
Authorization dependency for accessing organization financial data.
|
||||
|
||||
Allows access if ANY of these conditions are met:
|
||||
1. User has Admin or Owner role in the organization
|
||||
2. User has @openhands.dev email domain
|
||||
|
||||
This is used for the organization members financial data endpoint.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
org_id: Organization UUID from path parameter
|
||||
user_id: User ID from authentication
|
||||
|
||||
Returns:
|
||||
str: User ID if authorized
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if not authenticated, 403 if not authorized
|
||||
"""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
# Validate API key organization binding
|
||||
api_key_org_id = await get_api_key_org_id_from_request(request)
|
||||
if api_key_org_id is not None:
|
||||
if api_key_org_id != org_id:
|
||||
logger.warning(
|
||||
'API key organization mismatch for financial data access',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'api_key_org_id': str(api_key_org_id),
|
||||
'target_org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='API key is not authorized for this organization',
|
||||
)
|
||||
|
||||
# Check if user has @openhands.dev email
|
||||
user_auth = await get_user_auth(request)
|
||||
user_email = await user_auth.get_user_email()
|
||||
|
||||
if user_email and user_email.endswith('@openhands.dev'):
|
||||
logger.debug(
|
||||
'Financial data access granted via @openhands.dev email',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
return user_id
|
||||
|
||||
# Check if user has Admin or Owner role in the organization
|
||||
user_role = await get_user_org_role(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'Financial data access denied - user not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if user_role.name not in (RoleName.OWNER.value, RoleName.ADMIN.value):
|
||||
logger.warning(
|
||||
'Financial data access denied - insufficient role',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='Access restricted to organization admins, owners, or OpenHands members',
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
'Financial data access granted via admin/owner role',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id), 'role': user_role.name},
|
||||
)
|
||||
return user_id
|
||||
|
||||
@@ -6,6 +6,7 @@ GITHUB_APP_WEBHOOK_SECRET = os.getenv('GITHUB_APP_WEBHOOK_SECRET', '')
|
||||
GITHUB_APP_PRIVATE_KEY = os.getenv('GITHUB_APP_PRIVATE_KEY', '').replace('\\n', '\n')
|
||||
KEYCLOAK_SERVER_URL = os.getenv('KEYCLOAK_SERVER_URL', '').rstrip('/')
|
||||
KEYCLOAK_REALM_NAME = os.getenv('KEYCLOAK_REALM_NAME', '')
|
||||
KEYCLOAK_PROVIDER_NAME = os.getenv('KEYCLOAK_PROVIDER_NAME', '')
|
||||
KEYCLOAK_CLIENT_ID = os.getenv('KEYCLOAK_CLIENT_ID', '')
|
||||
KEYCLOAK_CLIENT_SECRET = os.getenv('KEYCLOAK_CLIENT_SECRET', '')
|
||||
KEYCLOAK_SERVER_URL_EXT = os.getenv(
|
||||
@@ -56,23 +57,6 @@ RECAPTCHA_SITE_KEY = os.getenv('RECAPTCHA_SITE_KEY', '').strip()
|
||||
RECAPTCHA_HMAC_SECRET = os.getenv('RECAPTCHA_HMAC_SECRET', '').strip()
|
||||
RECAPTCHA_BLOCK_THRESHOLD = float(os.getenv('RECAPTCHA_BLOCK_THRESHOLD', '0.3'))
|
||||
|
||||
# Automation Service
|
||||
AUTOMATION_SERVICE_URL = os.getenv('AUTOMATION_SERVICE_URL', '').strip()
|
||||
if AUTOMATION_SERVICE_URL and not AUTOMATION_SERVICE_URL.startswith(
|
||||
('http://', 'https://')
|
||||
):
|
||||
raise ValueError(
|
||||
f'AUTOMATION_SERVICE_URL must start with http:// or https://, '
|
||||
f'got: {AUTOMATION_SERVICE_URL}'
|
||||
)
|
||||
AUTOMATION_EVENT_FORWARDING_ENABLED = os.getenv(
|
||||
'AUTOMATION_EVENT_FORWARDING_ENABLED', 'false'
|
||||
) in ('1', 'true')
|
||||
# Shared secret for signing payloads sent to automation service (separate from GitHub webhook secret)
|
||||
AUTOMATION_WEBHOOK_SECRET = os.getenv('AUTOMATION_WEBHOOK_SECRET', '').strip()
|
||||
# Default HTTP timeout for automation service requests (seconds)
|
||||
AUTOMATION_SERVICE_TIMEOUT = int(os.getenv('AUTOMATION_SERVICE_TIMEOUT', '30'))
|
||||
|
||||
# Account Defender labels that indicate suspicious activity
|
||||
SUSPICIOUS_LABELS = {
|
||||
'SUSPICIOUS_LOGIN_ACTIVITY',
|
||||
|
||||
@@ -4,6 +4,7 @@ from server.auth.constants import (
|
||||
KEYCLOAK_ADMIN_PASSWORD,
|
||||
KEYCLOAK_CLIENT_ID,
|
||||
KEYCLOAK_CLIENT_SECRET,
|
||||
KEYCLOAK_PROVIDER_NAME,
|
||||
KEYCLOAK_REALM_NAME,
|
||||
KEYCLOAK_SERVER_URL,
|
||||
KEYCLOAK_SERVER_URL_EXT,
|
||||
@@ -11,7 +12,7 @@ from server.auth.constants import (
|
||||
from server.logger import logger
|
||||
|
||||
logger.debug(
|
||||
f'KEYCLOAK_SERVER_URL:{KEYCLOAK_SERVER_URL}, KEYCLOAK_SERVER_URL_EXT:{KEYCLOAK_SERVER_URL_EXT}, KEYCLOAK_CLIENT_ID:{KEYCLOAK_CLIENT_ID}'
|
||||
f'KEYCLOAK_SERVER_URL:{KEYCLOAK_SERVER_URL}, KEYCLOAK_SERVER_URL_EXT:{KEYCLOAK_SERVER_URL_EXT}, KEYCLOAK_PROVIDER_NAME:{KEYCLOAK_PROVIDER_NAME}, KEYCLOAK_CLIENT_ID:{KEYCLOAK_CLIENT_ID}'
|
||||
)
|
||||
|
||||
_keycloak_instances = {}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from types import MappingProxyType
|
||||
from uuid import UUID
|
||||
|
||||
import jwt
|
||||
from fastapi import Request
|
||||
@@ -60,19 +59,6 @@ class SaasUserAuth(UserAuth):
|
||||
_secrets: Secrets | None = None
|
||||
accepted_tos: bool | None = None
|
||||
auth_type: AuthType = AuthType.COOKIE
|
||||
# API key context fields - populated when authenticated via API key
|
||||
api_key_org_id: UUID | None = None # Org bound to the API key used for auth
|
||||
api_key_id: int | None = None
|
||||
api_key_name: str | None = None
|
||||
|
||||
def get_api_key_org_id(self) -> UUID | None:
|
||||
"""Get the organization ID bound to the API key used for authentication.
|
||||
|
||||
Returns:
|
||||
The org_id if authenticated via API key with org binding, None otherwise
|
||||
(cookie auth or legacy API keys without org binding).
|
||||
"""
|
||||
return self.api_key_org_id
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return self.user_id
|
||||
@@ -297,19 +283,14 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
if not validation_result:
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(
|
||||
validation_result.user_id
|
||||
)
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
saas_user_auth = SaasUserAuth(
|
||||
user_id=validation_result.user_id,
|
||||
user_id=user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=validation_result.org_id,
|
||||
api_key_id=validation_result.key_id,
|
||||
api_key_name=validation_result.key_name,
|
||||
)
|
||||
await saas_user_auth.refresh()
|
||||
return saas_user_auth
|
||||
|
||||
@@ -80,11 +80,10 @@ def setup_json_logger(
|
||||
handler.setLevel(level)
|
||||
|
||||
formatter = JsonFormatter(
|
||||
'%(message)s%(levelname)s%(module)s%(funcName)s%(lineno)d',
|
||||
'{message}{levelname}',
|
||||
style='{',
|
||||
rename_fields={'levelname': 'severity'},
|
||||
json_serializer=custom_json_serializer,
|
||||
# Use 'ts' for consistency with LOG_JSON_FOR_CONSOLE mode (skip when console mode to avoid duplicates)
|
||||
timestamp='ts' if not LOG_JSON_FOR_CONSOLE else False,
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
@@ -182,10 +182,6 @@ class SetAuthCookieMiddleware:
|
||||
if path.startswith('/api/v1/webhooks/'):
|
||||
return False
|
||||
|
||||
# Service API uses its own authentication (X-Service-API-Key header)
|
||||
if path.startswith('/api/service/'):
|
||||
return False
|
||||
|
||||
is_mcp = path.startswith('/mcp')
|
||||
is_api_route = path.startswith('/api')
|
||||
return is_api_route or is_mcp
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from datetime import UTC, datetime
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
@@ -13,8 +11,7 @@ from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_auth, get_user_id
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
# Helper functions for BYOR API key management
|
||||
@@ -153,16 +150,6 @@ class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class CurrentApiKeyResponse(BaseModel):
|
||||
"""Response model for the current API key endpoint."""
|
||||
|
||||
id: int
|
||||
name: str | None
|
||||
org_id: str
|
||||
user_id: str
|
||||
auth_type: str
|
||||
|
||||
|
||||
def api_key_to_response(key: ApiKey) -> ApiKeyResponse:
|
||||
"""Convert an ApiKey model to an ApiKeyResponse."""
|
||||
return ApiKeyResponse(
|
||||
@@ -275,46 +262,6 @@ async def delete_api_key(
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/current', tags=['Keys'])
|
||||
async def get_current_api_key(
|
||||
request: Request,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> CurrentApiKeyResponse:
|
||||
"""Get information about the currently authenticated API key.
|
||||
|
||||
This endpoint returns metadata about the API key used for the current request,
|
||||
including the org_id associated with the key. This is useful for API key
|
||||
callers who need to know which organization context their key operates in.
|
||||
|
||||
Returns 400 if not authenticated via API key (e.g., using cookie auth).
|
||||
"""
|
||||
user_auth = await get_user_auth(request)
|
||||
|
||||
# Check if authenticated via API key
|
||||
if user_auth.get_auth_type() != AuthType.BEARER:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This endpoint requires API key authentication. Not available for cookie-based auth.',
|
||||
)
|
||||
|
||||
# In SaaS context, bearer auth always produces SaasUserAuth
|
||||
saas_user_auth = cast(SaasUserAuth, user_auth)
|
||||
|
||||
if saas_user_auth.api_key_org_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='This API key was created before organization support. Please regenerate your API key to use this endpoint.',
|
||||
)
|
||||
|
||||
return CurrentApiKeyResponse(
|
||||
id=saas_user_auth.api_key_id,
|
||||
name=saas_user_auth.api_key_name,
|
||||
org_id=str(saas_user_auth.api_key_org_id),
|
||||
user_id=user_id,
|
||||
auth_type=saas_user_auth.auth_type.value,
|
||||
)
|
||||
|
||||
|
||||
@api_router.get('/llm/byor', tags=['Keys'])
|
||||
async def get_llm_api_key_for_byor(
|
||||
user_id: str = Depends(get_user_id),
|
||||
|
||||
@@ -172,23 +172,6 @@ async def keycloak_callback(
|
||||
|
||||
authorization = await user_authorizer.authorize_user(user_info)
|
||||
if not authorization.success:
|
||||
# For duplicate_email errors, clean up the newly created Keycloak user
|
||||
# (only if they're not already in our UserStore, i.e., they're a new user)
|
||||
if authorization.error_detail == 'duplicate_email':
|
||||
try:
|
||||
existing_user = await UserStore.get_user_by_id(user_info.sub)
|
||||
if not existing_user:
|
||||
# New user created during OAuth should be deleted from Keycloak
|
||||
await token_manager.delete_keycloak_user(user_info.sub)
|
||||
logger.info(
|
||||
f'Deleted orphaned Keycloak user {user_info.sub} '
|
||||
'after duplicate_email rejection'
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail - user should still get 401 response
|
||||
logger.warning(
|
||||
f'Failed to clean up orphaned Keycloak user {user_info.sub}: {e}'
|
||||
)
|
||||
# Return unauthorized
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
|
||||
@@ -7,8 +7,8 @@ from storage.database import a_session_maker
|
||||
from storage.feedback import ConversationFeedback
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.app_server.utils.dependencies import get_dependencies
|
||||
from openhands.events.event_store import EventStore
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.shared import file_store
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
@@ -3,17 +3,13 @@ import hashlib
|
||||
import hmac
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException, Request
|
||||
from fastapi import APIRouter, Header, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from integrations.github.data_collector import GitHubDataCollector
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.models import Message, SourceType
|
||||
from server.auth.constants import (
|
||||
AUTOMATION_EVENT_FORWARDING_ENABLED,
|
||||
GITHUB_APP_WEBHOOK_SECRET,
|
||||
)
|
||||
from server.auth.constants import GITHUB_APP_WEBHOOK_SECRET
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
@@ -26,7 +22,6 @@ github_integration_router = APIRouter(prefix='/integration')
|
||||
token_manager = TokenManager()
|
||||
data_collector = GitHubDataCollector()
|
||||
github_manager = GithubManager(token_manager, data_collector)
|
||||
automation_event_service = AutomationEventService(token_manager)
|
||||
|
||||
|
||||
def verify_github_signature(payload: bytes, signature: str):
|
||||
@@ -51,9 +46,7 @@ def verify_github_signature(payload: bytes, signature: str):
|
||||
@github_integration_router.post('/github/events')
|
||||
async def github_events(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
x_hub_signature_256: str = Header(None),
|
||||
x_github_event: str = Header(None),
|
||||
):
|
||||
# Check if GitHub webhooks are enabled
|
||||
if not GITHUB_WEBHOOKS_ENABLED:
|
||||
@@ -79,15 +72,6 @@ async def github_events(
|
||||
content={'error': 'Installation ID is missing in the payload.'},
|
||||
)
|
||||
|
||||
# Forward to automation service (fire-and-forget background task)
|
||||
if AUTOMATION_EVENT_FORWARDING_ENABLED:
|
||||
background_tasks.add_task(
|
||||
automation_event_service.forward_github_event,
|
||||
payload=payload_data,
|
||||
installation_id=installation_id,
|
||||
)
|
||||
|
||||
# Existing resolver bot processing
|
||||
message_payload = {'payload': payload_data, 'installation': installation_id}
|
||||
message = Message(source=SourceType.GITHUB, message=message_payload)
|
||||
await github_manager.receive_message(message)
|
||||
|
||||
@@ -120,18 +120,3 @@ class BatchInvitationResponse(BaseModel):
|
||||
|
||||
successful: list[InvitationResponse]
|
||||
failed: list[InvitationFailure]
|
||||
|
||||
|
||||
class AcceptInvitationRequest(BaseModel):
|
||||
"""Request model for accepting an invitation via POST."""
|
||||
|
||||
token: str
|
||||
|
||||
|
||||
class AcceptInvitationResponse(BaseModel):
|
||||
"""Response model for successful invitation acceptance."""
|
||||
|
||||
success: bool
|
||||
org_id: str
|
||||
org_name: str
|
||||
role: str
|
||||
|
||||
@@ -5,8 +5,6 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from server.routes.org_invitation_models import (
|
||||
AcceptInvitationRequest,
|
||||
AcceptInvitationResponse,
|
||||
BatchInvitationResponse,
|
||||
EmailMismatchError,
|
||||
InsufficientPermissionError,
|
||||
@@ -19,11 +17,10 @@ from server.routes.org_invitation_models import (
|
||||
)
|
||||
from server.services.org_invitation_service import OrgInvitationService
|
||||
from server.utils.rate_limit_utils import check_rate_limit_by_user_id
|
||||
from storage.org_store import OrgStore
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
from openhands.server.user_auth.user_auth import get_user_auth
|
||||
|
||||
# Router for invitation operations on an organization (requires org_id)
|
||||
invitation_router = APIRouter(prefix='/api/organizations/{org_id}/members')
|
||||
@@ -126,93 +123,70 @@ async def create_invitation(
|
||||
|
||||
|
||||
@accept_router.get('/accept')
|
||||
async def accept_invitation_redirect(
|
||||
async def accept_invitation(
|
||||
token: str,
|
||||
request: Request,
|
||||
):
|
||||
"""Redirect invitation acceptance to frontend.
|
||||
"""Accept an organization invitation via token.
|
||||
|
||||
This endpoint is accessed via the link in the invitation email.
|
||||
It always redirects to the home page with the invitation token,
|
||||
allowing the frontend to handle the acceptance flow via a modal.
|
||||
|
||||
This approach works with SameSite='strict' cookies because:
|
||||
- Cross-site navigation (clicking email link) doesn't send cookies
|
||||
- But same-origin POST requests (from frontend) DO send cookies
|
||||
Flow:
|
||||
1. If user is authenticated: Accept invitation directly and redirect to home
|
||||
2. If user is not authenticated: Redirect to login page with invitation token
|
||||
- Frontend stores token and includes it in OAuth state during login
|
||||
- After authentication, keycloak_callback processes the invitation
|
||||
|
||||
Args:
|
||||
token: The invitation token from the email link
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
RedirectResponse: Redirect to home page with invitation_token query param
|
||||
RedirectResponse: Redirect to home page on success, or login page if not authenticated,
|
||||
or home page with error query params on failure
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
|
||||
logger.info(
|
||||
'Invitation accept: redirecting to frontend for acceptance',
|
||||
extra={'token_prefix': token[:10] + '...'},
|
||||
)
|
||||
|
||||
return RedirectResponse(f'{base_url}/?invitation_token={token}', status_code=302)
|
||||
|
||||
|
||||
@accept_router.post('/accept', response_model=AcceptInvitationResponse)
|
||||
async def accept_invitation(
|
||||
request_data: AcceptInvitationRequest,
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
"""Accept an organization invitation via authenticated POST request.
|
||||
|
||||
This endpoint is called by the frontend after displaying the acceptance modal.
|
||||
Requires authentication - cookies are sent because this is a same-origin request.
|
||||
|
||||
Args:
|
||||
request_data: Contains the invitation token
|
||||
user_id: Authenticated user ID (from dependency)
|
||||
|
||||
Returns:
|
||||
AcceptInvitationResponse: Success response with organization details
|
||||
|
||||
Raises:
|
||||
HTTPException 400: Invalid or expired token
|
||||
HTTPException 403: Email mismatch
|
||||
HTTPException 409: User already a member
|
||||
"""
|
||||
token = request_data.token
|
||||
|
||||
# Try to get user_id from auth (may not be authenticated)
|
||||
user_id = None
|
||||
try:
|
||||
invitation = await OrgInvitationService.accept_invitation(token, UUID(user_id))
|
||||
user_auth = await get_user_auth(request)
|
||||
if user_auth:
|
||||
user_id = await user_auth.get_user_id()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get organization and role details for response
|
||||
org = await OrgStore.get_org_by_id(invitation.org_id)
|
||||
role = await RoleStore.get_role_by_id(invitation.role_id)
|
||||
if not user_id:
|
||||
# User not authenticated - redirect to login page with invitation token
|
||||
# Frontend will store the token and include it in OAuth state during login
|
||||
logger.info(
|
||||
'Invitation accept: redirecting unauthenticated user to login',
|
||||
extra={'token_prefix': token[:10] + '...'},
|
||||
)
|
||||
login_url = f'{base_url}/login?invitation_token={token}'
|
||||
return RedirectResponse(login_url, status_code=302)
|
||||
|
||||
# User is authenticated - process the invitation directly
|
||||
try:
|
||||
await OrgInvitationService.accept_invitation(token, UUID(user_id))
|
||||
|
||||
logger.info(
|
||||
'Invitation accepted via API',
|
||||
'Invitation accepted successfully',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'org_id': str(invitation.org_id),
|
||||
},
|
||||
)
|
||||
|
||||
return AcceptInvitationResponse(
|
||||
success=True,
|
||||
org_id=str(invitation.org_id),
|
||||
org_name=org.name if org else '',
|
||||
role=role.name if role else '',
|
||||
)
|
||||
# Redirect to home page on success
|
||||
return RedirectResponse(f'{base_url}/', status_code=302)
|
||||
|
||||
except InvitationExpiredError:
|
||||
logger.warning(
|
||||
'Invitation accept failed: expired',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='invitation_expired',
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_expired=true', status_code=302)
|
||||
|
||||
except InvitationInvalidError as e:
|
||||
logger.warning(
|
||||
@@ -223,20 +197,14 @@ async def accept_invitation(
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail='invitation_invalid',
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_invalid=true', status_code=302)
|
||||
|
||||
except UserAlreadyMemberError:
|
||||
logger.info(
|
||||
'Invitation accept: user already member',
|
||||
extra={'token_prefix': token[:10] + '...', 'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail='already_member',
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?already_member=true', status_code=302)
|
||||
|
||||
except EmailMismatchError as e:
|
||||
logger.warning(
|
||||
@@ -247,21 +215,15 @@ async def accept_invitation(
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='email_mismatch',
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?email_mismatch=true', status_code=302)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Unexpected error accepting invitation via API',
|
||||
'Unexpected error accepting invitation',
|
||||
extra={
|
||||
'token_prefix': token[:10] + '...',
|
||||
'user_id': user_id,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='An unexpected error occurred',
|
||||
)
|
||||
return RedirectResponse(f'{base_url}/?invitation_error=true', status_code=302)
|
||||
|
||||
@@ -241,6 +241,7 @@ class OrgUpdate(BaseModel):
|
||||
enable_proactive_conversation_starters: bool | None = None
|
||||
sandbox_base_container_image: str | None = None
|
||||
sandbox_runtime_container_image: str | None = None
|
||||
mcp_config: dict | None = None
|
||||
sandbox_api_key: str | None = None
|
||||
max_budget_per_task: float | None = Field(default=None, gt=0)
|
||||
enable_solvability_analysis: bool | None = None
|
||||
@@ -483,72 +484,3 @@ class OrgAppSettingsUpdate(BaseModel):
|
||||
if v is not None and v <= 0:
|
||||
raise ValueError('max_budget_per_task must be greater than 0')
|
||||
return v
|
||||
|
||||
|
||||
VALID_GIT_PROVIDERS = {'github', 'gitlab', 'bitbucket'}
|
||||
|
||||
|
||||
class GitOrgClaimRequest(BaseModel):
|
||||
"""Request model for claiming a Git organization."""
|
||||
|
||||
provider: str
|
||||
git_organization: str
|
||||
|
||||
@field_validator('provider')
|
||||
@classmethod
|
||||
def validate_provider(cls, v: str) -> str:
|
||||
v = v.lower().strip()
|
||||
if v not in VALID_GIT_PROVIDERS:
|
||||
raise ValueError(
|
||||
f'Invalid provider: "{v}". Must be one of: {", ".join(sorted(VALID_GIT_PROVIDERS))}'
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator('git_organization')
|
||||
@classmethod
|
||||
def validate_git_organization(cls, v: str) -> str:
|
||||
v = v.strip().lower()
|
||||
if not v:
|
||||
raise ValueError('git_organization must not be empty')
|
||||
return v
|
||||
|
||||
|
||||
class GitOrgClaimResponse(BaseModel):
|
||||
"""Response model for a Git organization claim."""
|
||||
|
||||
id: str
|
||||
org_id: str
|
||||
provider: str
|
||||
git_organization: str
|
||||
claimed_by: str
|
||||
claimed_at: str
|
||||
|
||||
|
||||
class GitOrgAlreadyClaimedError(Exception):
|
||||
"""Raised when a Git organization is already claimed by another OpenHands org."""
|
||||
|
||||
def __init__(self, provider: str, git_organization: str):
|
||||
self.provider = provider
|
||||
self.git_organization = git_organization
|
||||
super().__init__(
|
||||
f'Git organization "{git_organization}" on {provider} is already claimed by another organization'
|
||||
)
|
||||
|
||||
|
||||
class OrgMemberFinancialResponse(BaseModel):
|
||||
"""Financial data for a single organization member."""
|
||||
|
||||
user_id: str
|
||||
email: str | None
|
||||
lifetime_spend: float # Total amount spent (from LiteLLM)
|
||||
current_budget: float # Remaining budget (max_budget - spend)
|
||||
max_budget: float | None # Total allocated budget (None = unlimited)
|
||||
|
||||
|
||||
class OrgMemberFinancialPage(BaseModel):
|
||||
"""Paginated response for organization member financial data."""
|
||||
|
||||
items: list[OrgMemberFinancialResponse]
|
||||
current_page: int = 1
|
||||
per_page: int = 10
|
||||
next_page_id: str | None = None
|
||||
|
||||
@@ -4,15 +4,11 @@ from uuid import UUID
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_financial_data_access,
|
||||
require_permission,
|
||||
)
|
||||
from server.email_validation import get_admin_user_id
|
||||
from server.routes.org_models import (
|
||||
CannotModifySelfError,
|
||||
GitOrgAlreadyClaimedError,
|
||||
GitOrgClaimRequest,
|
||||
GitOrgClaimResponse,
|
||||
InsufficientPermissionError,
|
||||
InvalidRoleError,
|
||||
LastOwnerError,
|
||||
@@ -26,7 +22,6 @@ from server.routes.org_models import (
|
||||
OrgDatabaseError,
|
||||
OrgLLMSettingsResponse,
|
||||
OrgLLMSettingsUpdate,
|
||||
OrgMemberFinancialPage,
|
||||
OrgMemberNotFoundError,
|
||||
OrgMemberPage,
|
||||
OrgMemberResponse,
|
||||
@@ -47,10 +42,7 @@ from server.services.org_llm_settings_service import (
|
||||
OrgLLMSettingsService,
|
||||
OrgLLMSettingsServiceInjector,
|
||||
)
|
||||
from server.services.org_member_financial_service import OrgMemberFinancialService
|
||||
from server.services.org_member_service import OrgMemberService
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.org_git_claim_store import OrgGitClaimStore
|
||||
from storage.org_service import OrgService
|
||||
from storage.user_store import UserStore
|
||||
|
||||
@@ -76,7 +68,7 @@ async def list_user_orgs(
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, le=100),
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
user_id: str = Depends(get_user_id),
|
||||
) -> OrgPage:
|
||||
@@ -742,7 +734,7 @@ async def get_org_members(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
le=100,
|
||||
lte=100,
|
||||
),
|
||||
] = 10,
|
||||
email: Annotated[
|
||||
@@ -891,104 +883,6 @@ async def get_org_members_count(
|
||||
)
|
||||
|
||||
|
||||
@org_router.get(
|
||||
'/{org_id}/members/financial',
|
||||
response_model=OrgMemberFinancialPage,
|
||||
)
|
||||
async def get_org_members_financial(
|
||||
org_id: UUID,
|
||||
page_id: Annotated[
|
||||
str | None,
|
||||
Query(
|
||||
title='Pagination offset encoded as string',
|
||||
description='Offset for pagination (e.g., "0", "10", "20")',
|
||||
),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(
|
||||
title='Maximum items per page',
|
||||
gt=0,
|
||||
le=100,
|
||||
),
|
||||
] = 10,
|
||||
email: Annotated[
|
||||
str | None,
|
||||
Query(
|
||||
title='Filter members by email (case-insensitive partial match)',
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
),
|
||||
] = None,
|
||||
user_id: str = Depends(require_financial_data_access),
|
||||
) -> OrgMemberFinancialPage:
|
||||
"""Get paginated financial data for organization members.
|
||||
|
||||
Returns financial information (lifetime spend, current budget) for all members
|
||||
within the specified organization. Access is restricted to:
|
||||
- Organization Admins
|
||||
- Organization Owners
|
||||
- OpenHands members (users with @openhands.dev emails)
|
||||
|
||||
Args:
|
||||
org_id: Organization ID (UUID)
|
||||
page_id: Optional pagination offset encoded as string
|
||||
limit: Maximum items per page (1-100, default 10)
|
||||
email: Optional email filter (case-insensitive partial match)
|
||||
user_id: Authenticated user ID (injected by require_financial_data_access)
|
||||
|
||||
Returns:
|
||||
OrgMemberFinancialPage: Paginated response with member financial data
|
||||
- items: List of members with user_id, email, lifetime_spend,
|
||||
current_budget, and max_budget
|
||||
- current_page: Current page number (1-indexed)
|
||||
- per_page: Items per page
|
||||
- next_page_id: Offset for next page, or None if no more pages
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if user is not authenticated
|
||||
HTTPException: 403 if user lacks access (not admin/owner and not @openhands.dev)
|
||||
HTTPException: 400 if page_id is invalid
|
||||
HTTPException: 500 if retrieval fails
|
||||
"""
|
||||
logger.info(
|
||||
'Getting financial data for organization members',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'user_id': user_id,
|
||||
'page_id': page_id,
|
||||
'limit': limit,
|
||||
'email_filter': email,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
return await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
email_filter=email,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
'Invalid page_id for financial data request',
|
||||
extra={'org_id': str(org_id), 'page_id': page_id, 'error': str(e)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
'Error retrieving organization member financial data',
|
||||
extra={'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to retrieve member financial data',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete('/{org_id}/members/{user_id}')
|
||||
async def remove_org_member(
|
||||
org_id: UUID,
|
||||
@@ -1217,181 +1111,3 @@ async def update_org_member(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to update member',
|
||||
)
|
||||
|
||||
|
||||
@org_router.get(
|
||||
'/{org_id}/git-claims',
|
||||
response_model=list[GitOrgClaimResponse],
|
||||
)
|
||||
async def get_git_claims(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.MANAGE_ORG_CLAIMS)),
|
||||
) -> list[GitOrgClaimResponse]:
|
||||
"""Get all Git organization claims for an OpenHands organization.
|
||||
|
||||
Only admin and owner roles can view Git organization claims.
|
||||
|
||||
Args:
|
||||
org_id: OpenHands organization UUID
|
||||
user_id: Authenticated user ID (injected by permission check)
|
||||
|
||||
Returns:
|
||||
List of GitOrgClaimResponse with claim details
|
||||
"""
|
||||
try:
|
||||
claims = await OrgGitClaimStore.get_claims_by_org_id(org_id=org_id)
|
||||
return [
|
||||
GitOrgClaimResponse(
|
||||
id=str(claim.id),
|
||||
org_id=str(claim.org_id),
|
||||
provider=claim.provider,
|
||||
git_organization=claim.git_organization,
|
||||
claimed_by=str(claim.claimed_by),
|
||||
claimed_at=claim.claimed_at.isoformat(),
|
||||
)
|
||||
for claim in claims
|
||||
]
|
||||
except Exception:
|
||||
logger.exception('Error fetching Git organization claims')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to fetch Git organization claims',
|
||||
)
|
||||
|
||||
|
||||
@org_router.post(
|
||||
'/{org_id}/git-claims',
|
||||
response_model=GitOrgClaimResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def claim_git_organization(
|
||||
org_id: UUID,
|
||||
request: GitOrgClaimRequest,
|
||||
user_id: str = Depends(require_permission(Permission.MANAGE_ORG_CLAIMS)),
|
||||
) -> GitOrgClaimResponse:
|
||||
"""Claim a Git organization for an OpenHands organization.
|
||||
|
||||
Only admin and owner roles can claim Git organizations.
|
||||
A Git organization can only be claimed by one OpenHands organization at a time.
|
||||
|
||||
Args:
|
||||
org_id: OpenHands organization UUID
|
||||
request: Claim request with provider and git_organization
|
||||
user_id: Authenticated user ID (injected by permission check)
|
||||
|
||||
Returns:
|
||||
GitOrgClaimResponse with the created claim details
|
||||
|
||||
Raises:
|
||||
HTTPException 409: If the Git organization is already claimed
|
||||
HTTPException 403: If user lacks permission
|
||||
"""
|
||||
try:
|
||||
# Check if this Git org is already claimed (early feedback for the common case)
|
||||
existing_claim = await OrgGitClaimStore.get_claim_by_provider_and_git_org(
|
||||
provider=request.provider,
|
||||
git_organization=request.git_organization,
|
||||
)
|
||||
|
||||
if existing_claim:
|
||||
raise GitOrgAlreadyClaimedError(
|
||||
provider=request.provider,
|
||||
git_organization=request.git_organization,
|
||||
)
|
||||
|
||||
# Create the claim — the DB unique constraint handles the race condition
|
||||
# where two concurrent requests both pass the check above.
|
||||
claim = await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider=request.provider,
|
||||
git_organization=request.git_organization,
|
||||
claimed_by=UUID(user_id),
|
||||
)
|
||||
|
||||
return GitOrgClaimResponse(
|
||||
id=str(claim.id),
|
||||
org_id=str(claim.org_id),
|
||||
provider=claim.provider,
|
||||
git_organization=claim.git_organization,
|
||||
claimed_by=str(claim.claimed_by),
|
||||
claimed_at=claim.claimed_at.isoformat(),
|
||||
)
|
||||
|
||||
except GitOrgAlreadyClaimedError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(e),
|
||||
)
|
||||
except IntegrityError as e:
|
||||
# Only treat the unique constraint violation as a duplicate claim.
|
||||
# Other integrity errors (e.g. FK violations) should surface as 500s.
|
||||
if 'uq_provider_git_org' in str(e.orig):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=str(
|
||||
GitOrgAlreadyClaimedError(
|
||||
provider=request.provider,
|
||||
git_organization=request.git_organization,
|
||||
)
|
||||
),
|
||||
)
|
||||
logger.exception('Integrity error claiming Git organization')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to claim Git organization',
|
||||
)
|
||||
except Exception:
|
||||
logger.exception('Error claiming Git organization')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to claim Git organization',
|
||||
)
|
||||
|
||||
|
||||
@org_router.delete(
|
||||
'/{org_id}/git-claims/{claim_id}',
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def disconnect_git_organization(
|
||||
org_id: UUID,
|
||||
claim_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.MANAGE_ORG_CLAIMS)),
|
||||
) -> dict:
|
||||
"""Remove a Git organization claim from an OpenHands organization.
|
||||
|
||||
Only admin and owner roles can disconnect Git organization claims.
|
||||
|
||||
Args:
|
||||
org_id: OpenHands organization UUID
|
||||
claim_id: Claim UUID to remove
|
||||
user_id: Authenticated user ID (injected by permission check)
|
||||
|
||||
Returns:
|
||||
dict: Confirmation message on successful deletion
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If the claim is not found for this organization
|
||||
HTTPException 403: If user lacks permission
|
||||
"""
|
||||
try:
|
||||
deleted = await OrgGitClaimStore.delete_claim(
|
||||
claim_id=claim_id,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail='Git organization claim not found',
|
||||
)
|
||||
|
||||
return {'message': 'Git organization claim removed successfully'}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception('Error disconnecting Git organization')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to disconnect Git organization',
|
||||
)
|
||||
|
||||
@@ -1,270 +0,0 @@
|
||||
"""
|
||||
Service API routes for internal service-to-service communication.
|
||||
|
||||
This module provides endpoints for trusted internal services (e.g., automations service)
|
||||
to perform privileged operations like creating API keys on behalf of users.
|
||||
|
||||
Authentication is via a shared secret (X-Service-API-Key header) configured
|
||||
through the AUTOMATIONS_SERVICE_KEY environment variable.
|
||||
"""
|
||||
|
||||
import os
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, status
|
||||
from pydantic import BaseModel, field_validator
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.user_store import UserStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Environment variable for the service API key
|
||||
AUTOMATIONS_SERVICE_KEY = os.getenv('AUTOMATIONS_SERVICE_KEY', '').strip()
|
||||
|
||||
service_router = APIRouter(prefix='/api/service', tags=['Service'])
|
||||
|
||||
|
||||
class CreateUserApiKeyRequest(BaseModel):
|
||||
"""Request model for creating an API key on behalf of a user."""
|
||||
|
||||
name: str # Required - used to identify the key
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_name(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError('name is required and cannot be empty')
|
||||
return v.strip()
|
||||
|
||||
|
||||
class CreateUserApiKeyResponse(BaseModel):
|
||||
"""Response model for created API key."""
|
||||
|
||||
key: str
|
||||
user_id: str
|
||||
org_id: str
|
||||
name: str
|
||||
|
||||
|
||||
class ServiceInfoResponse(BaseModel):
|
||||
"""Response model for service info endpoint."""
|
||||
|
||||
service: str
|
||||
authenticated: bool
|
||||
|
||||
|
||||
async def validate_service_api_key(
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> str:
|
||||
"""
|
||||
Validate the service API key from the request header.
|
||||
|
||||
Args:
|
||||
x_service_api_key: The service API key from the X-Service-API-Key header
|
||||
|
||||
Returns:
|
||||
str: Service identifier for audit logging
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if key is missing or invalid
|
||||
HTTPException: 503 if service auth is not configured
|
||||
"""
|
||||
if not AUTOMATIONS_SERVICE_KEY:
|
||||
logger.warning(
|
||||
'Service authentication not configured (AUTOMATIONS_SERVICE_KEY not set)'
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail='Service authentication not configured',
|
||||
)
|
||||
|
||||
if not x_service_api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='X-Service-API-Key header is required',
|
||||
)
|
||||
|
||||
if x_service_api_key != AUTOMATIONS_SERVICE_KEY:
|
||||
logger.warning('Invalid service API key attempted')
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='Invalid service API key',
|
||||
)
|
||||
|
||||
return 'automations-service'
|
||||
|
||||
|
||||
@service_router.get('/health')
|
||||
async def service_health() -> dict:
|
||||
"""Health check endpoint for the service API.
|
||||
|
||||
This endpoint does not require authentication and can be used
|
||||
to verify the service routes are accessible.
|
||||
"""
|
||||
return {
|
||||
'status': 'ok',
|
||||
'service_auth_configured': bool(AUTOMATIONS_SERVICE_KEY),
|
||||
}
|
||||
|
||||
|
||||
@service_router.post('/users/{user_id}/orgs/{org_id}/api-keys')
|
||||
async def get_or_create_api_key_for_user(
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
request: CreateUserApiKeyRequest,
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> CreateUserApiKeyResponse:
|
||||
"""
|
||||
Get or create an API key for a user on behalf of the automations service.
|
||||
|
||||
If a key with the given name already exists for the user/org and is not expired,
|
||||
returns the existing key. Otherwise, creates a new key.
|
||||
|
||||
The created/returned keys are system keys and are:
|
||||
- Not visible to the user in their API keys list
|
||||
- Not deletable by the user
|
||||
- Never expire
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
org_id: The organization ID
|
||||
request: Request body containing name (required)
|
||||
x_service_api_key: Service API key header for authentication
|
||||
|
||||
Returns:
|
||||
CreateUserApiKeyResponse: The API key and metadata
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if service key is invalid
|
||||
HTTPException: 404 if user not found
|
||||
HTTPException: 403 if user is not a member of the specified org
|
||||
"""
|
||||
# Validate service API key
|
||||
service_id = await validate_service_api_key(x_service_api_key)
|
||||
|
||||
# Verify user exists
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if not user:
|
||||
logger.warning(
|
||||
'Service attempted to create key for non-existent user',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'User {user_id} not found',
|
||||
)
|
||||
|
||||
# Verify user is a member of the specified org
|
||||
org_member = await OrgMemberStore.get_org_member(org_id, UUID(user_id))
|
||||
if not org_member:
|
||||
logger.warning(
|
||||
'Service attempted to create key for user not in org',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'User {user_id} is not a member of org {org_id}',
|
||||
)
|
||||
|
||||
# Get or create the system API key
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
|
||||
try:
|
||||
api_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=request.name,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'Failed to get or create system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail='Failed to get or create API key',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Service created API key for user',
|
||||
extra={
|
||||
'service_id': service_id,
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': request.name,
|
||||
},
|
||||
)
|
||||
|
||||
return CreateUserApiKeyResponse(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=str(org_id),
|
||||
name=request.name,
|
||||
)
|
||||
|
||||
|
||||
@service_router.delete('/users/{user_id}/orgs/{org_id}/api-keys/{key_name}')
|
||||
async def delete_user_api_key(
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
key_name: str,
|
||||
x_service_api_key: str | None = Header(default=None, alias='X-Service-API-Key'),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete a system API key created by the service.
|
||||
|
||||
This endpoint allows the automations service to clean up API keys
|
||||
it previously created for users.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
org_id: The organization ID
|
||||
key_name: The name of the key to delete (without __SYSTEM__: prefix)
|
||||
x_service_api_key: Service API key header for authentication
|
||||
|
||||
Returns:
|
||||
dict: Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if service key is invalid
|
||||
HTTPException: 404 if key not found
|
||||
"""
|
||||
# Validate service API key
|
||||
service_id = await validate_service_api_key(x_service_api_key)
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
|
||||
# Delete the key by name (wrap with system key prefix since service creates system keys)
|
||||
system_key_name = api_key_store.make_system_key_name(key_name)
|
||||
success = await api_key_store.delete_api_key_by_name(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
allow_system=True,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f'API key with name "{key_name}" not found for user {user_id} in org {org_id}',
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Service deleted API key for user',
|
||||
extra={
|
||||
'service_id': service_id,
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': key_name,
|
||||
},
|
||||
)
|
||||
|
||||
return {'message': 'API key deleted successfully'}
|
||||
@@ -7,10 +7,8 @@ from server.auth.token_manager import TokenManager
|
||||
from storage.user_store import UserStore
|
||||
from utils.identity import resolve_display_name
|
||||
|
||||
from openhands.app_server.utils.dependencies import get_dependencies
|
||||
from openhands.integrations.provider import (
|
||||
PROVIDER_TOKEN_TYPE,
|
||||
ProviderHandler,
|
||||
)
|
||||
from openhands.integrations.service_types import (
|
||||
Branch,
|
||||
@@ -24,6 +22,7 @@ from openhands.microagent.types import (
|
||||
MicroagentContentResponse,
|
||||
MicroagentResponse,
|
||||
)
|
||||
from openhands.server.dependencies import get_dependencies
|
||||
from openhands.server.routes.git import (
|
||||
get_repository_branches,
|
||||
get_repository_microagent_content,
|
||||
@@ -68,53 +67,6 @@ async def saas_get_user_installations(
|
||||
)
|
||||
|
||||
|
||||
@saas_user_router.get('/git-organizations')
|
||||
async def saas_get_user_git_organizations(
|
||||
provider_tokens: PROVIDER_TOKEN_TYPE | None = Depends(get_provider_tokens),
|
||||
access_token: SecretStr | None = Depends(get_access_token),
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
):
|
||||
if not provider_tokens:
|
||||
retval = await _check_idp(
|
||||
access_token=access_token,
|
||||
default_value={},
|
||||
)
|
||||
if retval is not None:
|
||||
return retval
|
||||
# _check_idp returned None (tokens refreshed on Keycloak side),
|
||||
# but provider_tokens is still None for this request.
|
||||
return JSONResponse(
|
||||
content='Git provider token required.',
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
)
|
||||
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
|
||||
# SaaS users sign in with one provider at a time
|
||||
provider = next(iter(provider_tokens))
|
||||
|
||||
if provider == ProviderType.GITHUB:
|
||||
orgs = await client.get_github_organizations()
|
||||
elif provider == ProviderType.GITLAB:
|
||||
orgs = await client.get_gitlab_groups()
|
||||
elif provider == ProviderType.BITBUCKET:
|
||||
orgs = await client.get_bitbucket_workspaces()
|
||||
else:
|
||||
return JSONResponse(
|
||||
content=f"Provider {provider.value} doesn't support git organizations",
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
return {
|
||||
'provider': provider.value,
|
||||
'organizations': orgs,
|
||||
}
|
||||
|
||||
|
||||
@saas_user_router.get('/repositories', response_model=list[Repository])
|
||||
async def saas_get_user_repositories(
|
||||
sort: str = 'pushed',
|
||||
|
||||
@@ -1,445 +0,0 @@
|
||||
"""
|
||||
Service for forwarding GitHub webhook events to the automation service.
|
||||
|
||||
This service is optimized for high-traffic scenarios:
|
||||
1. Resolves GitHub org → OpenHands org_id (via cached OrgGitClaim lookup)
|
||||
2. For personal repos, resolves to personal org (via cached GitHub→Keycloak mapping)
|
||||
3. Forwards minimal payload to automation service (just org_id + payload)
|
||||
4. Access control checks are deferred to automation execution time
|
||||
|
||||
The lazy access control approach means:
|
||||
- Most webhooks only do cached lookups + HTTP forward
|
||||
- Membership checks only happen when an automation actually matches
|
||||
|
||||
Security notes:
|
||||
- Uses AUTOMATION_WEBHOOK_SECRET (not GitHub webhook secret) for internal service signing
|
||||
- Negative results are cached to prevent DoS via repeated lookups for unclaimed orgs
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from integrations.resolver_org_router import resolve_org_for_repo
|
||||
from server.auth.constants import (
|
||||
AUTOMATION_SERVICE_TIMEOUT,
|
||||
AUTOMATION_SERVICE_URL,
|
||||
AUTOMATION_WEBHOOK_SECRET,
|
||||
)
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderType
|
||||
from openhands.server.shared import sio
|
||||
|
||||
# Cache TTL constants
|
||||
ORG_CLAIM_CACHE_TTL_SECONDS = 3600 # 1 hour for org claims (rarely change)
|
||||
USER_ID_CACHE_TTL_SECONDS = 86400 # 24 hours for user ID mappings (never change)
|
||||
|
||||
# Cache key prefixes
|
||||
ORG_CLAIM_CACHE_PREFIX = 'automation:org_claim'
|
||||
USER_ID_CACHE_PREFIX = 'automation:gh_to_kc_user'
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrgContext:
|
||||
"""Context for the resolved organization."""
|
||||
|
||||
org_id: UUID
|
||||
github_org: str
|
||||
|
||||
|
||||
class AutomationEventService:
|
||||
"""
|
||||
Service for forwarding webhook events to the automation service.
|
||||
|
||||
Optimized for high traffic with:
|
||||
- Redis caching for org claim lookups (1 hour TTL)
|
||||
- Redis caching for GitHub→Keycloak user ID mappings (24 hour TTL)
|
||||
- Lazy access control (membership checks deferred to execution time)
|
||||
"""
|
||||
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
from server.auth.constants import AUTOMATION_EVENT_FORWARDING_ENABLED
|
||||
|
||||
self.token_manager = token_manager
|
||||
|
||||
# Fail fast if forwarding is enabled but misconfigured
|
||||
if AUTOMATION_EVENT_FORWARDING_ENABLED:
|
||||
if not AUTOMATION_SERVICE_URL:
|
||||
raise ValueError(
|
||||
'AUTOMATION_EVENT_FORWARDING_ENABLED=true but '
|
||||
'AUTOMATION_SERVICE_URL is not configured'
|
||||
)
|
||||
if not AUTOMATION_WEBHOOK_SECRET:
|
||||
raise ValueError(
|
||||
'AUTOMATION_EVENT_FORWARDING_ENABLED=true but '
|
||||
'AUTOMATION_WEBHOOK_SECRET is not configured'
|
||||
)
|
||||
|
||||
async def forward_github_event(
|
||||
self,
|
||||
payload: dict[str, Any],
|
||||
installation_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
Forward a GitHub webhook event to the automation service.
|
||||
|
||||
This is designed to be called as a fire-and-forget background task.
|
||||
The forward path is optimized for speed - only org resolution is done here.
|
||||
Access control checks are deferred to automation execution time.
|
||||
|
||||
Args:
|
||||
payload: The raw GitHub webhook payload
|
||||
installation_id: The GitHub App installation ID
|
||||
"""
|
||||
org_id: UUID | None = None
|
||||
try:
|
||||
# Resolve org context (org_id and github_org name) - uses Redis cache
|
||||
org_context = await self._resolve_org_context(payload)
|
||||
if not org_context:
|
||||
return
|
||||
|
||||
org_id = org_context.org_id
|
||||
|
||||
# Build minimal payload and forward immediately
|
||||
# Access control is NOT computed here - it's deferred to execution time
|
||||
event_payload = self._build_event_payload(org_context, payload)
|
||||
await self._send_to_automation_service(org_id, event_payload)
|
||||
|
||||
except (aiohttp.ClientError, asyncio.TimeoutError) as e:
|
||||
# Network errors are expected and recoverable
|
||||
logger.error(
|
||||
f'[AutomationEventService] Network error forwarding event '
|
||||
f'(org_id={org_id}): {e}',
|
||||
exc_info=True,
|
||||
extra={'installation_id': installation_id},
|
||||
)
|
||||
except Exception as e:
|
||||
# Log unexpected errors. Note: This is a background task, so exceptions
|
||||
# won't surface to the HTTP caller - they're logged for debugging only.
|
||||
logger.error(
|
||||
f'[AutomationEventService] Unexpected error forwarding event '
|
||||
f'(org_id={org_id}): {e}',
|
||||
exc_info=True,
|
||||
extra={'installation_id': installation_id},
|
||||
)
|
||||
# Don't re-raise in background task - just log for debugging
|
||||
|
||||
async def _resolve_org_context(self, payload: dict[str, Any]) -> OrgContext | None:
|
||||
"""
|
||||
Resolve the organization context from the webhook payload.
|
||||
|
||||
Uses Redis caching for both org claims and user ID mappings.
|
||||
Returns None if the org cannot be resolved (not claimed, no personal org).
|
||||
"""
|
||||
repo = payload.get('repository', {})
|
||||
owner = repo.get('owner', {})
|
||||
git_org_name = owner.get('login')
|
||||
owner_type = owner.get('type') # 'User' or 'Organization'
|
||||
|
||||
if not git_org_name:
|
||||
logger.warning(
|
||||
'[AutomationEventService] No repository owner in payload, skipping'
|
||||
)
|
||||
return None
|
||||
|
||||
# Try to resolve via OrgGitClaim
|
||||
org_id = await self._resolve_github_org(git_org_name)
|
||||
|
||||
# Fallback for personal repos
|
||||
if not org_id and owner_type == 'User':
|
||||
org_id = await self._resolve_personal_org(owner.get('id'))
|
||||
if org_id:
|
||||
logger.info(
|
||||
f'[AutomationEventService] Resolved personal repo owner '
|
||||
f'{git_org_name} to personal org {org_id}'
|
||||
)
|
||||
|
||||
if not org_id:
|
||||
logger.warning(
|
||||
f'[AutomationEventService] GitHub org {git_org_name} not claimed '
|
||||
f'and no personal org found, skipping'
|
||||
)
|
||||
return None
|
||||
|
||||
return OrgContext(org_id=org_id, github_org=git_org_name)
|
||||
|
||||
def _build_event_payload(
|
||||
self,
|
||||
org_context: OrgContext,
|
||||
payload: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build the minimal event payload to forward to the automation service.
|
||||
|
||||
Access control is NOT included here - it's deferred to execution time.
|
||||
This keeps the forward path fast for high-traffic scenarios.
|
||||
"""
|
||||
return {
|
||||
'organization': {
|
||||
'github_org': org_context.github_org,
|
||||
'openhands_org_id': str(org_context.org_id),
|
||||
},
|
||||
'payload': payload,
|
||||
}
|
||||
|
||||
# =========================================================================
|
||||
# Cached Org Resolution Methods
|
||||
# =========================================================================
|
||||
|
||||
async def _resolve_github_org(self, git_org_name: str) -> UUID | None:
|
||||
"""
|
||||
Resolve a GitHub organization name to an OpenHands org_id.
|
||||
|
||||
Uses Redis caching with 1-hour TTL. Caches both positive and negative
|
||||
results to avoid repeated DB queries for unclaimed orgs.
|
||||
|
||||
Note: GitHub org names are case-insensitive. We normalize to lowercase
|
||||
for both cache keys and DB queries. This matches the OrgGitClaim schema
|
||||
which stores git_organization as lowercase (enforced by GitOrgClaimRequest
|
||||
validator in org_models.py).
|
||||
"""
|
||||
normalized_org = git_org_name.lower()
|
||||
cache_key = f'{ORG_CLAIM_CACHE_PREFIX}:{normalized_org}'
|
||||
|
||||
# Check cache first
|
||||
cached = await self._get_cached_value(cache_key)
|
||||
if cached is not None:
|
||||
if cached == 'none':
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit (negative): org {git_org_name} not claimed'
|
||||
)
|
||||
return None
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit: org {git_org_name} -> {cached}'
|
||||
)
|
||||
return UUID(cached)
|
||||
|
||||
# Cache miss - use resolve_org_for_repo without user_id (no membership check)
|
||||
# Construct a minimal repo name since resolve_org_for_repo extracts the org
|
||||
org_id = await resolve_org_for_repo(
|
||||
provider='github',
|
||||
full_repo_name=f'{normalized_org}/',
|
||||
)
|
||||
|
||||
# Cache the result (including negative results)
|
||||
if org_id:
|
||||
await self._set_cached_value(
|
||||
cache_key, str(org_id), ORG_CLAIM_CACHE_TTL_SECONDS
|
||||
)
|
||||
return org_id
|
||||
else:
|
||||
# Cache negative result to avoid repeated DB queries
|
||||
await self._set_cached_value(cache_key, 'none', ORG_CLAIM_CACHE_TTL_SECONDS)
|
||||
return None
|
||||
|
||||
async def _resolve_personal_org(self, github_user_id: int | None) -> UUID | None:
|
||||
"""
|
||||
Resolve a GitHub user to their personal OpenHands org.
|
||||
|
||||
For personal repos (owner type is 'User'), the OpenHands org_id
|
||||
is the user's keycloak user ID. This allows users to set up
|
||||
automations on their personal repos without needing an OrgGitClaim.
|
||||
|
||||
Uses Redis caching for the GitHub→Keycloak user ID mapping (24h TTL).
|
||||
"""
|
||||
if not github_user_id:
|
||||
return None
|
||||
|
||||
keycloak_id = await self._get_keycloak_user_id_cached(github_user_id)
|
||||
if keycloak_id:
|
||||
return UUID(keycloak_id)
|
||||
return None
|
||||
|
||||
async def _get_keycloak_user_id_cached(self, github_user_id: int) -> str | None:
|
||||
"""
|
||||
Convert a GitHub user ID to a Keycloak user ID.
|
||||
|
||||
Uses Redis caching with 24-hour TTL since this mapping never changes.
|
||||
Caches negative results to avoid repeated Keycloak queries.
|
||||
"""
|
||||
cache_key = f'{USER_ID_CACHE_PREFIX}:{github_user_id}'
|
||||
|
||||
# Check cache first
|
||||
cached = await self._get_cached_value(cache_key)
|
||||
if cached is not None:
|
||||
if cached == 'none':
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit (negative): GitHub user {github_user_id} not in Keycloak'
|
||||
)
|
||||
return None
|
||||
logger.debug(
|
||||
f'[AutomationEventService] Cache hit: GitHub user {github_user_id} -> Keycloak {cached}'
|
||||
)
|
||||
return cached
|
||||
|
||||
# Cache miss - query Keycloak
|
||||
try:
|
||||
keycloak_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
str(github_user_id), ProviderType.GITHUB
|
||||
)
|
||||
|
||||
# Cache the result (including negative results)
|
||||
if keycloak_id:
|
||||
await self._set_cached_value(
|
||||
cache_key, keycloak_id, USER_ID_CACHE_TTL_SECONDS
|
||||
)
|
||||
else:
|
||||
# Cache negative result to prevent repeated Keycloak queries (DoS mitigation)
|
||||
await self._set_cached_value(
|
||||
cache_key, 'none', USER_ID_CACHE_TTL_SECONDS
|
||||
)
|
||||
|
||||
return keycloak_id
|
||||
except Exception as e:
|
||||
# Log at warning level to surface programmer errors and API issues
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Failed to get keycloak ID for GitHub user {github_user_id}: {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
# =========================================================================
|
||||
# Generic Redis Cache Helpers
|
||||
# =========================================================================
|
||||
|
||||
async def _get_cached_value(self, cache_key: str) -> str | None:
|
||||
"""
|
||||
Get a cached value from Redis.
|
||||
|
||||
Returns the cached string value, or None if not cached or Redis unavailable.
|
||||
Falls back to DB/API queries if Redis is unavailable (graceful degradation).
|
||||
|
||||
Warning: When Redis is unavailable, every webhook will hit the DB directly.
|
||||
Monitor logs for 'Redis unavailable' warnings to detect degradation.
|
||||
"""
|
||||
try:
|
||||
redis = getattr(sio.manager, 'redis', None)
|
||||
if not redis:
|
||||
# Log at warning level - this is a significant degradation that
|
||||
# will cause DB load. Monitor these logs for alerting.
|
||||
logger.warning(
|
||||
'[AutomationEventService] Redis unavailable for cache read, '
|
||||
'falling back to direct DB queries (this will increase DB load)'
|
||||
)
|
||||
return None
|
||||
|
||||
cached = await redis.get(cache_key)
|
||||
if cached is None:
|
||||
return None
|
||||
|
||||
# Redis returns bytes, decode to string
|
||||
return cached.decode('utf-8') if isinstance(cached, bytes) else cached
|
||||
except Exception as e:
|
||||
# Log at warning level - cache errors cause DB fallback
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Redis cache read error (falling back to DB): {e}'
|
||||
)
|
||||
return None
|
||||
|
||||
async def _set_cached_value(
|
||||
self, cache_key: str, value: str, ttl_seconds: int
|
||||
) -> None:
|
||||
"""
|
||||
Set a cached value in Redis with TTL.
|
||||
|
||||
Fails silently if Redis is unavailable (graceful degradation).
|
||||
"""
|
||||
try:
|
||||
redis = getattr(sio.manager, 'redis', None)
|
||||
if not redis:
|
||||
# Silent failure - read path already logs the warning
|
||||
return
|
||||
|
||||
await redis.setex(cache_key, ttl_seconds, value)
|
||||
except Exception as e:
|
||||
# Log at warning level for visibility
|
||||
logger.warning(f'[AutomationEventService] Redis cache write error: {e}')
|
||||
|
||||
def _sign_payload(self, payload_bytes: bytes) -> str:
|
||||
"""
|
||||
Sign a payload using the dedicated automation shared secret.
|
||||
|
||||
Uses AUTOMATION_WEBHOOK_SECRET (not GitHub webhook secret) to maintain
|
||||
separate trust boundaries between GitHub webhooks and internal services.
|
||||
|
||||
Returns the signature in the format 'sha256=<hex_digest>'.
|
||||
"""
|
||||
signature = hmac.new(
|
||||
AUTOMATION_WEBHOOK_SECRET.encode('utf-8'),
|
||||
msg=payload_bytes,
|
||||
digestmod=hashlib.sha256,
|
||||
).hexdigest()
|
||||
return f'sha256={signature}'
|
||||
|
||||
async def _send_to_automation_service(
|
||||
self,
|
||||
org_id: UUID,
|
||||
payload: dict[str, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Send the normalized payload to the automation service.
|
||||
|
||||
The payload is signed using AUTOMATION_WEBHOOK_SECRET so the
|
||||
automation service can verify it came from the OpenHands server.
|
||||
"""
|
||||
if not AUTOMATION_SERVICE_URL:
|
||||
logger.warning(
|
||||
'[AutomationEventService] AUTOMATION_SERVICE_URL not configured'
|
||||
)
|
||||
return
|
||||
|
||||
# Build endpoint URL. AUTOMATION_SERVICE_URL may include path segments
|
||||
# (e.g., https://example.com/api/automation), so we strip trailing slash
|
||||
# and append our path.
|
||||
url = f'{AUTOMATION_SERVICE_URL.rstrip("/")}/v1/events/{org_id}/github'
|
||||
|
||||
# Serialize payload to JSON bytes for signing
|
||||
payload_bytes = json.dumps(payload, separators=(',', ':')).encode('utf-8')
|
||||
signature = self._sign_payload(payload_bytes)
|
||||
|
||||
headers = {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Hub-Signature-256': signature,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
url,
|
||||
data=payload_bytes,
|
||||
headers=headers,
|
||||
timeout=aiohttp.ClientTimeout(total=AUTOMATION_SERVICE_TIMEOUT),
|
||||
) as resp:
|
||||
if resp.status >= 400:
|
||||
# Try JSON first (expected interface), fall back to text
|
||||
# for infrastructure errors (502/503 from load balancer)
|
||||
try:
|
||||
body = await resp.json()
|
||||
except (aiohttp.ContentTypeError, ValueError):
|
||||
body = await resp.text()
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Automation service returned '
|
||||
f'{resp.status} for org {org_id}: {body}'
|
||||
)
|
||||
else:
|
||||
data = await resp.json()
|
||||
matched = data.get('matched', 0)
|
||||
logger.info(
|
||||
f'[AutomationEventService] Forwarded event to org {org_id}: '
|
||||
f'{matched} automations matched'
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f'[AutomationEventService] Timeout ({AUTOMATION_SERVICE_TIMEOUT}s) '
|
||||
'forwarding to automation service'
|
||||
)
|
||||
except aiohttp.ClientError as e:
|
||||
logger.warning(
|
||||
f'[AutomationEventService] HTTP error forwarding to automation service: {e}'
|
||||
)
|
||||
@@ -1,171 +0,0 @@
|
||||
"""Service for managing organization member financial data."""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from server.routes.org_models import (
|
||||
OrgMemberFinancialPage,
|
||||
OrgMemberFinancialResponse,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgMemberFinancialService:
|
||||
"""Service for organization member financial data operations."""
|
||||
|
||||
@staticmethod
|
||||
async def get_org_members_financial_data(
|
||||
org_id: UUID,
|
||||
page_id: str | None = None,
|
||||
limit: int = 10,
|
||||
email_filter: str | None = None,
|
||||
) -> OrgMemberFinancialPage:
|
||||
"""Get paginated financial data for organization members.
|
||||
|
||||
Fetches member list from database and joins with financial data from LiteLLM.
|
||||
|
||||
Args:
|
||||
org_id: Organization UUID
|
||||
page_id: Offset encoded as string (e.g., "0", "10", "20")
|
||||
limit: Maximum items per page (default 10)
|
||||
email_filter: Optional case-insensitive partial email match
|
||||
|
||||
Returns:
|
||||
OrgMemberFinancialPage: Paginated response with financial data
|
||||
|
||||
Raises:
|
||||
ValueError: If page_id is invalid
|
||||
"""
|
||||
# Parse page_id to get offset
|
||||
offset = 0
|
||||
if page_id is not None:
|
||||
try:
|
||||
offset = int(page_id)
|
||||
if offset < 0:
|
||||
raise ValueError('page_id must be non-negative')
|
||||
except ValueError as e:
|
||||
raise ValueError(f'Invalid page_id: {page_id}') from e
|
||||
|
||||
# Fetch paginated members from database
|
||||
members, total_count = await OrgMemberStore.get_org_members_paginated(
|
||||
org_id=org_id,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
email_filter=email_filter,
|
||||
)
|
||||
|
||||
if not members:
|
||||
return OrgMemberFinancialPage(
|
||||
items=[],
|
||||
current_page=(offset // limit) + 1,
|
||||
per_page=limit,
|
||||
next_page_id=None,
|
||||
)
|
||||
|
||||
# Fetch financial data from LiteLLM for the entire team
|
||||
# This is a single API call that returns all team members' data
|
||||
try:
|
||||
financial_data = await LiteLlmManager.get_team_members_financial_data(
|
||||
str(org_id)
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
# Re-raise auth errors - these indicate configuration issues that need fixing
|
||||
if e.response.status_code in (401, 403):
|
||||
logger.error(
|
||||
'LiteLLM authentication/authorization failed',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'status_code': e.response.status_code,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise
|
||||
# For other HTTP errors (404, 500, etc.), use graceful degradation
|
||||
logger.warning(
|
||||
'Failed to fetch financial data from LiteLLM',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'status_code': e.response.status_code,
|
||||
'error_type': type(e).__name__,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
financial_data = {}
|
||||
except Exception as e:
|
||||
# For network errors, timeouts, etc., use graceful degradation
|
||||
logger.warning(
|
||||
'Failed to fetch financial data from LiteLLM',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'error_type': type(e).__name__,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
financial_data = {}
|
||||
|
||||
# Extract team-level data for shared budget calculation
|
||||
team_spend = financial_data.get('team_spend', 0) or 0
|
||||
members_financial = financial_data.get('members', {})
|
||||
|
||||
# Build response items by joining DB members with LiteLLM financial data
|
||||
items: list[OrgMemberFinancialResponse] = []
|
||||
for member in members:
|
||||
user = member.user
|
||||
user_id_str = str(member.user_id)
|
||||
|
||||
# Get financial data for this user (or defaults if not found)
|
||||
user_financial = members_financial.get(user_id_str, {})
|
||||
individual_spend = user_financial.get('spend', 0) or 0
|
||||
max_budget = user_financial.get('max_budget')
|
||||
uses_shared_budget = user_financial.get('uses_shared_budget', False)
|
||||
|
||||
# Calculate current budget (remaining)
|
||||
# For shared team budgets, use team_spend to calculate remaining budget
|
||||
# This ensures all members see the same remaining budget
|
||||
if max_budget is not None:
|
||||
if uses_shared_budget:
|
||||
# Shared budget - use team's total spend
|
||||
current_budget = max(max_budget - team_spend, 0)
|
||||
else:
|
||||
# Individual budget - use individual spend
|
||||
current_budget = max(max_budget - individual_spend, 0)
|
||||
else:
|
||||
# If no max_budget, current_budget is unlimited (represented as 0)
|
||||
current_budget = 0
|
||||
|
||||
items.append(
|
||||
OrgMemberFinancialResponse(
|
||||
user_id=user_id_str,
|
||||
email=user.email if user else None,
|
||||
lifetime_spend=individual_spend,
|
||||
current_budget=current_budget,
|
||||
max_budget=max_budget,
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate current page (1-indexed)
|
||||
current_page = (offset // limit) + 1
|
||||
|
||||
# Calculate next_page_id
|
||||
next_offset = offset + limit
|
||||
next_page_id = str(next_offset) if next_offset < total_count else None
|
||||
|
||||
logger.debug(
|
||||
'OrgMemberFinancialService:get_org_members_financial_data:success',
|
||||
extra={
|
||||
'org_id': str(org_id),
|
||||
'items_count': len(items),
|
||||
'current_page': current_page,
|
||||
'total_count': total_count,
|
||||
},
|
||||
)
|
||||
|
||||
return OrgMemberFinancialPage(
|
||||
items=items,
|
||||
current_page=current_page,
|
||||
per_page=limit,
|
||||
next_page_id=next_page_id,
|
||||
)
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Implementation of SharedEventService.
|
||||
|
||||
This implementation provides read-only access to events from shared conversations:
|
||||
- Validates that the conversation is shared before returning events
|
||||
- Uses existing EventService for actual event retrieval
|
||||
- Uses SharedConversationInfoService for shared conversation validation
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Request
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
)
|
||||
from server.sharing.sql_shared_conversation_info_service import (
|
||||
SQLSharedConversationInfoService,
|
||||
)
|
||||
|
||||
from openhands.agent_server.models import EventPage, EventSortOrder
|
||||
from openhands.app_server.config import get_global_config
|
||||
from openhands.app_server.event.event_service import EventService
|
||||
from openhands.app_server.event.filesystem_event_service import FilesystemEventService
|
||||
from openhands.app_server.event_callback.event_callback_models import EventKind
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.sdk import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilesystemSharedEventService(SharedEventService):
|
||||
"""Implementation of SharedEventService that validates shared access."""
|
||||
|
||||
shared_conversation_info_service: SharedConversationInfoService
|
||||
persistence_dir: Path
|
||||
|
||||
async def get_event_service(self, conversation_id: UUID) -> EventService | None:
|
||||
shared_conversation_info = (
|
||||
await self.shared_conversation_info_service.get_shared_conversation_info(
|
||||
conversation_id
|
||||
)
|
||||
)
|
||||
if shared_conversation_info is None:
|
||||
return None
|
||||
|
||||
return FilesystemEventService(
|
||||
prefix=self.persistence_dir,
|
||||
user_id=shared_conversation_info.created_by_user_id,
|
||||
app_conversation_info_service=None,
|
||||
app_conversation_info_load_tasks={},
|
||||
)
|
||||
|
||||
async def get_shared_event(
|
||||
self, conversation_id: UUID, event_id: UUID
|
||||
) -> Event | None:
|
||||
"""Given a conversation_id and event_id, retrieve an event if the conversation is shared."""
|
||||
# First check if the conversation is shared
|
||||
event_service = await self.get_event_service(conversation_id)
|
||||
if event_service is None:
|
||||
return None
|
||||
|
||||
# If conversation is shared, get the event
|
||||
return await event_service.get_event(conversation_id, event_id)
|
||||
|
||||
async def search_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
sort_order: EventSortOrder = EventSortOrder.TIMESTAMP,
|
||||
page_id: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> EventPage:
|
||||
"""Search events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
event_service = await self.get_event_service(conversation_id)
|
||||
if event_service is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return EventPage(items=[], next_page_id=None)
|
||||
|
||||
# If conversation is shared, search events for this conversation
|
||||
return await event_service.search_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
sort_order=sort_order,
|
||||
page_id=page_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
async def count_shared_events(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
kind__eq: EventKind | None = None,
|
||||
timestamp__gte: datetime | None = None,
|
||||
timestamp__lt: datetime | None = None,
|
||||
) -> int:
|
||||
"""Count events for a specific shared conversation."""
|
||||
# First check if the conversation is shared
|
||||
event_service = await self.get_event_service(conversation_id)
|
||||
if event_service is None:
|
||||
# Return empty page if conversation is not shared
|
||||
return 0
|
||||
|
||||
# If conversation is shared, count events for this conversation
|
||||
return await event_service.count_events(
|
||||
conversation_id=conversation_id,
|
||||
kind__eq=kind__eq,
|
||||
timestamp__gte=timestamp__gte,
|
||||
timestamp__lt=timestamp__lt,
|
||||
)
|
||||
|
||||
|
||||
class FilesystemSharedEventServiceInjector(SharedEventServiceInjector):
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[SharedEventService, None]:
|
||||
# Define inline to prevent circular lookup
|
||||
from openhands.app_server.config import get_db_session
|
||||
|
||||
async with get_db_session(state, request) as db_session:
|
||||
shared_conversation_info_service = SQLSharedConversationInfoService(
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
service = FilesystemSharedEventService(
|
||||
shared_conversation_info_service=shared_conversation_info_service,
|
||||
persistence_dir=get_global_config().persistence_dir,
|
||||
)
|
||||
yield service
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_conversation_info_service import (
|
||||
SharedConversationInfoService,
|
||||
)
|
||||
@@ -60,7 +60,7 @@ async def search_shared_conversations(
|
||||
Query(
|
||||
title='The max number of results in the page',
|
||||
gt=0,
|
||||
le=100,
|
||||
lte=100,
|
||||
),
|
||||
] = 100,
|
||||
include_sub_conversations: Annotated[
|
||||
@@ -72,6 +72,8 @@ async def search_shared_conversations(
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> SharedConversationPage:
|
||||
"""Search / List shared conversations."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_conversation_service.search_shared_conversation_info(
|
||||
title__contains=title__contains,
|
||||
created_at__gte=created_at__gte,
|
||||
@@ -125,11 +127,7 @@ async def batch_get_shared_conversations(
|
||||
shared_conversation_service: SharedConversationInfoService = shared_conversation_info_service_dependency,
|
||||
) -> list[SharedConversation | None]:
|
||||
"""Get a batch of shared conversations given their ids. Return None for any missing or non-shared."""
|
||||
if len(ids) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 conversations at once, got {len(ids)}',
|
||||
)
|
||||
assert len(ids) <= 100
|
||||
uuids = [UUID(id_) for id_ in ids]
|
||||
shared_conversation_info = (
|
||||
await shared_conversation_service.batch_get_shared_conversation_info(uuids)
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from typing import Annotated
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from server.sharing.shared_event_service import (
|
||||
SharedEventService,
|
||||
SharedEventServiceInjector,
|
||||
@@ -33,12 +33,6 @@ def get_shared_event_service_injector() -> SharedEventServiceInjector:
|
||||
)
|
||||
|
||||
return AwsSharedEventServiceInjector()
|
||||
elif provider == StorageProvider.FILESYSTEM:
|
||||
from server.sharing.filesystem_shared_event_service import (
|
||||
FilesystemSharedEventServiceInjector,
|
||||
)
|
||||
|
||||
return FilesystemSharedEventServiceInjector()
|
||||
else:
|
||||
# GCP is the default for shared events (including filesystem fallback)
|
||||
from server.sharing.google_cloud_shared_event_service import (
|
||||
@@ -83,11 +77,13 @@ async def search_shared_events(
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int,
|
||||
Query(title='The max number of results in the page', gt=0, le=100),
|
||||
Query(title='The max number of results in the page', gt=0, lte=100),
|
||||
] = 100,
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> EventPage:
|
||||
"""Search / List events for a shared conversation."""
|
||||
assert limit > 0
|
||||
assert limit <= 100
|
||||
return await shared_event_service.search_shared_events(
|
||||
conversation_id=UUID(conversation_id),
|
||||
kind__eq=kind__eq,
|
||||
@@ -138,11 +134,7 @@ async def batch_get_shared_events(
|
||||
shared_event_service: SharedEventService = shared_event_service_dependency,
|
||||
) -> list[Event | None]:
|
||||
"""Get a batch of events for a shared conversation given their ids, returning null for any missing event."""
|
||||
if len(id) > 100:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f'Cannot request more than 100 events at once, got {len(id)}',
|
||||
)
|
||||
assert len(id) <= 100
|
||||
event_ids = [UUID(id_) for id_ in id]
|
||||
events = await shared_event_service.batch_get_shared_events(
|
||||
UUID(conversation_id), event_ids
|
||||
|
||||
@@ -354,20 +354,6 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
user = result.scalar_one_or_none()
|
||||
assert user
|
||||
|
||||
# Determine org_id: prefer API key's org_id if authenticated via API key
|
||||
org_id = user.current_org_id # Default fallback
|
||||
if hasattr(self.user_context, 'user_auth'):
|
||||
user_auth = self.user_context.user_auth
|
||||
if hasattr(user_auth, 'get_api_key_org_id'):
|
||||
api_key_org_id = user_auth.get_api_key_org_id()
|
||||
if api_key_org_id is not None:
|
||||
org_id = api_key_org_id
|
||||
|
||||
# Override with resolver org_id if set (from git org claim resolution)
|
||||
resolver_org_id = getattr(self.user_context, 'resolver_org_id', None)
|
||||
if resolver_org_id is not None:
|
||||
org_id = resolver_org_id
|
||||
|
||||
# Check if SAAS metadata already exists
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(info.id)
|
||||
@@ -376,15 +362,16 @@ class SaasSQLAppConversationInfoService(SQLAppConversationInfoService):
|
||||
existing_saas_metadata = result.scalar_one_or_none()
|
||||
assert existing_saas_metadata is None or (
|
||||
existing_saas_metadata.user_id == user_id_uuid
|
||||
and existing_saas_metadata.org_id == org_id
|
||||
and existing_saas_metadata.org_id == user.current_org_id
|
||||
)
|
||||
|
||||
if not existing_saas_metadata:
|
||||
# Create new SAAS metadata with the determined org_id
|
||||
# Create new SAAS metadata
|
||||
# Set org_id to user_id as specified in requirements
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=str(info.id),
|
||||
user_id=user_id_uuid,
|
||||
org_id=org_id,
|
||||
org_id=user.current_org_id,
|
||||
)
|
||||
self.db_session.add(saas_metadata)
|
||||
|
||||
|
||||
@@ -29,10 +29,7 @@ def get_cookie_domain() -> str | None:
|
||||
|
||||
|
||||
def get_cookie_samesite() -> Literal['lax', 'strict']:
|
||||
# Use 'strict' in production for maximum CSRF protection
|
||||
# Use 'lax' for local development and staging environments
|
||||
# Note: For invitation links from emails, the frontend handles acceptance via
|
||||
# an authenticated POST request (same-origin), which works with 'strict' cookies
|
||||
# for localhost and feature/staging stacks we set it to 'lax' as the cookie domain won't allow 'strict'
|
||||
web_url = get_global_config().web_url
|
||||
return (
|
||||
'strict'
|
||||
|
||||
@@ -17,7 +17,7 @@ from server.verified_models.verified_model_service import (
|
||||
|
||||
from openhands.app_server.config import get_db_session
|
||||
from openhands.server.routes import public
|
||||
from openhands.utils.llm import ModelsResponse, get_supported_llm_models
|
||||
from openhands.utils.llm import get_supported_llm_models
|
||||
|
||||
api_router = APIRouter(prefix='/api/admin/verified-models', tags=['Verified Models'])
|
||||
|
||||
@@ -117,7 +117,7 @@ async def delete_verified_model(
|
||||
)
|
||||
|
||||
|
||||
async def get_saas_llm_models_dependency(request: Request) -> ModelsResponse:
|
||||
async def get_saas_llm_models_dependency(request: Request) -> list[str]:
|
||||
"""SaaS implementation for the LLM models endpoint."""
|
||||
async with get_db_session(request.state, request) as db_session:
|
||||
# Prevent circular import
|
||||
|
||||
@@ -19,7 +19,6 @@ from storage.linear_workspace import LinearWorkspace
|
||||
from storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.org import Org
|
||||
from storage.org_git_claim import OrgGitClaim
|
||||
from storage.org_invitation import OrgInvitation
|
||||
from storage.org_member import OrgMember
|
||||
from storage.proactive_convos import ProactiveConversation
|
||||
@@ -66,7 +65,6 @@ __all__ = [
|
||||
'MaintenanceTaskStatus',
|
||||
'OpenhandsPR',
|
||||
'Org',
|
||||
'OrgGitClaim',
|
||||
'OrgInvitation',
|
||||
'OrgMember',
|
||||
'ProactiveConversation',
|
||||
|
||||
@@ -4,7 +4,6 @@ import secrets
|
||||
import string
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from storage.api_key import ApiKey
|
||||
@@ -14,22 +13,9 @@ from storage.user_store import UserStore
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyValidationResult:
|
||||
"""Result of API key validation containing user and organization info."""
|
||||
|
||||
user_id: str
|
||||
org_id: UUID | None # None for legacy API keys without org binding
|
||||
key_id: int
|
||||
key_name: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ApiKeyStore:
|
||||
API_KEY_PREFIX = 'sk-oh-'
|
||||
# Prefix for system keys created by internal services (e.g., automations)
|
||||
# Keys with this prefix are hidden from users and cannot be deleted by users
|
||||
SYSTEM_KEY_NAME_PREFIX = '__SYSTEM__:'
|
||||
|
||||
def generate_api_key(self, length: int = 32) -> str:
|
||||
"""Generate a random API key with the sk-oh- prefix."""
|
||||
@@ -37,19 +23,6 @@ class ApiKeyStore:
|
||||
random_part = ''.join(secrets.choice(alphabet) for _ in range(length))
|
||||
return f'{self.API_KEY_PREFIX}{random_part}'
|
||||
|
||||
@classmethod
|
||||
def is_system_key_name(cls, name: str | None) -> bool:
|
||||
"""Check if a key name indicates a system key."""
|
||||
return name is not None and name.startswith(cls.SYSTEM_KEY_NAME_PREFIX)
|
||||
|
||||
@classmethod
|
||||
def make_system_key_name(cls, name: str) -> str:
|
||||
"""Create a system key name with the appropriate prefix.
|
||||
|
||||
Format: __SYSTEM__:<name>
|
||||
"""
|
||||
return f'{cls.SYSTEM_KEY_NAME_PREFIX}{name}'
|
||||
|
||||
async def create_api_key(
|
||||
self, user_id: str, name: str | None = None, expires_at: datetime | None = None
|
||||
) -> str:
|
||||
@@ -87,120 +60,8 @@ class ApiKeyStore:
|
||||
|
||||
return api_key
|
||||
|
||||
async def get_or_create_system_api_key(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
name: str,
|
||||
) -> str:
|
||||
"""Get or create a system API key for a user on behalf of an internal service.
|
||||
|
||||
If a key with the given name already exists for this user/org and is not expired,
|
||||
returns the existing key. Otherwise, creates a new key (and deletes any expired one).
|
||||
|
||||
System keys are:
|
||||
- Not visible to users in their API keys list (filtered by name prefix)
|
||||
- Not deletable by users (protected by name prefix check)
|
||||
- Associated with a specific org (not the user's current org)
|
||||
- Never expire (no expiration date)
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to create the key for
|
||||
org_id: The organization ID to associate the key with
|
||||
name: Required name for the key (will be prefixed with __SYSTEM__:)
|
||||
|
||||
Returns:
|
||||
The API key (existing or newly created)
|
||||
"""
|
||||
# Create system key name with prefix
|
||||
system_key_name = self.make_system_key_name(name)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Check if key already exists for this user/org/name
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id,
|
||||
ApiKey.org_id == org_id,
|
||||
ApiKey.name == system_key_name,
|
||||
)
|
||||
)
|
||||
existing_key = result.scalars().first()
|
||||
|
||||
if existing_key:
|
||||
# Check if expired
|
||||
if existing_key.expires_at:
|
||||
now = datetime.now(UTC)
|
||||
expires_at = existing_key.expires_at
|
||||
if expires_at.tzinfo is None:
|
||||
expires_at = expires_at.replace(tzinfo=UTC)
|
||||
|
||||
if expires_at < now:
|
||||
# Key is expired, delete it and create new one
|
||||
logger.info(
|
||||
'System API key expired, re-issuing',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
await session.delete(existing_key)
|
||||
await session.commit()
|
||||
else:
|
||||
# Key exists and is not expired, return it
|
||||
logger.debug(
|
||||
'Returning existing system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
return existing_key.key
|
||||
else:
|
||||
# Key exists and has no expiration, return it
|
||||
logger.debug(
|
||||
'Returning existing system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
return existing_key.key
|
||||
|
||||
# Create new key (no expiration)
|
||||
api_key = self.generate_api_key()
|
||||
|
||||
async with a_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
expires_at=None, # System keys never expire
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'Created system API key',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'key_name': system_key_name,
|
||||
},
|
||||
)
|
||||
|
||||
return api_key
|
||||
|
||||
async def validate_api_key(self, api_key: str) -> ApiKeyValidationResult | None:
|
||||
"""Validate an API key and return the associated user_id and org_id if valid.
|
||||
|
||||
Returns:
|
||||
ApiKeyValidationResult if the key is valid, None otherwise.
|
||||
The org_id may be None for legacy API keys that weren't bound to an organization.
|
||||
"""
|
||||
async def validate_api_key(self, api_key: str) -> str | None:
|
||||
"""Validate an API key and return the associated user_id if valid."""
|
||||
now = datetime.now(UTC)
|
||||
|
||||
async with a_session_maker() as session:
|
||||
@@ -228,12 +89,7 @@ class ApiKeyStore:
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return ApiKeyValidationResult(
|
||||
user_id=key_record.user_id,
|
||||
org_id=key_record.org_id,
|
||||
key_id=key_record.id,
|
||||
key_name=key_record.name,
|
||||
)
|
||||
return key_record.user_id
|
||||
|
||||
async def delete_api_key(self, api_key: str) -> bool:
|
||||
"""Delete an API key by the key value."""
|
||||
@@ -249,18 +105,8 @@ class ApiKeyStore:
|
||||
|
||||
return True
|
||||
|
||||
async def delete_api_key_by_id(
|
||||
self, key_id: int, allow_system: bool = False
|
||||
) -> bool:
|
||||
"""Delete an API key by its ID.
|
||||
|
||||
Args:
|
||||
key_id: The ID of the key to delete
|
||||
allow_system: If False (default), system keys cannot be deleted
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found or is a protected system key
|
||||
"""
|
||||
async def delete_api_key_by_id(self, key_id: int) -> bool:
|
||||
"""Delete an API key by its ID."""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
@@ -268,26 +114,13 @@ class ApiKeyStore:
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
# Protect system keys from deletion unless explicitly allowed
|
||||
if self.is_system_key_name(key_record.name) and not allow_system:
|
||||
logger.warning(
|
||||
'Attempted to delete system API key',
|
||||
extra={'key_id': key_id, 'user_id': key_record.user_id},
|
||||
)
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
async def list_api_keys(self, user_id: str) -> list[ApiKey]:
|
||||
"""List all user-visible API keys for a user.
|
||||
|
||||
This excludes:
|
||||
- System keys (name starts with __SYSTEM__:) - created by internal services
|
||||
- MCP_API_KEY - internal MCP key
|
||||
"""
|
||||
"""List all API keys for a user."""
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f'User not found: {user_id}')
|
||||
@@ -296,17 +129,11 @@ class ApiKeyStore:
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(
|
||||
ApiKey.user_id == user_id,
|
||||
ApiKey.org_id == org_id,
|
||||
ApiKey.user_id == user_id, ApiKey.org_id == org_id
|
||||
)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
# Filter out system keys and MCP_API_KEY
|
||||
return [
|
||||
key
|
||||
for key in keys
|
||||
if key.name != 'MCP_API_KEY' and not self.is_system_key_name(key.name)
|
||||
]
|
||||
return [key for key in keys if key.name != 'MCP_API_KEY']
|
||||
|
||||
async def retrieve_mcp_api_key(self, user_id: str) -> str | None:
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
@@ -336,44 +163,17 @@ class ApiKeyStore:
|
||||
key_record = result.scalars().first()
|
||||
return key_record.key if key_record else None
|
||||
|
||||
async def delete_api_key_by_name(
|
||||
self,
|
||||
user_id: str,
|
||||
name: str,
|
||||
org_id: UUID | None = None,
|
||||
allow_system: bool = False,
|
||||
) -> bool:
|
||||
"""Delete an API key by name for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose key to delete
|
||||
name: The name of the key to delete
|
||||
org_id: Optional organization ID to filter by (required for system keys)
|
||||
allow_system: If False (default), system keys cannot be deleted
|
||||
|
||||
Returns:
|
||||
True if the key was deleted, False if not found or is a protected system key
|
||||
"""
|
||||
async def delete_api_key_by_name(self, user_id: str, name: str) -> bool:
|
||||
"""Delete an API key by name for a specific user."""
|
||||
async with a_session_maker() as session:
|
||||
# Build the query filters
|
||||
filters = [ApiKey.user_id == user_id, ApiKey.name == name]
|
||||
if org_id is not None:
|
||||
filters.append(ApiKey.org_id == org_id)
|
||||
|
||||
result = await session.execute(select(ApiKey).filter(*filters))
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.user_id == user_id, ApiKey.name == name)
|
||||
)
|
||||
key_record = result.scalars().first()
|
||||
|
||||
if not key_record:
|
||||
return False
|
||||
|
||||
# Protect system keys from deletion unless explicitly allowed
|
||||
if self.is_system_key_name(key_record.name) and not allow_system:
|
||||
logger.warning(
|
||||
'Attempted to delete system API key',
|
||||
extra={'user_id': user_id, 'key_name': name},
|
||||
)
|
||||
return False
|
||||
|
||||
await session.delete(key_record)
|
||||
await session.commit()
|
||||
|
||||
|
||||
@@ -29,37 +29,14 @@ KEY_VERIFICATION_TIMEOUT = 5.0
|
||||
# A very large number to represent "unlimited" until LiteLLM fixes their unlimited update bug.
|
||||
UNLIMITED_BUDGET_SETTING = 1000000000.0
|
||||
|
||||
# Check if billing is enabled (defaults to false for enterprise deployments)
|
||||
ENABLE_BILLING = os.environ.get('ENABLE_BILLING', 'false').lower() == 'true'
|
||||
|
||||
|
||||
def _get_default_initial_budget() -> float | None:
|
||||
"""Get the default initial budget for new teams.
|
||||
|
||||
When billing is disabled (ENABLE_BILLING=false), returns None to disable
|
||||
budget enforcement in LiteLLM. When billing is enabled, returns the
|
||||
DEFAULT_INITIAL_BUDGET environment variable value (default 0.0).
|
||||
|
||||
Returns:
|
||||
float | None: The default budget, or None to disable budget enforcement.
|
||||
"""
|
||||
if not ENABLE_BILLING:
|
||||
return None
|
||||
|
||||
try:
|
||||
budget = float(os.environ.get('DEFAULT_INITIAL_BUDGET', 0.0))
|
||||
if budget < 0:
|
||||
raise ValueError(
|
||||
f'DEFAULT_INITIAL_BUDGET must be non-negative, got {budget}'
|
||||
)
|
||||
return budget
|
||||
except ValueError as e:
|
||||
try:
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', 0.0))
|
||||
if DEFAULT_INITIAL_BUDGET < 0:
|
||||
raise ValueError(
|
||||
f'Invalid DEFAULT_INITIAL_BUDGET environment variable: {e}'
|
||||
) from e
|
||||
|
||||
|
||||
DEFAULT_INITIAL_BUDGET: float | None = _get_default_initial_budget()
|
||||
f'DEFAULT_INITIAL_BUDGET must be non-negative, got {DEFAULT_INITIAL_BUDGET}'
|
||||
)
|
||||
except ValueError as e:
|
||||
raise ValueError(f'Invalid DEFAULT_INITIAL_BUDGET environment variable: {e}') from e
|
||||
|
||||
|
||||
def get_openhands_cloud_key_alias(keycloak_user_id: str, org_id: str) -> str:
|
||||
@@ -133,15 +110,12 @@ class LiteLlmManager:
|
||||
) as client:
|
||||
# Check if team already exists and get its budget
|
||||
# New users joining existing orgs should inherit the team's budget
|
||||
# When billing is disabled, DEFAULT_INITIAL_BUDGET is None
|
||||
team_budget: float | None = DEFAULT_INITIAL_BUDGET
|
||||
team_budget: float = DEFAULT_INITIAL_BUDGET
|
||||
try:
|
||||
existing_team = await LiteLlmManager._get_team(client, org_id)
|
||||
if existing_team:
|
||||
team_info = existing_team.get('team_info', {})
|
||||
# Preserve None from existing team (no budget enforcement)
|
||||
existing_budget = team_info.get('max_budget')
|
||||
team_budget = existing_budget
|
||||
team_budget = team_info.get('max_budget', 0.0) or 0.0
|
||||
logger.info(
|
||||
'LiteLlmManager:create_entries:existing_team_budget',
|
||||
extra={
|
||||
@@ -164,33 +138,9 @@ class LiteLlmManager:
|
||||
)
|
||||
|
||||
if create_user:
|
||||
user_created = await LiteLlmManager._create_user(
|
||||
await LiteLlmManager._create_user(
|
||||
client, keycloak_user_info.get('email'), keycloak_user_id
|
||||
)
|
||||
if not user_created:
|
||||
logger.error(
|
||||
'create_entries_failed_user_creation',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Verify user exists before proceeding with key generation
|
||||
user_exists = await LiteLlmManager._user_exists(
|
||||
client, keycloak_user_id
|
||||
)
|
||||
if not user_exists:
|
||||
logger.error(
|
||||
'create_entries_user_not_found_before_key_generation',
|
||||
extra={
|
||||
'org_id': org_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'create_user_flag': create_user,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
await LiteLlmManager._add_user_to_team(
|
||||
client, keycloak_user_id, org_id, team_budget
|
||||
@@ -575,40 +525,25 @@ class LiteLlmManager:
|
||||
client: httpx.AsyncClient,
|
||||
team_alias: str,
|
||||
team_id: str,
|
||||
max_budget: float | None,
|
||||
max_budget: float,
|
||||
):
|
||||
"""Create a new team in LiteLLM.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use.
|
||||
team_alias: The alias for the team.
|
||||
team_id: The ID for the team.
|
||||
max_budget: The maximum budget for the team. When None, budget
|
||||
enforcement is disabled (unlimited usage).
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/new',
|
||||
json=json_data,
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'team_alias': team_alias,
|
||||
'models': [],
|
||||
'max_budget': max_budget,
|
||||
'spend': 0,
|
||||
'metadata': {
|
||||
'version': ORG_SETTINGS_VERSION,
|
||||
'model': get_default_litellm_model(),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Team failed to create in litellm - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
@@ -685,48 +620,15 @@ class LiteLlmManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
async def _user_exists(
|
||||
client: httpx.AsyncClient,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""Check if a user exists in LiteLLM.
|
||||
|
||||
Returns True if the user exists, False otherwise.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
return False
|
||||
try:
|
||||
response = await client.get(
|
||||
f'{LITE_LLM_API_URL}/user/info?user_id={user_id}',
|
||||
)
|
||||
if response.is_success:
|
||||
user_data = response.json()
|
||||
# Check that user_info exists and has the user_id
|
||||
user_info = user_data.get('user_info', {})
|
||||
return user_info.get('user_id') == user_id
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'litellm_user_exists_check_failed',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def _create_user(
|
||||
client: httpx.AsyncClient,
|
||||
email: str | None,
|
||||
keycloak_user_id: str,
|
||||
) -> bool:
|
||||
"""Create a user in LiteLLM.
|
||||
|
||||
Returns True if the user was created or already exists and is verified,
|
||||
False if creation failed and user does not exist.
|
||||
"""
|
||||
):
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return False
|
||||
return
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/user/new',
|
||||
json={
|
||||
@@ -779,33 +681,17 @@ class LiteLlmManager:
|
||||
'user_id': keycloak_user_id,
|
||||
},
|
||||
)
|
||||
# Verify the user actually exists before returning success
|
||||
user_exists = await LiteLlmManager._user_exists(
|
||||
client, keycloak_user_id
|
||||
)
|
||||
if not user_exists:
|
||||
logger.error(
|
||||
'litellm_user_claimed_exists_but_not_found',
|
||||
extra={
|
||||
'user_id': keycloak_user_id,
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
},
|
||||
)
|
||||
return False
|
||||
return True
|
||||
return
|
||||
logger.error(
|
||||
'error_creating_litellm_user',
|
||||
extra={
|
||||
'status_code': response.status_code,
|
||||
'text': response.text,
|
||||
'user_id': keycloak_user_id,
|
||||
'user_id': [keycloak_user_id],
|
||||
'email': None,
|
||||
},
|
||||
)
|
||||
return False
|
||||
response.raise_for_status()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _get_user(client: httpx.AsyncClient, user_id: str) -> dict | None:
|
||||
@@ -1032,34 +918,19 @@ class LiteLlmManager:
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float | None,
|
||||
max_budget: float,
|
||||
):
|
||||
"""Add a user to a team in LiteLLM.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use.
|
||||
keycloak_user_id: The user's Keycloak ID.
|
||||
team_id: The team ID.
|
||||
max_budget: The maximum budget for the user in the team. When None,
|
||||
budget enforcement is disabled (unlimited usage).
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget_in_team'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_add',
|
||||
json=json_data,
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'member': {'user_id': keycloak_user_id, 'role': 'user'},
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
|
||||
# Failed to add user to team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
if (
|
||||
@@ -1127,34 +998,19 @@ class LiteLlmManager:
|
||||
client: httpx.AsyncClient,
|
||||
keycloak_user_id: str,
|
||||
team_id: str,
|
||||
max_budget: float | None,
|
||||
max_budget: float,
|
||||
):
|
||||
"""Update a user's budget in a team.
|
||||
|
||||
Args:
|
||||
client: The HTTP client to use.
|
||||
keycloak_user_id: The user's Keycloak ID.
|
||||
team_id: The team ID.
|
||||
max_budget: The maximum budget for the user in the team. When None,
|
||||
budget enforcement is disabled (unlimited usage).
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return
|
||||
|
||||
json_data: dict[str, Any] = {
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
}
|
||||
|
||||
if max_budget is not None:
|
||||
json_data['max_budget_in_team'] = max_budget
|
||||
|
||||
response = await client.post(
|
||||
f'{LITE_LLM_API_URL}/team/member_update',
|
||||
json=json_data,
|
||||
json={
|
||||
'team_id': team_id,
|
||||
'user_id': keycloak_user_id,
|
||||
'max_budget_in_team': max_budget,
|
||||
},
|
||||
)
|
||||
|
||||
# Failed to update user in team - this is an unforseen error state...
|
||||
if not response.is_success:
|
||||
logger.error(
|
||||
@@ -1524,83 +1380,6 @@ class LiteLlmManager:
|
||||
'LiteLlmManager:_delete_key:key_deleted',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _get_team_members_financial_data(
|
||||
client: httpx.AsyncClient,
|
||||
team_id: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Get financial data for all members in a team.
|
||||
|
||||
Fetches team info from LiteLLM and extracts spending/budget data for each member.
|
||||
|
||||
Args:
|
||||
client: HTTP client for LiteLLM API
|
||||
team_id: The team/organization ID
|
||||
|
||||
Returns:
|
||||
Dict with structure:
|
||||
{
|
||||
"team_max_budget": float | None, # Team's shared budget
|
||||
"team_spend": float, # Team's total spend (for shared budget calc)
|
||||
"members": {
|
||||
user_id: {
|
||||
"spend": float,
|
||||
"max_budget": float | None,
|
||||
"uses_shared_budget": bool # True if using team budget
|
||||
},
|
||||
...
|
||||
}
|
||||
}
|
||||
Returns empty dict if team not found or LiteLLM is not configured.
|
||||
"""
|
||||
if LITE_LLM_API_KEY is None or LITE_LLM_API_URL is None:
|
||||
logger.warning('LiteLLM API configuration not found')
|
||||
return {}
|
||||
|
||||
team_info = await LiteLlmManager._get_team(client, team_id)
|
||||
if not team_info:
|
||||
logger.warning(
|
||||
'LiteLlmManager:_get_team_members_financial_data:team_not_found',
|
||||
extra={'team_id': team_id},
|
||||
)
|
||||
return {}
|
||||
|
||||
members: dict[str, dict] = {}
|
||||
team_memberships = team_info.get('team_memberships', [])
|
||||
|
||||
# Get team-level budget info (shared across all members in team orgs)
|
||||
team_data = team_info.get('team_info', {})
|
||||
team_max_budget = team_data.get('max_budget')
|
||||
team_spend = team_data.get('spend', 0) or 0
|
||||
|
||||
for membership in team_memberships:
|
||||
user_id = membership.get('user_id')
|
||||
if not user_id:
|
||||
continue
|
||||
|
||||
# Use individual max_budget_in_team if set, otherwise fall back to team budget
|
||||
member_max_budget = membership.get('max_budget_in_team')
|
||||
uses_shared_budget = member_max_budget is None
|
||||
if uses_shared_budget:
|
||||
member_max_budget = team_max_budget
|
||||
|
||||
members[user_id] = {
|
||||
'spend': membership.get('spend', 0) or 0,
|
||||
'max_budget': member_max_budget,
|
||||
'uses_shared_budget': uses_shared_budget,
|
||||
}
|
||||
|
||||
logger.debug(
|
||||
'LiteLlmManager:_get_team_members_financial_data:success',
|
||||
extra={'team_id': team_id, 'member_count': len(members)},
|
||||
)
|
||||
return {
|
||||
'team_max_budget': team_max_budget,
|
||||
'team_spend': team_spend,
|
||||
'members': members,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def with_http_client(
|
||||
internal_fn: Callable[..., Awaitable[Any]],
|
||||
@@ -1608,8 +1387,7 @@ class LiteLlmManager:
|
||||
@functools.wraps(internal_fn)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with httpx.AsyncClient(
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY},
|
||||
timeout=httpx.Timeout(30.0),
|
||||
headers={'x-goog-api-key': LITE_LLM_API_KEY}
|
||||
) as client:
|
||||
return await internal_fn(client, *args, **kwargs)
|
||||
|
||||
@@ -1619,7 +1397,6 @@ class LiteLlmManager:
|
||||
create_team = staticmethod(with_http_client(_create_team))
|
||||
get_team = staticmethod(with_http_client(_get_team))
|
||||
update_team = staticmethod(with_http_client(_update_team))
|
||||
user_exists = staticmethod(with_http_client(_user_exists))
|
||||
create_user = staticmethod(with_http_client(_create_user))
|
||||
get_user = staticmethod(with_http_client(_get_user))
|
||||
update_user = staticmethod(with_http_client(_update_user))
|
||||
@@ -1636,6 +1413,3 @@ class LiteLlmManager:
|
||||
get_user_keys = staticmethod(with_http_client(_get_user_keys))
|
||||
delete_key_by_alias = staticmethod(with_http_client(_delete_key_by_alias))
|
||||
update_user_keys = staticmethod(with_http_client(_update_user_keys))
|
||||
get_team_members_financial_data = staticmethod(
|
||||
with_http_client(_get_team_members_financial_data)
|
||||
)
|
||||
|
||||
@@ -64,7 +64,6 @@ class Org(Base): # type: ignore
|
||||
slack_conversations = relationship('SlackConversation', back_populates='org')
|
||||
slack_users = relationship('SlackUser', back_populates='org')
|
||||
stripe_customers = relationship('StripeCustomer', back_populates='org')
|
||||
git_claims = relationship('OrgGitClaim', back_populates='org')
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Handle known SQLAlchemy columns directly
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""
|
||||
SQLAlchemy model for Git Organization Claims.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import UUID, Column, DateTime, ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
|
||||
|
||||
class OrgGitClaim(Base): # type: ignore
|
||||
"""Model for tracking which OpenHands org has claimed a Git organization."""
|
||||
|
||||
__tablename__ = 'org_git_claim'
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
|
||||
org_id = Column(
|
||||
UUID(as_uuid=True), ForeignKey('org.id', ondelete='CASCADE'), nullable=False
|
||||
)
|
||||
provider = Column(String, nullable=False)
|
||||
git_organization = Column(String, nullable=False)
|
||||
claimed_by = Column(UUID(as_uuid=True), ForeignKey('user.id'), nullable=False)
|
||||
claimed_at = Column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint('provider', 'git_organization', name='uq_provider_git_org'),
|
||||
)
|
||||
|
||||
org = relationship('Org', back_populates='git_claims')
|
||||
@@ -1,141 +0,0 @@
|
||||
"""
|
||||
Store class for managing Git organization claims.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_, select
|
||||
from storage.database import a_session_maker
|
||||
from storage.org_git_claim import OrgGitClaim
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class OrgGitClaimStore:
|
||||
"""Store for managing Git organization claims."""
|
||||
|
||||
@staticmethod
|
||||
async def create_claim(
|
||||
org_id: UUID,
|
||||
provider: str,
|
||||
git_organization: str,
|
||||
claimed_by: UUID,
|
||||
) -> OrgGitClaim:
|
||||
"""Create a new Git organization claim.
|
||||
|
||||
Args:
|
||||
org_id: OpenHands organization UUID
|
||||
provider: Git provider ('github', 'gitlab', 'bitbucket')
|
||||
git_organization: Name of the Git organization being claimed
|
||||
claimed_by: User UUID who is making the claim
|
||||
|
||||
Returns:
|
||||
OrgGitClaim: The created claim record
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
claim = OrgGitClaim(
|
||||
org_id=org_id,
|
||||
provider=provider,
|
||||
git_organization=git_organization,
|
||||
claimed_by=claimed_by,
|
||||
claimed_at=datetime.now(timezone.utc),
|
||||
)
|
||||
session.add(claim)
|
||||
await session.commit()
|
||||
await session.refresh(claim)
|
||||
|
||||
logger.info(
|
||||
'Created Git organization claim',
|
||||
extra={
|
||||
'claim_id': str(claim.id),
|
||||
'org_id': str(org_id),
|
||||
'provider': provider,
|
||||
'git_organization': git_organization,
|
||||
'claimed_by': str(claimed_by),
|
||||
},
|
||||
)
|
||||
|
||||
return claim
|
||||
|
||||
@staticmethod
|
||||
async def get_claim_by_provider_and_git_org(
|
||||
provider: str,
|
||||
git_organization: str,
|
||||
) -> Optional[OrgGitClaim]:
|
||||
"""Check if a Git organization is already claimed.
|
||||
|
||||
Args:
|
||||
provider: Git provider name
|
||||
git_organization: Name of the Git organization
|
||||
|
||||
Returns:
|
||||
OrgGitClaim or None if not claimed
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgGitClaim).filter(
|
||||
and_(
|
||||
OrgGitClaim.provider == provider,
|
||||
OrgGitClaim.git_organization == git_organization,
|
||||
)
|
||||
)
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@staticmethod
|
||||
async def get_claims_by_org_id(org_id: UUID) -> list[OrgGitClaim]:
|
||||
"""Get all Git organization claims for an OpenHands organization.
|
||||
|
||||
Args:
|
||||
org_id: OpenHands organization UUID
|
||||
|
||||
Returns:
|
||||
List of OrgGitClaim records
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgGitClaim).filter(OrgGitClaim.org_id == org_id)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@staticmethod
|
||||
async def delete_claim(claim_id: UUID, org_id: UUID) -> bool:
|
||||
"""Delete a Git organization claim.
|
||||
|
||||
Args:
|
||||
claim_id: Claim UUID to delete
|
||||
org_id: OpenHands organization UUID (for ownership verification)
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(OrgGitClaim).filter(
|
||||
and_(
|
||||
OrgGitClaim.id == claim_id,
|
||||
OrgGitClaim.org_id == org_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
claim = result.scalars().first()
|
||||
|
||||
if not claim:
|
||||
return False
|
||||
|
||||
await session.delete(claim)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'Deleted Git organization claim',
|
||||
extra={
|
||||
'claim_id': str(claim_id),
|
||||
'org_id': str(org_id),
|
||||
'provider': claim.provider,
|
||||
'git_organization': claim.git_organization,
|
||||
},
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -3,7 +3,7 @@ SQLAlchemy model for Organization-Member relationship.
|
||||
"""
|
||||
|
||||
from pydantic import SecretStr
|
||||
from sqlalchemy import JSON, UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy import UUID, Column, ForeignKey, Integer, String
|
||||
from sqlalchemy.orm import relationship
|
||||
from storage.base import Base
|
||||
from storage.encrypt_utils import decrypt_value, encrypt_value
|
||||
@@ -23,7 +23,6 @@ class OrgMember(Base): # type: ignore
|
||||
_llm_api_key_for_byor = Column(String, nullable=True)
|
||||
llm_base_url = Column(String, nullable=True)
|
||||
status = Column(String, nullable=True)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
org = relationship('Org', back_populates='org_members')
|
||||
|
||||
@@ -34,17 +34,10 @@ class SaasConversationStore(ConversationStore):
|
||||
session_maker: sessionmaker
|
||||
org_id: UUID | None = None # will be fetched automatically
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
org_id: UUID,
|
||||
session_maker: sessionmaker,
|
||||
resolver_org_id: UUID | None = None,
|
||||
):
|
||||
def __init__(self, user_id: str, org_id: UUID, session_maker: sessionmaker):
|
||||
self.user_id = user_id
|
||||
self.org_id = org_id
|
||||
self.session_maker = session_maker
|
||||
self.resolver_org_id = resolver_org_id
|
||||
|
||||
def _select_by_id(self, session, conversation_id: str):
|
||||
# Join StoredConversationMetadata with ConversationMetadataSaas to filter by user/org
|
||||
@@ -110,13 +103,6 @@ class SaasConversationStore(ConversationStore):
|
||||
|
||||
stored_metadata = StoredConversationMetadata(**kwargs)
|
||||
|
||||
# Override with resolver org_id if set (from git org claim resolution),
|
||||
# same pattern as V1's save_app_conversation_info in
|
||||
# saas_app_conversation_info_injector.py
|
||||
org_id = self.org_id
|
||||
if self.resolver_org_id is not None:
|
||||
org_id = self.resolver_org_id
|
||||
|
||||
def _save_metadata():
|
||||
with self.session_maker() as session:
|
||||
# Save the main conversation metadata
|
||||
@@ -136,13 +122,13 @@ class SaasConversationStore(ConversationStore):
|
||||
saas_metadata = StoredConversationMetadataSaas(
|
||||
conversation_id=stored_metadata.conversation_id,
|
||||
user_id=UUID(self.user_id),
|
||||
org_id=org_id,
|
||||
org_id=self.org_id,
|
||||
)
|
||||
session.add(saas_metadata)
|
||||
else:
|
||||
# Validate
|
||||
expected_user_id = UUID(self.user_id)
|
||||
expected_org_id = org_id
|
||||
expected_org_id = self.org_id
|
||||
|
||||
if saas_metadata.user_id != expected_user_id:
|
||||
raise ValueError(
|
||||
@@ -254,19 +240,3 @@ class SaasConversationStore(ConversationStore):
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(user_id, org_id, session_maker)
|
||||
|
||||
@classmethod
|
||||
async def get_resolver_instance(
|
||||
cls,
|
||||
config: OpenHandsConfig,
|
||||
user_id: str,
|
||||
resolver_org_id: UUID | None = None,
|
||||
) -> 'SaasConversationStore':
|
||||
"""Get a store for resolver conversations with explicit org routing.
|
||||
|
||||
Unlike get_instance, this accepts a resolver_org_id that overrides
|
||||
the user's default org when saving conversation metadata.
|
||||
"""
|
||||
user = await UserStore.get_user_by_id(user_id)
|
||||
org_id = user.current_org_id if user else None
|
||||
return SaasConversationStore(user_id, org_id, session_maker, resolver_org_id)
|
||||
|
||||
@@ -15,27 +15,25 @@ class SaasConversationValidator(ConversationValidator):
|
||||
|
||||
async def _validate_api_key(self, api_key: str) -> str | None:
|
||||
"""
|
||||
Validate an API key and return the user_id if valid.
|
||||
Validate an API key and return the user_id and github_user_id if valid.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate
|
||||
|
||||
Returns:
|
||||
The user_id if the API key is valid, None otherwise
|
||||
A tuple of (user_id, github_user_id) if the API key is valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
token_manager = TokenManager()
|
||||
|
||||
# Validate the API key and get the user_id
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
validation_result = await api_key_store.validate_api_key(api_key)
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
|
||||
if not validation_result:
|
||||
if not user_id:
|
||||
logger.warning('Invalid API key')
|
||||
return None
|
||||
|
||||
user_id = validation_result.user_id
|
||||
|
||||
# Get the offline token for the user
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
if not offline_token:
|
||||
|
||||
@@ -59,15 +59,12 @@ class SaasSecretsStore(SecretsStore):
|
||||
|
||||
async with a_session_maker() as session:
|
||||
# Incoming secrets are always the most updated ones
|
||||
# Delete existing records for this user AND organization only
|
||||
delete_query = delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
# Delete all existing records and override with incoming ones
|
||||
await session.execute(
|
||||
delete(StoredCustomSecrets).filter(
|
||||
StoredCustomSecrets.keycloak_user_id == self.user_id
|
||||
)
|
||||
)
|
||||
if org_id is not None:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id == org_id)
|
||||
else:
|
||||
delete_query = delete_query.filter(StoredCustomSecrets.org_id.is_(None))
|
||||
await session.execute(delete_query)
|
||||
|
||||
# Prepare the new secrets data
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
|
||||
@@ -115,9 +115,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
kwargs['llm_api_key_for_byor'] = org_member.llm_api_key_for_byor
|
||||
if org_member.llm_base_url:
|
||||
kwargs['llm_base_url'] = org_member.llm_base_url
|
||||
# MCP config is user-specific (stored on org_member, not org)
|
||||
if org_member.mcp_config is not None:
|
||||
kwargs['mcp_config'] = org_member.mcp_config
|
||||
if org.v1_enabled is None:
|
||||
kwargs['v1_enabled'] = True
|
||||
# Apply default if sandbox_grouping_strategy is None in the database
|
||||
@@ -182,13 +179,7 @@ class SaasSettingsStore(SettingsStore):
|
||||
return None
|
||||
|
||||
# Check if we need to generate an LLM key.
|
||||
# Only generate/verify proxy keys when the base URL is explicitly the
|
||||
# LiteLLM proxy, or when it's unset and the model is an OpenHands model
|
||||
# (which always needs a proxy key). For non-OpenHands models with no
|
||||
# base URL (e.g. basic view BYOR), preserve the user's own API key.
|
||||
if item.llm_base_url == LITE_LLM_API_URL or (
|
||||
not item.llm_base_url and is_openhands_model(item.llm_model)
|
||||
):
|
||||
if item.llm_base_url == LITE_LLM_API_URL:
|
||||
await self._ensure_api_key(
|
||||
item, str(org_id), openhands_type=is_openhands_model(item.llm_model)
|
||||
)
|
||||
@@ -196,9 +187,6 @@ class SaasSettingsStore(SettingsStore):
|
||||
kwargs = item.model_dump(context={'expose_secrets': True})
|
||||
for model in (user, org, org_member):
|
||||
for key, value in kwargs.items():
|
||||
# Skip mcp_config for org - it should only be stored on org_member (user-specific)
|
||||
if key == 'mcp_config' and model is org:
|
||||
continue
|
||||
if hasattr(model, key):
|
||||
setattr(model, key, value)
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ SQLAlchemy model for User.
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
UUID,
|
||||
Boolean,
|
||||
Column,
|
||||
@@ -35,7 +34,6 @@ class User(Base): # type: ignore
|
||||
git_user_name = Column(String, nullable=True)
|
||||
git_user_email = Column(String, nullable=True)
|
||||
sandbox_grouping_strategy = Column(String, nullable=True)
|
||||
disabled_skills = Column(JSON, nullable=True)
|
||||
|
||||
# Relationships
|
||||
role = relationship('Role', back_populates='users')
|
||||
|
||||
@@ -31,7 +31,6 @@ class UserSettings(Base): # type: ignore
|
||||
user_version = Column(Integer, nullable=False, default=0)
|
||||
accepted_tos = Column(DateTime, nullable=True)
|
||||
mcp_config = Column(JSON, nullable=True)
|
||||
disabled_skills = Column(JSON, nullable=True)
|
||||
search_api_key = Column(String, nullable=True)
|
||||
sandbox_api_key = Column(String, nullable=True)
|
||||
max_budget_per_task = Column(Float, nullable=True)
|
||||
|
||||
@@ -214,15 +214,14 @@ class UserStore:
|
||||
decrypted_user_settings, user_settings.user_version
|
||||
)
|
||||
|
||||
# Migrate stripe customer (pass session to avoid FK violation)
|
||||
# avoids circular reference. This migrate method is temporary until all users are migrated.
|
||||
# avoids circular reference. This migrate method is temprorary until all users are migrated.
|
||||
from integrations.stripe_service import migrate_customer
|
||||
|
||||
logger.debug(
|
||||
'user_store:migrate_user:calling_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
await migrate_customer(session, user_id, org)
|
||||
await migrate_customer(user_id, org)
|
||||
logger.debug(
|
||||
'user_store:migrate_user:done_stripe_migrate_customer',
|
||||
extra={'user_id': user_id},
|
||||
|
||||
@@ -13,6 +13,7 @@ Required environment variables:
|
||||
- RESEND_AUDIENCE_ID: ID of the Resend audience to add users to
|
||||
|
||||
Optional environment variables:
|
||||
- KEYCLOAK_PROVIDER_NAME: Provider name for Keycloak
|
||||
- KEYCLOAK_CLIENT_ID: Client ID for Keycloak
|
||||
- KEYCLOAK_CLIENT_SECRET: Client secret for Keycloak
|
||||
- RESEND_FROM_EMAIL: Email address to use as the sender (default: "OpenHands Team <no-reply@welcome.openhands.dev>")
|
||||
@@ -48,6 +49,7 @@ from openhands.core.logger import openhands_logger as logger
|
||||
# Get Keycloak configuration from environment variables
|
||||
KEYCLOAK_SERVER_URL = os.environ.get('KEYCLOAK_SERVER_URL', '')
|
||||
KEYCLOAK_REALM_NAME = os.environ.get('KEYCLOAK_REALM_NAME', '')
|
||||
KEYCLOAK_PROVIDER_NAME = os.environ.get('KEYCLOAK_PROVIDER_NAME', '')
|
||||
KEYCLOAK_CLIENT_ID = os.environ.get('KEYCLOAK_CLIENT_ID', '')
|
||||
KEYCLOAK_CLIENT_SECRET = os.environ.get('KEYCLOAK_CLIENT_SECRET', '')
|
||||
KEYCLOAK_ADMIN_PASSWORD = os.environ.get('KEYCLOAK_ADMIN_PASSWORD', '')
|
||||
|
||||
@@ -25,7 +25,6 @@ from storage.device_code import DeviceCode # noqa: F401
|
||||
from storage.feedback import Feedback
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.org import Org
|
||||
from storage.org_git_claim import OrgGitClaim # noqa: F401
|
||||
from storage.org_invitation import OrgInvitation # noqa: F401
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
|
||||
@@ -88,7 +88,6 @@ class TestGithubViewV1InitialUserMessage:
|
||||
view.previous_comments = [MagicMock(author='alice', body='old comment 1')]
|
||||
|
||||
view._load_resolver_context = AsyncMock(side_effect=_load_context) # type: ignore[method-assign]
|
||||
view.resolved_org_id = None
|
||||
|
||||
fake_service = _FakeAppConversationService()
|
||||
mock_get_app_conversation_service.return_value = (
|
||||
@@ -145,7 +144,6 @@ class TestGithubViewV1InitialUserMessage:
|
||||
]
|
||||
|
||||
view._load_resolver_context = AsyncMock(side_effect=_load_context) # type: ignore[method-assign]
|
||||
view.resolved_org_id = None
|
||||
|
||||
fake_service = _FakeAppConversationService()
|
||||
mock_get_app_conversation_service.return_value = (
|
||||
@@ -202,7 +200,6 @@ class TestGithubViewV1InitialUserMessage:
|
||||
view.previous_comments = []
|
||||
|
||||
view._load_resolver_context = AsyncMock(side_effect=_load_context) # type: ignore[method-assign]
|
||||
view.resolved_org_id = None
|
||||
|
||||
fake_service = _FakeAppConversationService()
|
||||
mock_get_service.return_value = _fake_app_conversation_service_ctx(fake_service)
|
||||
|
||||
@@ -32,28 +32,6 @@ def resolver_context(mock_saas_user_auth):
|
||||
return ResolverUserContext(saas_user_auth=mock_saas_user_auth)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for resolver_org_id - org routing for resolver conversations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolver_org_id_defaults_to_none(mock_saas_user_auth):
|
||||
"""Test that resolver_org_id defaults to None when not provided."""
|
||||
ctx = ResolverUserContext(saas_user_auth=mock_saas_user_auth)
|
||||
assert ctx.resolver_org_id is None
|
||||
|
||||
|
||||
def test_resolver_org_id_can_be_set_via_constructor(mock_saas_user_auth):
|
||||
"""Test that resolver_org_id can be set via constructor for org routing."""
|
||||
from uuid import UUID
|
||||
|
||||
org_id = UUID('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')
|
||||
ctx = ResolverUserContext(
|
||||
saas_user_auth=mock_saas_user_auth, resolver_org_id=org_id
|
||||
)
|
||||
assert ctx.resolver_org_id == org_id
|
||||
|
||||
|
||||
def create_custom_secret(value: str, description: str = 'Test secret') -> CustomSecret:
|
||||
"""Helper to create CustomSecret instances."""
|
||||
return CustomSecret(secret=SecretStr(value), description=description)
|
||||
|
||||
@@ -1,151 +0,0 @@
|
||||
"""Tests for resolver org routing logic.
|
||||
|
||||
Tests the resolve_org_for_repo function which determines which OpenHands
|
||||
organization workspace a resolver conversation should be created in.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
CLAIMING_ORG_ID = UUID('aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa')
|
||||
USER_ID = 'bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb'
|
||||
|
||||
# Patch at module level where the names are looked up
|
||||
_CLAIM_STORE = 'enterprise.integrations.resolver_org_router.OrgGitClaimStore'
|
||||
_MEMBER_STORE = 'enterprise.integrations.resolver_org_router.OrgMemberStore'
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_stores():
|
||||
"""Mock OrgGitClaimStore and OrgMemberStore for all tests."""
|
||||
with (
|
||||
patch(_CLAIM_STORE) as mock_claim_store,
|
||||
patch(_MEMBER_STORE) as mock_member_store,
|
||||
):
|
||||
mock_claim_store.get_claim_by_provider_and_git_org = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
mock_member_store.get_org_member = AsyncMock(return_value=None)
|
||||
yield mock_claim_store, mock_member_store
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_org_id_when_claimed_and_user_is_member(mock_stores):
|
||||
"""When the git org is claimed and the user is a member, return the claiming org's ID."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, mock_member_store = mock_stores
|
||||
|
||||
# Arrange
|
||||
claim = MagicMock()
|
||||
claim.org_id = CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = claim
|
||||
mock_member_store.get_org_member.return_value = MagicMock() # member exists
|
||||
|
||||
# Act
|
||||
result = await resolve_org_for_repo('github', 'OpenHands/foo', USER_ID)
|
||||
|
||||
# Assert
|
||||
assert result == CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
|
||||
'github', 'openhands'
|
||||
)
|
||||
mock_member_store.get_org_member.assert_called_once_with(
|
||||
CLAIMING_ORG_ID, UUID(USER_ID)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_claimed_but_user_not_member(mock_stores):
|
||||
"""When the git org is claimed but user is not a member, return None."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, mock_member_store = mock_stores
|
||||
|
||||
# Arrange
|
||||
claim = MagicMock()
|
||||
claim.org_id = CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = claim
|
||||
mock_member_store.get_org_member.return_value = None
|
||||
|
||||
# Act
|
||||
result = await resolve_org_for_repo('github', 'OpenHands/foo', USER_ID)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_claim_exists(mock_stores):
|
||||
"""When no org has claimed the git organization, return None."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, _ = mock_stores
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = None
|
||||
|
||||
# Act
|
||||
result = await resolve_org_for_repo('github', 'UnclaimedOrg/repo', USER_ID)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
|
||||
'github', 'unclaimedorg'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_git_org_lowercase_from_repo_name(mock_stores):
|
||||
"""The git org is extracted from repo name and lowercased for claim lookup."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, _ = mock_stores
|
||||
|
||||
# Act
|
||||
await resolve_org_for_repo('github', 'MyOrg/some-repo', USER_ID)
|
||||
|
||||
# Assert
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
|
||||
'github', 'myorg'
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_org_id_without_membership_check_when_no_user_id(mock_stores):
|
||||
"""When user_id is None, skip membership check and return org_id if claim exists."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, mock_member_store = mock_stores
|
||||
|
||||
# Arrange
|
||||
claim = MagicMock()
|
||||
claim.org_id = CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = claim
|
||||
|
||||
# Act - no user_id provided
|
||||
result = await resolve_org_for_repo('github', 'OpenHands/foo')
|
||||
|
||||
# Assert
|
||||
assert result == CLAIMING_ORG_ID
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.assert_called_once_with(
|
||||
'github', 'openhands'
|
||||
)
|
||||
# Membership check should NOT be called
|
||||
mock_member_store.get_org_member.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_no_claim_and_no_user_id(mock_stores):
|
||||
"""When no claim exists and no user_id, return None."""
|
||||
from enterprise.integrations.resolver_org_router import resolve_org_for_repo
|
||||
|
||||
mock_claim_store, mock_member_store = mock_stores
|
||||
mock_claim_store.get_claim_by_provider_and_git_org.return_value = None
|
||||
|
||||
# Act - no user_id provided
|
||||
result = await resolve_org_for_repo('github', 'UnclaimedOrg/repo')
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_member_store.get_org_member.assert_not_called()
|
||||
@@ -1,325 +0,0 @@
|
||||
"""Unit tests for service API routes."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from server.routes.service import (
|
||||
CreateUserApiKeyRequest,
|
||||
delete_user_api_key,
|
||||
get_or_create_api_key_for_user,
|
||||
validate_service_api_key,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateServiceApiKey:
|
||||
"""Test cases for validate_service_api_key."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_service_key(self):
|
||||
"""Test validation with valid service API key."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-service-key'):
|
||||
result = await validate_service_api_key('test-service-key')
|
||||
assert result == 'automations-service'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_service_key(self):
|
||||
"""Test validation with missing service API key header."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-service-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key(None)
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'X-Service-API-Key header is required' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_service_key(self):
|
||||
"""Test validation with invalid service API key."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-service-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key('wrong-key')
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'Invalid service API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_auth_not_configured(self):
|
||||
"""Test validation when service auth is not configured."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', ''):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await validate_service_api_key('any-key')
|
||||
assert exc_info.value.status_code == 503
|
||||
assert 'Service authentication not configured' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestCreateUserApiKeyRequest:
|
||||
"""Test cases for CreateUserApiKeyRequest validation."""
|
||||
|
||||
def test_valid_request(self):
|
||||
"""Test valid request with all fields."""
|
||||
request = CreateUserApiKeyRequest(
|
||||
name='automation',
|
||||
)
|
||||
assert request.name == 'automation'
|
||||
|
||||
def test_name_is_required(self):
|
||||
"""Test that name field is required."""
|
||||
with pytest.raises(ValueError):
|
||||
CreateUserApiKeyRequest(
|
||||
name='', # Empty name should fail
|
||||
)
|
||||
|
||||
def test_name_is_stripped(self):
|
||||
"""Test that name field is stripped of whitespace."""
|
||||
request = CreateUserApiKeyRequest(
|
||||
name=' automation ',
|
||||
)
|
||||
assert request.name == 'automation'
|
||||
|
||||
def test_whitespace_only_name_fails(self):
|
||||
"""Test that whitespace-only name fails validation."""
|
||||
with pytest.raises(ValueError):
|
||||
CreateUserApiKeyRequest(
|
||||
name=' ',
|
||||
)
|
||||
|
||||
|
||||
class TestGetOrCreateApiKeyForUser:
|
||||
"""Test cases for get_or_create_api_key_for_user endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_user_id(self):
|
||||
"""Return a valid user ID."""
|
||||
return '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
|
||||
@pytest.fixture
|
||||
def valid_org_id(self):
|
||||
"""Return a valid org ID."""
|
||||
return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
@pytest.fixture
|
||||
def valid_request(self):
|
||||
"""Create a valid request object."""
|
||||
return CreateUserApiKeyRequest(
|
||||
name='automation',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_found(self, valid_user_id, valid_org_id, valid_request):
|
||||
"""Test error when user doesn't exist."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'not found' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_not_in_org(self, valid_user_id, valid_org_id, valid_request):
|
||||
"""Test error when user is not a member of the org."""
|
||||
mock_user = MagicMock()
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert 'not a member of org' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_key_creation(
|
||||
self, valid_user_id, valid_org_id, valid_request
|
||||
):
|
||||
"""Test successful API key creation."""
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.get_or_create_system_api_key = AsyncMock(
|
||||
return_value='sk-oh-test-key-12345678901234567890'
|
||||
)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
response = await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert response.key == 'sk-oh-test-key-12345678901234567890'
|
||||
assert response.user_id == valid_user_id
|
||||
assert response.org_id == str(valid_org_id)
|
||||
assert response.name == 'automation'
|
||||
|
||||
# Verify the store was called with correct arguments
|
||||
mock_api_key_store.get_or_create_system_api_key.assert_called_once_with(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
name='automation',
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_exception_handling(
|
||||
self, valid_user_id, valid_org_id, valid_request
|
||||
):
|
||||
"""Test error handling when store raises exception."""
|
||||
mock_user = MagicMock()
|
||||
mock_org_member = MagicMock()
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.get_or_create_system_api_key = AsyncMock(
|
||||
side_effect=Exception('Database error')
|
||||
)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
with patch(
|
||||
'server.routes.service.OrgMemberStore.get_org_member',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_member:
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_user.return_value = mock_user
|
||||
mock_get_member.return_value = mock_org_member
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_or_create_api_key_for_user(
|
||||
user_id=valid_user_id,
|
||||
org_id=valid_org_id,
|
||||
request=valid_request,
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to get or create API key' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestDeleteUserApiKey:
|
||||
"""Test cases for delete_user_api_key endpoint."""
|
||||
|
||||
@pytest.fixture
|
||||
def valid_org_id(self):
|
||||
"""Return a valid org ID."""
|
||||
return uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_delete(self, valid_org_id):
|
||||
"""Test successful deletion of a system API key."""
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:automation'
|
||||
mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=True)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
response = await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert response == {'message': 'API key deleted successfully'}
|
||||
|
||||
# Verify the store was called with correct arguments
|
||||
mock_api_key_store.make_system_key_name.assert_called_once_with('automation')
|
||||
mock_api_key_store.delete_api_key_by_name.assert_called_once_with(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
allow_system=True,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key_not_found(self, valid_org_id):
|
||||
"""Test error when key to delete is not found."""
|
||||
mock_api_key_store = MagicMock()
|
||||
mock_api_key_store.make_system_key_name.return_value = '__SYSTEM__:nonexistent'
|
||||
mock_api_key_store.delete_api_key_by_name = AsyncMock(return_value=False)
|
||||
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with patch(
|
||||
'server.routes.service.ApiKeyStore.get_instance'
|
||||
) as mock_get_store:
|
||||
mock_get_store.return_value = mock_api_key_store
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='nonexistent',
|
||||
x_service_api_key='test-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert 'not found' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_invalid_service_key(self, valid_org_id):
|
||||
"""Test error when service API key is invalid."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key='wrong-key',
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'Invalid service API key' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_missing_service_key(self, valid_org_id):
|
||||
"""Test error when service API key header is missing."""
|
||||
with patch('server.routes.service.AUTOMATIONS_SERVICE_KEY', 'test-key'):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await delete_user_api_key(
|
||||
user_id='user-123',
|
||||
org_id=valid_org_id,
|
||||
key_name='automation',
|
||||
x_service_api_key=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert 'X-Service-API-Key header is required' in exc_info.value.detail
|
||||
@@ -1,26 +1,19 @@
|
||||
"""Unit tests for API keys routes, focusing on BYOR key validation and retrieval."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.routes.api_keys import (
|
||||
ByorPermittedResponse,
|
||||
CurrentApiKeyResponse,
|
||||
LlmApiKeyResponse,
|
||||
check_byor_permitted,
|
||||
delete_byor_key_from_litellm,
|
||||
get_current_api_key,
|
||||
get_llm_api_key_for_byor,
|
||||
)
|
||||
from storage.lite_llm_manager import LiteLlmManager
|
||||
|
||||
from openhands.server.user_auth.user_auth import AuthType
|
||||
|
||||
|
||||
class TestVerifyByorKeyInLitellm:
|
||||
"""Test the verify_byor_key_in_litellm function."""
|
||||
@@ -519,81 +512,3 @@ class TestCheckByorPermitted:
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
assert 'Failed to check BYOR export permission' in exc_info.value.detail
|
||||
|
||||
|
||||
class TestGetCurrentApiKey:
|
||||
"""Test the get_current_api_key endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_api_key_info_for_bearer_auth(self, mock_get_user_auth):
|
||||
"""Test that API key metadata including org_id is returned for bearer token auth."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
org_id = uuid.uuid4()
|
||||
mock_request = MagicMock()
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
refresh_token=SecretStr('mock-token'),
|
||||
user_id=user_id,
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=org_id,
|
||||
api_key_id=42,
|
||||
api_key_name='My Production Key',
|
||||
)
|
||||
mock_get_user_auth.return_value = user_auth
|
||||
|
||||
# Act
|
||||
result = await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, CurrentApiKeyResponse)
|
||||
assert result.org_id == str(org_id)
|
||||
assert result.id == 42
|
||||
assert result.name == 'My Production Key'
|
||||
assert result.user_id == user_id
|
||||
assert result.auth_type == 'bearer'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_400_for_cookie_auth(self, mock_get_user_auth):
|
||||
"""Test that 400 Bad Request is returned when using cookie authentication."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_request = MagicMock()
|
||||
|
||||
mock_user_auth = MagicMock()
|
||||
mock_user_auth.get_auth_type.return_value = AuthType.COOKIE
|
||||
mock_get_user_auth.return_value = mock_user_auth
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'API key authentication' in exc_info.value.detail
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch('server.routes.api_keys.get_user_auth')
|
||||
async def test_returns_400_when_api_key_org_id_is_none(self, mock_get_user_auth):
|
||||
"""Test that 400 is returned when API key has no org_id (legacy key)."""
|
||||
# Arrange
|
||||
user_id = 'user-123'
|
||||
mock_request = MagicMock()
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
refresh_token=SecretStr('mock-token'),
|
||||
user_id=user_id,
|
||||
auth_type=AuthType.BEARER,
|
||||
api_key_org_id=None, # No org_id - legacy key
|
||||
api_key_id=42,
|
||||
api_key_name='Legacy Key',
|
||||
)
|
||||
mock_get_user_auth.return_value = user_auth
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await get_current_api_key(request=mock_request, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert 'created before organization support' in exc_info.value.detail
|
||||
|
||||
@@ -1,603 +0,0 @@
|
||||
"""Tests for Git organization claim API endpoints.
|
||||
|
||||
Tests the following endpoints:
|
||||
- GET /api/organizations/{org_id}/git-claims (list claims)
|
||||
- POST /api/organizations/{org_id}/git-claims (claim)
|
||||
- DELETE /api/organizations/{org_id}/git-claims/{claim_id} (disconnect)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, status
|
||||
from fastapi.testclient import TestClient
|
||||
from server.routes.orgs import (
|
||||
claim_git_organization,
|
||||
disconnect_git_organization,
|
||||
get_git_claims,
|
||||
org_router,
|
||||
)
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.org_git_claim import OrgGitClaim
|
||||
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
TEST_USER_ID = str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def org_id():
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_claim():
|
||||
"""Factory to create mock OrgGitClaim objects."""
|
||||
|
||||
def _make(org_id, provider='github', git_organization='OpenHands', claimed_by=None):
|
||||
claim = MagicMock(spec=OrgGitClaim)
|
||||
claim.id = uuid.uuid4()
|
||||
claim.org_id = org_id
|
||||
claim.provider = provider
|
||||
claim.git_organization = git_organization
|
||||
claim.claimed_by = claimed_by or uuid.uuid4()
|
||||
claim.claimed_at = datetime(2026, 4, 1, 12, 0, 0)
|
||||
return claim
|
||||
|
||||
return _make
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# GET /api/organizations/{org_id}/git-claims
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetGitClaims:
|
||||
"""Tests for the get Git organization claims endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_list_when_no_claims(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: An organization with no Git claims
|
||||
WHEN: GET /api/organizations/{org_id}/git-claims is called
|
||||
THEN: An empty list is returned
|
||||
"""
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claims_by_org_id',
|
||||
AsyncMock(return_value=[]),
|
||||
) as mock_get:
|
||||
result = await get_git_claims(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert result == []
|
||||
mock_get.assert_called_once_with(org_id=org_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_claims_for_organization(self, org_id, user_id, make_claim):
|
||||
"""
|
||||
GIVEN: An organization with multiple Git claims
|
||||
WHEN: GET /api/organizations/{org_id}/git-claims is called
|
||||
THEN: All claims are returned with correct details
|
||||
"""
|
||||
claim1 = make_claim(org_id, provider='github', git_organization='OpenHands')
|
||||
claim2 = make_claim(org_id, provider='gitlab', git_organization='AcmeCo')
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claims_by_org_id',
|
||||
AsyncMock(return_value=[claim1, claim2]),
|
||||
):
|
||||
result = await get_git_claims(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].id == str(claim1.id)
|
||||
assert result[0].org_id == str(org_id)
|
||||
assert result[0].provider == 'github'
|
||||
assert result[0].git_organization == 'OpenHands'
|
||||
assert result[0].claimed_by == str(claim1.claimed_by)
|
||||
assert result[0].claimed_at == '2026-04-01T12:00:00'
|
||||
assert result[1].id == str(claim2.id)
|
||||
assert result[1].provider == 'gitlab'
|
||||
assert result[1].git_organization == 'AcmeCo'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_500_on_unexpected_error(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: An unexpected error occurs when fetching claims
|
||||
WHEN: GET /api/organizations/{org_id}/git-claims is called
|
||||
THEN: A 500 Internal Server Error is returned
|
||||
"""
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claims_by_org_id',
|
||||
AsyncMock(side_effect=RuntimeError('db connection failed')),
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await get_git_claims(org_id=org_id, user_id=user_id)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# POST /api/organizations/{org_id}/git-claims
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestClaimGitOrganization:
|
||||
"""Tests for the claim Git organization endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_succeeds_for_unclaimed_org(self, org_id, user_id, make_claim):
|
||||
"""
|
||||
GIVEN: A Git organization that has not been claimed
|
||||
WHEN: POST /api/organizations/{org_id}/git-claims is called
|
||||
THEN: The claim is created and returned with correct details
|
||||
"""
|
||||
# Arrange
|
||||
mock_claim = make_claim(org_id, claimed_by=uuid.UUID(user_id))
|
||||
request = MagicMock()
|
||||
request.provider = 'github'
|
||||
request.git_organization = 'OpenHands'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claim_by_provider_and_git_org',
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.create_claim',
|
||||
AsyncMock(return_value=mock_claim),
|
||||
) as mock_create,
|
||||
):
|
||||
# Act
|
||||
response = await claim_git_organization(
|
||||
org_id=org_id, request=request, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response.id == str(mock_claim.id)
|
||||
assert response.org_id == str(org_id)
|
||||
assert response.provider == 'github'
|
||||
assert response.git_organization == 'OpenHands'
|
||||
assert response.claimed_by == user_id
|
||||
mock_create.assert_called_once_with(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='OpenHands',
|
||||
claimed_by=uuid.UUID(user_id),
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_fails_when_already_claimed(self, org_id, user_id, make_claim):
|
||||
"""
|
||||
GIVEN: A Git organization already claimed by another OpenHands org
|
||||
WHEN: POST /api/organizations/{org_id}/git-claims is called
|
||||
THEN: A 409 Conflict error is returned
|
||||
"""
|
||||
# Arrange
|
||||
other_org_id = uuid.uuid4()
|
||||
existing_claim = make_claim(
|
||||
other_org_id, provider='github', git_organization='AlreadyClaimed'
|
||||
)
|
||||
request = MagicMock()
|
||||
request.provider = 'github'
|
||||
request.git_organization = 'AlreadyClaimed'
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claim_by_provider_and_git_org',
|
||||
AsyncMock(return_value=existing_claim),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await claim_git_organization(
|
||||
org_id=org_id, request=request, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_returns_500_on_unexpected_error(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: An unexpected error occurs during claim creation
|
||||
WHEN: POST /api/organizations/{org_id}/git-claims is called
|
||||
THEN: A 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
request = MagicMock()
|
||||
request.provider = 'github'
|
||||
request.git_organization = 'OpenHands'
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claim_by_provider_and_git_org',
|
||||
AsyncMock(side_effect=RuntimeError('db connection failed')),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await claim_git_organization(
|
||||
org_id=org_id, request=request, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claim_race_condition_returns_409(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: Pre-check passes but a concurrent request claims the org first
|
||||
WHEN: create_claim raises IntegrityError (DB unique constraint)
|
||||
THEN: A 409 Conflict error is returned instead of 500
|
||||
"""
|
||||
# Arrange
|
||||
request = MagicMock()
|
||||
request.provider = 'github'
|
||||
request.git_organization = 'RaceOrg'
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claim_by_provider_and_git_org',
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.create_claim',
|
||||
AsyncMock(
|
||||
side_effect=IntegrityError(
|
||||
'duplicate',
|
||||
'',
|
||||
Exception('uq_provider_git_org'),
|
||||
)
|
||||
),
|
||||
),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await claim_git_organization(
|
||||
org_id=org_id, request=request, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_409_CONFLICT
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# DELETE /api/organizations/{org_id}/git-claims/{claim_id}
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestDisconnectGitOrganization:
|
||||
"""Tests for the disconnect Git organization endpoint."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_succeeds_for_existing_claim(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: A valid claim belonging to the organization
|
||||
WHEN: DELETE /api/organizations/{org_id}/git-claims/{claim_id} is called
|
||||
THEN: The claim is deleted and a success message is returned
|
||||
"""
|
||||
# Arrange
|
||||
claim_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.delete_claim',
|
||||
AsyncMock(return_value=True),
|
||||
) as mock_delete:
|
||||
# Act
|
||||
result = await disconnect_git_organization(
|
||||
org_id=org_id, claim_id=claim_id, user_id=user_id
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == {'message': 'Git organization claim removed successfully'}
|
||||
mock_delete.assert_called_once_with(claim_id=claim_id, org_id=org_id)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_fails_when_claim_not_found(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: A claim_id that does not exist for this organization
|
||||
WHEN: DELETE /api/organizations/{org_id}/git-claims/{claim_id} is called
|
||||
THEN: A 404 Not Found error is returned
|
||||
"""
|
||||
# Arrange
|
||||
claim_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.delete_claim',
|
||||
AsyncMock(return_value=False),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await disconnect_git_organization(
|
||||
org_id=org_id, claim_id=claim_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disconnect_returns_500_on_unexpected_error(self, org_id, user_id):
|
||||
"""
|
||||
GIVEN: An unexpected error occurs during claim deletion
|
||||
WHEN: DELETE /api/organizations/{org_id}/git-claims/{claim_id} is called
|
||||
THEN: A 500 Internal Server Error is returned
|
||||
"""
|
||||
# Arrange
|
||||
claim_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.delete_claim',
|
||||
AsyncMock(side_effect=RuntimeError('db connection failed')),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await disconnect_git_organization(
|
||||
org_id=org_id, claim_id=claim_id, user_id=user_id
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Validation tests for GitOrgClaimRequest
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGitOrgClaimRequestValidation:
|
||||
"""Tests for request model validation."""
|
||||
|
||||
def test_valid_providers_are_accepted(self):
|
||||
"""Each supported provider is accepted and normalized to lowercase."""
|
||||
from server.routes.org_models import GitOrgClaimRequest
|
||||
|
||||
for provider in ['github', 'GitLab', 'BITBUCKET']:
|
||||
req = GitOrgClaimRequest(provider=provider, git_organization='test-org')
|
||||
assert req.provider == provider.lower().strip()
|
||||
|
||||
def test_invalid_provider_is_rejected(self):
|
||||
"""An unsupported provider raises a validation error."""
|
||||
from pydantic import ValidationError
|
||||
from server.routes.org_models import GitOrgClaimRequest
|
||||
|
||||
with pytest.raises(ValidationError, match='Invalid provider'):
|
||||
GitOrgClaimRequest(provider='azure_devops', git_organization='test-org')
|
||||
|
||||
def test_empty_git_organization_is_rejected(self):
|
||||
"""An empty git_organization raises a validation error."""
|
||||
from pydantic import ValidationError
|
||||
from server.routes.org_models import GitOrgClaimRequest
|
||||
|
||||
with pytest.raises(ValidationError, match='git_organization must not be empty'):
|
||||
GitOrgClaimRequest(provider='github', git_organization=' ')
|
||||
|
||||
def test_git_organization_is_normalized_to_lowercase(self):
|
||||
"""git_organization is lowercased to prevent case-sensitive duplicates."""
|
||||
from server.routes.org_models import GitOrgClaimRequest
|
||||
|
||||
req = GitOrgClaimRequest(provider='github', git_organization='OpenHands')
|
||||
assert req.git_organization == 'openhands'
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Integration tests — TestClient with real HTTP, auth, and Pydantic validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
"""FastAPI app with org routes and mocked user authentication."""
|
||||
app = FastAPI()
|
||||
app.include_router(org_router)
|
||||
|
||||
app.dependency_overrides[get_user_id] = lambda: TEST_USER_ID
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_owner_role():
|
||||
role = MagicMock()
|
||||
role.name = 'owner'
|
||||
return role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_member_role():
|
||||
role = MagicMock()
|
||||
role.name = 'member'
|
||||
return role
|
||||
|
||||
|
||||
class TestGitClaimsAuthorization:
|
||||
"""Integration tests verifying authorization through the real HTTP cycle."""
|
||||
|
||||
def test_non_member_gets_403_on_get(self, mock_app):
|
||||
"""
|
||||
GIVEN: A user who is not a member of the target organization
|
||||
WHEN: GET /api/organizations/{org_id}/git-claims via HTTP
|
||||
THEN: 403 is returned by require_permission
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.get(f'/api/organizations/{org_id}/git-claims')
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'not a member' in response.json()['detail']
|
||||
|
||||
def test_member_without_permission_gets_403_on_post(
|
||||
self, mock_app, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with member role (lacks MANAGE_ORG_CLAIMS)
|
||||
WHEN: POST /api/organizations/{org_id}/git-claims via HTTP
|
||||
THEN: 403 is returned by require_permission
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.post(
|
||||
f'/api/organizations/{org_id}/git-claims',
|
||||
json={'provider': 'github', 'git_organization': 'SomeOrg'},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'manage_org_claims' in response.json()['detail']
|
||||
|
||||
def test_member_without_permission_gets_403_on_delete(
|
||||
self, mock_app, mock_member_role
|
||||
):
|
||||
"""
|
||||
GIVEN: A user with member role (lacks MANAGE_ORG_CLAIMS)
|
||||
WHEN: DELETE /api/organizations/{org_id}/git-claims/{claim_id} via HTTP
|
||||
THEN: 403 is returned by require_permission
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
claim_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_member_role),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.delete(
|
||||
f'/api/organizations/{org_id}/git-claims/{claim_id}'
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
assert 'manage_org_claims' in response.json()['detail']
|
||||
|
||||
|
||||
class TestGitClaimsHTTPIntegration:
|
||||
"""Integration tests for the full request/response cycle via TestClient."""
|
||||
|
||||
def test_post_claim_with_invalid_provider_returns_422(
|
||||
self, mock_app, mock_owner_role
|
||||
):
|
||||
"""
|
||||
GIVEN: A request with an unsupported provider
|
||||
WHEN: POST /api/organizations/{org_id}/git-claims via HTTP
|
||||
THEN: 422 is returned by Pydantic validation
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
|
||||
with patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_owner_role),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.post(
|
||||
f'/api/organizations/{org_id}/git-claims',
|
||||
json={'provider': 'azure_devops', 'git_organization': 'test'},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
|
||||
|
||||
def test_post_claim_success_returns_201(self, mock_app, mock_owner_role):
|
||||
"""
|
||||
GIVEN: A valid claim request by an authorized admin/owner
|
||||
WHEN: POST /api/organizations/{org_id}/git-claims via HTTP
|
||||
THEN: 201 is returned with the claim details
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
mock_claim = MagicMock(spec=OrgGitClaim)
|
||||
mock_claim.id = uuid.uuid4()
|
||||
mock_claim.org_id = org_id
|
||||
mock_claim.provider = 'github'
|
||||
mock_claim.git_organization = 'openhands'
|
||||
mock_claim.claimed_by = uuid.UUID(TEST_USER_ID)
|
||||
mock_claim.claimed_at = datetime(2026, 4, 1, 12, 0, 0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_owner_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claim_by_provider_and_git_org',
|
||||
AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.create_claim',
|
||||
AsyncMock(return_value=mock_claim),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.post(
|
||||
f'/api/organizations/{org_id}/git-claims',
|
||||
json={'provider': 'github', 'git_organization': 'OpenHands'},
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_201_CREATED
|
||||
data = response.json()
|
||||
assert data['org_id'] == str(org_id)
|
||||
assert data['provider'] == 'github'
|
||||
assert data['git_organization'] == 'openhands'
|
||||
|
||||
def test_delete_claim_success_returns_200(self, mock_app, mock_owner_role):
|
||||
"""
|
||||
GIVEN: A valid disconnect request by an authorized admin/owner
|
||||
WHEN: DELETE /api/organizations/{org_id}/git-claims/{claim_id} via HTTP
|
||||
THEN: 200 is returned with a success message
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
claim_id = uuid.uuid4()
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_owner_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.delete_claim',
|
||||
AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.delete(
|
||||
f'/api/organizations/{org_id}/git-claims/{claim_id}'
|
||||
)
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
assert (
|
||||
response.json()['message'] == 'Git organization claim removed successfully'
|
||||
)
|
||||
|
||||
def test_get_claims_success_returns_200(self, mock_app, mock_owner_role):
|
||||
"""
|
||||
GIVEN: An authorized user requests claims for their organization
|
||||
WHEN: GET /api/organizations/{org_id}/git-claims via HTTP
|
||||
THEN: 200 is returned with the list of claims
|
||||
"""
|
||||
org_id = uuid.uuid4()
|
||||
mock_claim = MagicMock(spec=OrgGitClaim)
|
||||
mock_claim.id = uuid.uuid4()
|
||||
mock_claim.org_id = org_id
|
||||
mock_claim.provider = 'github'
|
||||
mock_claim.git_organization = 'openhands'
|
||||
mock_claim.claimed_by = uuid.UUID(TEST_USER_ID)
|
||||
mock_claim.claimed_at = datetime(2026, 4, 1, 12, 0, 0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.auth.authorization.get_user_org_role',
|
||||
AsyncMock(return_value=mock_owner_role),
|
||||
),
|
||||
patch(
|
||||
'server.routes.orgs.OrgGitClaimStore.get_claims_by_org_id',
|
||||
AsyncMock(return_value=[mock_claim]),
|
||||
),
|
||||
):
|
||||
client = TestClient(mock_app)
|
||||
response = client.get(f'/api/organizations/{org_id}/git-claims')
|
||||
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]['provider'] == 'github'
|
||||
assert data[0]['git_organization'] == 'openhands'
|
||||
@@ -1,670 +0,0 @@
|
||||
"""
|
||||
Unit tests for AutomationEventService.
|
||||
|
||||
Tests the service that forwards GitHub webhook events to the automation service.
|
||||
|
||||
The service is optimized for high-traffic with:
|
||||
- Redis caching for org claim lookups (1 hour TTL)
|
||||
- Redis caching for GitHub→Keycloak user ID mappings (24 hour TTL)
|
||||
- Lazy access control (membership checks deferred to execution time)
|
||||
- Separate AUTOMATION_WEBHOOK_SECRET for internal service communication
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Default patches for constants
|
||||
CONSTANT_PATCHES = {
|
||||
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET': 'test-shared-secret',
|
||||
'server.services.automation_event_service.AUTOMATION_SERVICE_TIMEOUT': 30,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token_manager():
|
||||
"""Create a mock TokenManager."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org_git_claim():
|
||||
"""Create a mock OrgGitClaim."""
|
||||
claim = MagicMock()
|
||||
claim.org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
|
||||
return claim
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_org_payload():
|
||||
"""Create a sample GitHub webhook payload for an organization repo."""
|
||||
return {
|
||||
'repository': {
|
||||
'id': 123456,
|
||||
'full_name': 'test-org/test-repo',
|
||||
'private': False,
|
||||
'default_branch': 'main',
|
||||
'owner': {
|
||||
'login': 'test-org',
|
||||
'id': 789,
|
||||
'type': 'Organization',
|
||||
},
|
||||
},
|
||||
'sender': {
|
||||
'id': 12345,
|
||||
'login': 'testuser',
|
||||
},
|
||||
'action': 'opened',
|
||||
'installation': {
|
||||
'id': 99999,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def github_user_payload():
|
||||
"""Create a sample GitHub webhook payload for a personal/user repo."""
|
||||
return {
|
||||
'repository': {
|
||||
'id': 654321,
|
||||
'full_name': 'testuser/personal-repo',
|
||||
'private': True,
|
||||
'default_branch': 'main',
|
||||
'owner': {
|
||||
'login': 'testuser',
|
||||
'id': 12345,
|
||||
'type': 'User',
|
||||
},
|
||||
},
|
||||
'sender': {
|
||||
'id': 12345,
|
||||
'login': 'testuser',
|
||||
},
|
||||
'action': 'opened',
|
||||
'installation': {
|
||||
'id': 99999,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def create_service(mock_token_manager):
|
||||
"""Helper to create a service with mocked sio and constants."""
|
||||
with patch('server.services.automation_event_service.sio'), patch.dict(
|
||||
'os.environ', {}, clear=False
|
||||
):
|
||||
for key, value in CONSTANT_PATCHES.items():
|
||||
patch(key, value).start()
|
||||
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
return AutomationEventService(mock_token_manager)
|
||||
|
||||
|
||||
class TestResolveGithubOrg:
|
||||
"""Tests for _resolve_github_org method with caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_cache_miss_found(
|
||||
self, mock_token_manager, mock_org_git_claim
|
||||
):
|
||||
"""
|
||||
GIVEN: Cache miss and org claim exists in DB
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: Org ID is returned and cached
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('test-org')
|
||||
|
||||
assert result == mock_org_git_claim.org_id
|
||||
# Verify result was cached
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_cache_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Org ID is cached in Redis
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: Cached value is returned without calling resolve_org_for_repo
|
||||
"""
|
||||
cached_org_id = '12345678-1234-5678-1234-567812345678'
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=cached_org_id.encode())
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_resolver, patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('test-org')
|
||||
|
||||
assert result == uuid.UUID(cached_org_id)
|
||||
# resolve_org_for_repo should NOT be called
|
||||
mock_resolver.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_cache_miss_not_found(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Cache miss and org claim does NOT exist in DB
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: None is returned and negative result is cached
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('unclaimed-org')
|
||||
|
||||
assert result is None
|
||||
# Verify negative result was cached
|
||||
mock_redis.setex.assert_called_once()
|
||||
call_args = mock_redis.setex.call_args
|
||||
# Second positional arg is the value
|
||||
assert call_args[0][2] == 'none' # Negative cache value
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_github_org_negative_cache_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Negative result is cached (org not claimed)
|
||||
WHEN: _resolve_github_org is called
|
||||
THEN: None is returned without calling resolve_org_for_repo
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=b'none') # Cached negative
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_resolver, patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_github_org('unclaimed-org')
|
||||
|
||||
assert result is None
|
||||
mock_resolver.assert_not_called()
|
||||
|
||||
|
||||
class TestResolvePersonalOrg:
|
||||
"""Tests for _resolve_personal_org method with caching."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_personal_org_cache_miss_found(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Cache miss and user exists in Keycloak
|
||||
WHEN: _resolve_personal_org is called
|
||||
THEN: Keycloak ID is returned and cached
|
||||
"""
|
||||
keycloak_id = '87654321-4321-8765-4321-876543218765'
|
||||
mock_token_manager.get_user_id_from_idp_user_id = AsyncMock(
|
||||
return_value=keycloak_id
|
||||
)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None) # Cache miss
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_personal_org(12345)
|
||||
|
||||
assert result == uuid.UUID(keycloak_id)
|
||||
mock_redis.setex.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_personal_org_cache_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Keycloak ID is cached in Redis
|
||||
WHEN: _resolve_personal_org is called
|
||||
THEN: Cached value is returned without Keycloak query
|
||||
"""
|
||||
keycloak_id = '87654321-4321-8765-4321-876543218765'
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=keycloak_id.encode())
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_personal_org(12345)
|
||||
|
||||
assert result == uuid.UUID(keycloak_id)
|
||||
# Token manager should NOT be called
|
||||
mock_token_manager.get_user_id_from_idp_user_id.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_personal_org_no_github_user_id(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: No GitHub user ID provided
|
||||
WHEN: _resolve_personal_org is called
|
||||
THEN: None is returned immediately
|
||||
"""
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._resolve_personal_org(None)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestForwardGithubEvent:
|
||||
"""Tests for forward_github_event method (minimal payload, no access control)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_org_event_success(
|
||||
self, mock_token_manager, github_org_payload, mock_org_git_claim
|
||||
):
|
||||
"""
|
||||
GIVEN: A GitHub event from a claimed organization repo
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Minimal payload is forwarded (no access_control)
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_org_git_claim.org_id,
|
||||
), patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=github_org_payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args
|
||||
assert call_args[0][0] == mock_org_git_claim.org_id
|
||||
|
||||
payload = call_args[0][1]
|
||||
assert payload['organization']['github_org'] == 'test-org'
|
||||
assert 'payload' in payload
|
||||
# access_control should NOT be in payload (lazy evaluation)
|
||||
assert 'access_control' not in payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_personal_repo_event_success(
|
||||
self, mock_token_manager, github_user_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: A GitHub event from a personal repo with linked OpenHands account
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Event is forwarded using the user's personal org (keycloak ID)
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
keycloak_id = '87654321-4321-8765-4321-876543218765'
|
||||
mock_token_manager.get_user_id_from_idp_user_id = AsyncMock(
|
||||
return_value=keycloak_id
|
||||
)
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # No org claim for personal repo
|
||||
), patch(
|
||||
'server.services.automation_event_service.sio'
|
||||
) as mock_sio, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=github_user_payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_called_once()
|
||||
call_args = mock_send.call_args
|
||||
# Personal org should be keycloak ID
|
||||
assert call_args[0][0] == uuid.UUID(keycloak_id)
|
||||
payload = call_args[0][1]
|
||||
assert payload['organization']['github_org'] == 'testuser'
|
||||
assert payload['organization']['openhands_org_id'] == keycloak_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_event_no_owner_in_payload(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: A GitHub event with no repository owner in payload
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Event is skipped with warning log
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
payload = {
|
||||
'repository': {},
|
||||
'sender': {'id': 12345, 'login': 'testuser'},
|
||||
}
|
||||
|
||||
with patch('server.services.automation_event_service.sio'), patch(
|
||||
'server.services.automation_event_service.logger'
|
||||
) as mock_logger, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_not_called()
|
||||
mock_logger.warning.assert_called()
|
||||
assert 'No repository owner' in str(mock_logger.warning.call_args)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_forward_event_org_not_claimed_and_not_personal(
|
||||
self, mock_token_manager, github_org_payload
|
||||
):
|
||||
"""
|
||||
GIVEN: A GitHub event from an org that isn't claimed (and isn't personal)
|
||||
WHEN: forward_github_event is called
|
||||
THEN: Event is skipped with warning log
|
||||
"""
|
||||
from server.services.automation_event_service import AutomationEventService
|
||||
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.resolve_org_for_repo',
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
), patch('server.services.automation_event_service.sio') as mock_sio, patch(
|
||||
'server.services.automation_event_service.logger'
|
||||
) as mock_logger, patch.object(
|
||||
AutomationEventService,
|
||||
'_send_to_automation_service',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_send:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = AutomationEventService(mock_token_manager)
|
||||
await service.forward_github_event(
|
||||
payload=github_org_payload,
|
||||
installation_id=99999,
|
||||
)
|
||||
|
||||
mock_send.assert_not_called()
|
||||
mock_logger.warning.assert_called()
|
||||
assert 'not claimed' in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
class TestBuildEventPayload:
|
||||
"""Tests for _build_event_payload method."""
|
||||
|
||||
def test_build_minimal_payload(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Org context and payload
|
||||
WHEN: _build_event_payload is called
|
||||
THEN: Minimal payload with only org + payload is returned
|
||||
"""
|
||||
from server.services.automation_event_service import OrgContext
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
|
||||
org_context = OrgContext(
|
||||
org_id=uuid.UUID('12345678-1234-5678-1234-567812345678'),
|
||||
github_org='test-org',
|
||||
)
|
||||
test_payload = {'action': 'opened', 'sender': {'login': 'user'}}
|
||||
|
||||
result = service._build_event_payload(org_context, test_payload)
|
||||
|
||||
assert result == {
|
||||
'organization': {
|
||||
'github_org': 'test-org',
|
||||
'openhands_org_id': '12345678-1234-5678-1234-567812345678',
|
||||
},
|
||||
'payload': test_payload,
|
||||
}
|
||||
# Verify NO access_control in payload
|
||||
assert 'access_control' not in result
|
||||
|
||||
|
||||
class TestSendToAutomationService:
|
||||
"""Tests for _send_to_automation_service method."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_success(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: AUTOMATION_SERVICE_URL is configured
|
||||
WHEN: _send_to_automation_service is called
|
||||
THEN: Request is sent with correct signature
|
||||
"""
|
||||
|
||||
org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
|
||||
payload = {'organization': {'github_org': 'test'}, 'payload': {}}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
mock_response.json = AsyncMock(return_value={'matched': 2})
|
||||
|
||||
mock_post_context = MagicMock()
|
||||
mock_post_context.__aenter__ = AsyncMock(return_value=mock_response)
|
||||
mock_post_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session_instance.post = MagicMock(return_value=mock_post_context)
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session_instance)
|
||||
mock_session_context.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_SERVICE_URL',
|
||||
'https://automation.example.com',
|
||||
), patch('server.services.automation_event_service.sio'), patch(
|
||||
'server.services.automation_event_service.aiohttp.ClientSession',
|
||||
return_value=mock_session_context,
|
||||
):
|
||||
service = create_service(mock_token_manager)
|
||||
await service._send_to_automation_service(org_id, payload)
|
||||
|
||||
# Verify the POST was called
|
||||
mock_session_instance.post.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_no_url_configured(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: AUTOMATION_SERVICE_URL is not configured
|
||||
WHEN: _send_to_automation_service is called
|
||||
THEN: Warning is logged and nothing is sent
|
||||
"""
|
||||
org_id = uuid.UUID('12345678-1234-5678-1234-567812345678')
|
||||
payload = {}
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_SERVICE_URL', None
|
||||
), patch('server.services.automation_event_service.sio'), patch(
|
||||
'server.services.automation_event_service.logger'
|
||||
) as mock_logger:
|
||||
service = create_service(mock_token_manager)
|
||||
await service._send_to_automation_service(org_id, payload)
|
||||
|
||||
mock_logger.warning.assert_called()
|
||||
assert 'not configured' in str(mock_logger.warning.call_args)
|
||||
|
||||
|
||||
class TestSignPayload:
|
||||
"""Tests for _sign_payload method."""
|
||||
|
||||
def test_sign_payload(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: A payload bytes
|
||||
WHEN: _sign_payload is called
|
||||
THEN: HMAC-SHA256 signature is returned in correct format
|
||||
"""
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
|
||||
'test-shared-secret',
|
||||
), patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
payload_bytes = b'{"test": "data"}'
|
||||
|
||||
signature = service._sign_payload(payload_bytes)
|
||||
|
||||
assert signature.startswith('sha256=')
|
||||
assert len(signature) == 71 # 'sha256=' + 64 hex chars
|
||||
|
||||
def test_sign_payload_uses_dedicated_secret(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: AUTOMATION_WEBHOOK_SECRET is configured
|
||||
WHEN: _sign_payload is called
|
||||
THEN: The dedicated secret is used (not GitHub webhook secret)
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
# Use the default test secret from CONSTANT_PATCHES
|
||||
shared_secret = 'test-shared-secret'
|
||||
payload_bytes = b'{"test": "data"}'
|
||||
|
||||
# Calculate expected signature with the shared secret
|
||||
expected_sig = hmac.new(
|
||||
shared_secret.encode('utf-8'),
|
||||
msg=payload_bytes,
|
||||
digestmod=hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
with patch(
|
||||
'server.services.automation_event_service.AUTOMATION_WEBHOOK_SECRET',
|
||||
shared_secret,
|
||||
), patch('server.services.automation_event_service.sio'):
|
||||
service = create_service(mock_token_manager)
|
||||
signature = service._sign_payload(payload_bytes)
|
||||
|
||||
assert signature == f'sha256={expected_sig}'
|
||||
|
||||
|
||||
class TestCacheHelpers:
|
||||
"""Tests for generic cache helper methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cached_value_hit(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Value exists in Redis cache
|
||||
WHEN: _get_cached_value is called
|
||||
THEN: Decoded string value is returned
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=b'cached-value')
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._get_cached_value('test-key')
|
||||
|
||||
assert result == 'cached-value'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cached_value_miss(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Value does not exist in Redis cache
|
||||
WHEN: _get_cached_value is called
|
||||
THEN: None is returned
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.get = AsyncMock(return_value=None)
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._get_cached_value('test-key')
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cached_value_redis_unavailable(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Redis is unavailable
|
||||
WHEN: _get_cached_value is called
|
||||
THEN: None is returned (graceful degradation)
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = None
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
result = await service._get_cached_value('test-key')
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_cached_value_success(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Redis is available
|
||||
WHEN: _set_cached_value is called
|
||||
THEN: Value is stored with TTL
|
||||
"""
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.setex = AsyncMock()
|
||||
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = mock_redis
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
await service._set_cached_value('test-key', 'test-value', 3600)
|
||||
|
||||
mock_redis.setex.assert_called_once_with('test-key', 3600, 'test-value')
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_cached_value_redis_unavailable(self, mock_token_manager):
|
||||
"""
|
||||
GIVEN: Redis is unavailable
|
||||
WHEN: _set_cached_value is called
|
||||
THEN: No error is raised (silent failure)
|
||||
"""
|
||||
with patch('server.services.automation_event_service.sio') as mock_sio:
|
||||
mock_sio.manager.redis = None
|
||||
|
||||
service = create_service(mock_token_manager)
|
||||
# Should not raise
|
||||
await service._set_cached_value('test-key', 'test-value', 3600)
|
||||
@@ -1,420 +0,0 @@
|
||||
"""Tests for OrgMemberFinancialService."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from server.routes.org_models import OrgMemberFinancialPage
|
||||
from server.services.org_member_financial_service import OrgMemberFinancialService
|
||||
from storage.org_member import OrgMember
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def org_id():
|
||||
"""Create a test organization ID."""
|
||||
return uuid.uuid4()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
"""Create a mock user."""
|
||||
user = MagicMock()
|
||||
user.email = 'test@example.com'
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_role():
|
||||
"""Create a mock role."""
|
||||
role = MagicMock()
|
||||
role.id = 1
|
||||
role.name = 'member'
|
||||
role.rank = 1000
|
||||
return role
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_org_member(org_id, mock_user, mock_role):
|
||||
"""Create a mock org member with user and role."""
|
||||
member = MagicMock(spec=OrgMember)
|
||||
member.org_id = org_id
|
||||
member.user_id = uuid.uuid4()
|
||||
member.role_id = mock_role.id
|
||||
member.status = 'active'
|
||||
member.user = mock_user
|
||||
member.role = mock_role
|
||||
return member
|
||||
|
||||
|
||||
class TestOrgMemberFinancialServiceGetFinancialData:
|
||||
"""Test cases for OrgMemberFinancialService.get_org_members_financial_data."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_paginated_financial_data_with_individual_budget(
|
||||
self, org_id, mock_org_member
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization with members having individual budget limits
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns financial data using individual spend for current_budget calc
|
||||
"""
|
||||
# Arrange
|
||||
user_id_str = str(mock_org_member.user_id)
|
||||
litellm_data = {
|
||||
'team_max_budget': 1000.0,
|
||||
'team_spend': 200.0,
|
||||
'members': {
|
||||
user_id_str: {'spend': 125.50, 'max_budget': 500.0} # Individual budget
|
||||
},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 1)
|
||||
mock_get_financial.return_value = litellm_data
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
page_id=None,
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OrgMemberFinancialPage)
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].user_id == user_id_str
|
||||
assert result.items[0].email == 'test@example.com'
|
||||
assert result.items[0].lifetime_spend == 125.50
|
||||
assert result.items[0].max_budget == 500.0
|
||||
# Individual budget: 500 - 125.50 = 374.50
|
||||
assert result.items[0].current_budget == 374.50
|
||||
assert result.current_page == 1
|
||||
assert result.per_page == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_shared_budget_using_team_spend(
|
||||
self, org_id, mock_org_member
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization with shared team budget
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Uses team_spend (not individual spend) for current_budget calculation
|
||||
"""
|
||||
# Arrange
|
||||
user_id_str = str(mock_org_member.user_id)
|
||||
litellm_data = {
|
||||
'team_max_budget': 500.0,
|
||||
'team_spend': 150.0, # Total team spend
|
||||
'members': {
|
||||
user_id_str: {
|
||||
'spend': 50.0,
|
||||
'max_budget': 500.0,
|
||||
'uses_shared_budget': True, # Explicitly using shared budget
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 1)
|
||||
mock_get_financial.return_value = litellm_data
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].lifetime_spend == 50.0 # Individual spend
|
||||
assert result.items[0].max_budget == 500.0
|
||||
# Shared budget: 500 - 150 (team_spend) = 350
|
||||
assert result.items[0].current_budget == 350.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_defaults_when_litellm_data_missing(
|
||||
self, org_id, mock_org_member
|
||||
):
|
||||
"""
|
||||
GIVEN: Organization with members but no LiteLLM data for them
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns financial data with default values (spend=0, budget=None)
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 1)
|
||||
mock_get_financial.return_value = {
|
||||
'team_max_budget': None,
|
||||
'team_spend': 0,
|
||||
'members': {},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].lifetime_spend == 0
|
||||
assert result.items[0].max_budget is None
|
||||
assert result.items[0].current_budget == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_litellm_failure_gracefully(self, org_id, mock_org_member):
|
||||
"""
|
||||
GIVEN: LiteLLM service throws an exception
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns financial data with default values (doesn't fail)
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 1)
|
||||
mock_get_financial.side_effect = Exception('LiteLLM unavailable')
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Assert - should not raise, returns defaults
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].lifetime_spend == 0
|
||||
assert result.items[0].max_budget is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_returns_next_page_id(self, org_id, mock_org_member):
|
||||
"""
|
||||
GIVEN: Organization with more members than limit
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns next_page_id for pagination
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 25) # 25 total
|
||||
mock_get_financial.return_value = {
|
||||
'team_max_budget': None,
|
||||
'team_spend': 0,
|
||||
'members': {},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
page_id='0',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.current_page == 1
|
||||
assert result.next_page_id == '10'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_no_next_page_on_last_page(self, org_id, mock_org_member):
|
||||
"""
|
||||
GIVEN: Organization on last page of results
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns next_page_id as None
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 5) # 5 total
|
||||
mock_get_financial.return_value = {
|
||||
'team_max_budget': None,
|
||||
'team_spend': 0,
|
||||
'members': {},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
page_id='0',
|
||||
limit=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.next_page_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_organization_returns_empty_items(self, org_id):
|
||||
"""
|
||||
GIVEN: Organization with no members
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns empty items list
|
||||
"""
|
||||
# Arrange
|
||||
with patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated:
|
||||
mock_get_paginated.return_value = ([], 0)
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 0
|
||||
assert result.next_page_id is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_page_id_raises_value_error(self, org_id):
|
||||
"""
|
||||
GIVEN: Invalid page_id format
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Raises ValueError
|
||||
"""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
page_id='invalid',
|
||||
)
|
||||
|
||||
assert 'Invalid page_id' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_page_id_raises_value_error(self, org_id):
|
||||
"""
|
||||
GIVEN: Negative page_id
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Raises ValueError
|
||||
"""
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
page_id='-5',
|
||||
)
|
||||
|
||||
assert 'Invalid page_id' in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_email_filter_to_store(self, org_id, mock_org_member):
|
||||
"""
|
||||
GIVEN: Email filter parameter
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Passes email filter to the store
|
||||
"""
|
||||
# Arrange
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([mock_org_member], 1)
|
||||
mock_get_financial.return_value = {
|
||||
'team_max_budget': None,
|
||||
'team_spend': 0,
|
||||
'members': {},
|
||||
}
|
||||
|
||||
# Act
|
||||
await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
email_filter='alice',
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_get_paginated.assert_called_once_with(
|
||||
org_id=org_id, offset=0, limit=10, email_filter='alice'
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_missing_user_relationship(self, org_id, mock_role):
|
||||
"""
|
||||
GIVEN: Member with no user relationship loaded
|
||||
WHEN: get_org_members_financial_data is called
|
||||
THEN: Returns None for email
|
||||
"""
|
||||
# Arrange
|
||||
member_no_user = MagicMock(spec=OrgMember)
|
||||
member_no_user.org_id = org_id
|
||||
member_no_user.user_id = uuid.uuid4()
|
||||
member_no_user.role_id = mock_role.id
|
||||
member_no_user.user = None # No user relationship
|
||||
|
||||
with (
|
||||
patch(
|
||||
'server.services.org_member_financial_service.OrgMemberStore.get_org_members_paginated',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_paginated,
|
||||
patch(
|
||||
'server.services.org_member_financial_service.LiteLlmManager.get_team_members_financial_data',
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_financial,
|
||||
):
|
||||
mock_get_paginated.return_value = ([member_no_user], 1)
|
||||
mock_get_financial.return_value = {
|
||||
'team_max_budget': None,
|
||||
'team_spend': 0,
|
||||
'members': {},
|
||||
}
|
||||
|
||||
# Act
|
||||
result = await OrgMemberFinancialService.get_org_members_financial_data(
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].email is None
|
||||
@@ -1,314 +0,0 @@
|
||||
"""Unit tests for ApiKeyStore system key functionality."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_key_store():
|
||||
"""Create ApiKeyStore instance."""
|
||||
return ApiKeyStore()
|
||||
|
||||
|
||||
class TestApiKeyStoreSystemKeys:
|
||||
"""Test cases for system API key functionality."""
|
||||
|
||||
def test_is_system_key_name_with_prefix(self, api_key_store):
|
||||
"""Test that names with __SYSTEM__: prefix are identified as system keys."""
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:automation') is True
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:test-key') is True
|
||||
assert api_key_store.is_system_key_name('__SYSTEM__:') is True
|
||||
|
||||
def test_is_system_key_name_without_prefix(self, api_key_store):
|
||||
"""Test that names without __SYSTEM__: prefix are not system keys."""
|
||||
assert api_key_store.is_system_key_name('my-key') is False
|
||||
assert api_key_store.is_system_key_name('automation') is False
|
||||
assert api_key_store.is_system_key_name('MCP_API_KEY') is False
|
||||
assert api_key_store.is_system_key_name('') is False
|
||||
|
||||
def test_is_system_key_name_none(self, api_key_store):
|
||||
"""Test that None is not a system key."""
|
||||
assert api_key_store.is_system_key_name(None) is False
|
||||
|
||||
def test_make_system_key_name(self, api_key_store):
|
||||
"""Test system key name generation."""
|
||||
assert (
|
||||
api_key_store.make_system_key_name('automation') == '__SYSTEM__:automation'
|
||||
)
|
||||
assert api_key_store.make_system_key_name('test-key') == '__SYSTEM__:test-key'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_creates_new(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test creating a new system API key when none exists."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
api_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert api_key.startswith('sk-oh-')
|
||||
assert len(api_key) == len('sk-oh-') + 32
|
||||
|
||||
# Verify the key was created in the database
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.key == api_key))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is not None
|
||||
assert key_record.user_id == user_id
|
||||
assert key_record.org_id == org_id
|
||||
assert key_record.name == '__SYSTEM__:automation'
|
||||
assert key_record.expires_at is None # System keys never expire
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_returns_existing(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that existing valid system key is returned."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Create the first key
|
||||
first_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
# Request again - should return the same key
|
||||
second_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert first_key == second_key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_different_names(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that different names create different keys."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
key1 = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='automation-1',
|
||||
)
|
||||
|
||||
key2 = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='automation-2',
|
||||
)
|
||||
|
||||
assert key1 != key2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_or_create_system_api_key_reissues_expired(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that expired system key is replaced with a new one."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
key_name = 'automation'
|
||||
system_key_name = '__SYSTEM__:automation'
|
||||
|
||||
# First, manually create an expired key
|
||||
expired_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
async with async_session_maker() as session:
|
||||
expired_key = ApiKey(
|
||||
key='sk-oh-expired-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=system_key_name,
|
||||
expires_at=expired_time.replace(tzinfo=None),
|
||||
)
|
||||
session.add(expired_key)
|
||||
await session.commit()
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Request the key - should create a new one
|
||||
new_key = await api_key_store.get_or_create_system_api_key(
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name=key_name,
|
||||
)
|
||||
|
||||
assert new_key != 'sk-oh-expired-key-12345678901234567890'
|
||||
assert new_key.startswith('sk-oh-')
|
||||
|
||||
# Verify old key was deleted and new key exists
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(ApiKey).filter(ApiKey.name == system_key_name)
|
||||
)
|
||||
keys = result.scalars().all()
|
||||
assert len(keys) == 1
|
||||
assert keys[0].key == new_key
|
||||
assert keys[0].expires_at is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_api_keys_excludes_system_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that list_api_keys excludes system keys."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a user key and a system key
|
||||
async with async_session_maker() as session:
|
||||
user_key = ApiKey(
|
||||
key='sk-oh-user-key-123456789012345678901',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='my-user-key',
|
||||
)
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
mcp_key = ApiKey(
|
||||
key='sk-oh-mcp-key-1234567890123456789012',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='MCP_API_KEY',
|
||||
)
|
||||
session.add(user_key)
|
||||
session.add(system_key)
|
||||
session.add(mcp_key)
|
||||
await session.commit()
|
||||
|
||||
# Mock UserStore.get_user_by_id to return a user with the correct org
|
||||
mock_user = MagicMock()
|
||||
mock_user.current_org_id = org_id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
with patch(
|
||||
'storage.api_key_store.UserStore.get_user_by_id', new_callable=AsyncMock
|
||||
) as mock_get_user:
|
||||
mock_get_user.return_value = mock_user
|
||||
keys = await api_key_store.list_api_keys(user_id)
|
||||
|
||||
# Should only return the user key
|
||||
assert len(keys) == 1
|
||||
assert keys[0].name == 'my-user-key'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_protects_system_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that system keys cannot be deleted by users."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a system key
|
||||
async with async_session_maker() as session:
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
session.add(system_key)
|
||||
await session.commit()
|
||||
key_id = system_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Attempt to delete without allow_system flag
|
||||
result = await api_key_store.delete_api_key_by_id(
|
||||
key_id, allow_system=False
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
# Verify the key still exists
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_allows_system_with_flag(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that system keys can be deleted with allow_system=True."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a system key
|
||||
async with async_session_maker() as session:
|
||||
system_key = ApiKey(
|
||||
key='sk-oh-system-key-12345678901234567890',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='__SYSTEM__:automation',
|
||||
)
|
||||
session.add(system_key)
|
||||
await session.commit()
|
||||
key_id = system_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Delete with allow_system=True
|
||||
result = await api_key_store.delete_api_key_by_id(key_id, allow_system=True)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the key was deleted
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_api_key_by_id_allows_regular_keys(
|
||||
self, api_key_store, async_session_maker
|
||||
):
|
||||
"""Test that regular keys can be deleted normally."""
|
||||
user_id = '5594c7b6-f959-4b81-92e9-b09c206f5081'
|
||||
org_id = uuid.UUID('5594c7b6-f959-4b81-92e9-b09c206f5081')
|
||||
|
||||
# Create a regular key
|
||||
async with async_session_maker() as session:
|
||||
regular_key = ApiKey(
|
||||
key='sk-oh-regular-key-1234567890123456789',
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
name='my-regular-key',
|
||||
)
|
||||
session.add(regular_key)
|
||||
await session.commit()
|
||||
key_id = regular_key.id
|
||||
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
# Delete without allow_system flag - should work for regular keys
|
||||
result = await api_key_store.delete_api_key_by_id(
|
||||
key_id, allow_system=False
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the key was deleted
|
||||
async with async_session_maker() as session:
|
||||
result = await session.execute(select(ApiKey).filter(ApiKey.id == key_id))
|
||||
key_record = result.scalars().first()
|
||||
assert key_record is None
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Tests for OrgGitClaimStore with real in-memory SQLite database.
|
||||
|
||||
Covers CRUD operations and unique constraint enforcement.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from storage.org import Org
|
||||
from storage.org_git_claim_store import OrgGitClaimStore
|
||||
from storage.org_member import OrgMember
|
||||
from storage.role import Role
|
||||
from storage.user import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def seed_org_and_user(async_session_maker):
|
||||
"""Create a minimal org, role, user, and org_member for FK satisfaction."""
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
role_id = 1
|
||||
|
||||
async with async_session_maker() as session:
|
||||
session.add(Role(id=role_id, name='owner', rank=10))
|
||||
session.add(Org(id=org_id, name='test-org'))
|
||||
session.add(User(id=user_id, current_org_id=org_id, role_id=role_id))
|
||||
session.add(
|
||||
OrgMember(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
status='active',
|
||||
llm_api_key='test-key',
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
return org_id, user_id
|
||||
|
||||
|
||||
class TestOrgGitClaimStoreCreate:
|
||||
"""Tests for OrgGitClaimStore.create_claim."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_claim_persists_and_returns(
|
||||
self, async_session_maker, seed_org_and_user
|
||||
):
|
||||
"""A new claim is persisted with correct fields and returned."""
|
||||
org_id, user_id = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
claim = await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='OpenHands',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
assert claim.org_id == org_id
|
||||
assert claim.provider == 'github'
|
||||
assert claim.git_organization == 'OpenHands'
|
||||
assert claim.claimed_by == user_id
|
||||
assert claim.claimed_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_duplicate_raises_integrity_error(
|
||||
self, async_session_maker, seed_org_and_user
|
||||
):
|
||||
"""Creating a duplicate (provider, git_organization) violates the unique constraint."""
|
||||
org_id, user_id = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='DuplicateOrg',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='DuplicateOrg',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
|
||||
class TestOrgGitClaimStoreLookup:
|
||||
"""Tests for OrgGitClaimStore lookup methods."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_claim_by_provider_and_git_org_found(
|
||||
self, async_session_maker, seed_org_and_user
|
||||
):
|
||||
"""Returns the claim when provider+git_organization exists."""
|
||||
org_id, user_id = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='gitlab',
|
||||
git_organization='MyGroup',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
result = await OrgGitClaimStore.get_claim_by_provider_and_git_org(
|
||||
provider='gitlab', git_organization='MyGroup'
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.provider == 'gitlab'
|
||||
assert result.git_organization == 'MyGroup'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_claim_by_provider_and_git_org_not_found(
|
||||
self, async_session_maker
|
||||
):
|
||||
"""Returns None when no matching claim exists."""
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
result = await OrgGitClaimStore.get_claim_by_provider_and_git_org(
|
||||
provider='github', git_organization='NonExistent'
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_claims_by_org_id(self, async_session_maker, seed_org_and_user):
|
||||
"""Returns all claims belonging to the given org."""
|
||||
org_id, user_id = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='Org1',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='gitlab',
|
||||
git_organization='Org2',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
claims = await OrgGitClaimStore.get_claims_by_org_id(org_id)
|
||||
|
||||
assert len(claims) == 2
|
||||
|
||||
|
||||
class TestOrgGitClaimStoreDelete:
|
||||
"""Tests for OrgGitClaimStore.delete_claim."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_existing_claim_returns_true(
|
||||
self, async_session_maker, seed_org_and_user
|
||||
):
|
||||
"""Deleting an existing claim returns True and removes it from the DB."""
|
||||
org_id, user_id = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
claim = await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='ToDelete',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
result = await OrgGitClaimStore.delete_claim(
|
||||
claim_id=claim.id, org_id=org_id
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent_claim_returns_false(
|
||||
self, async_session_maker, seed_org_and_user
|
||||
):
|
||||
"""Deleting a claim that doesn't exist returns False."""
|
||||
org_id, _ = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
result = await OrgGitClaimStore.delete_claim(
|
||||
claim_id=uuid.uuid4(), org_id=org_id
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_claim_wrong_org_returns_false(
|
||||
self, async_session_maker, seed_org_and_user
|
||||
):
|
||||
"""Deleting a claim with a mismatched org_id returns False."""
|
||||
org_id, user_id = seed_org_and_user
|
||||
|
||||
with patch('storage.org_git_claim_store.a_session_maker', async_session_maker):
|
||||
claim = await OrgGitClaimStore.create_claim(
|
||||
org_id=org_id,
|
||||
provider='github',
|
||||
git_organization='WrongOrg',
|
||||
claimed_by=user_id,
|
||||
)
|
||||
|
||||
result = await OrgGitClaimStore.delete_claim(
|
||||
claim_id=claim.id, org_id=uuid.uuid4()
|
||||
)
|
||||
|
||||
assert result is False
|
||||
@@ -280,8 +280,6 @@ class TestSaasSQLAppConversationInfoService:
|
||||
stored_metadata.reasoning_tokens = 0
|
||||
stored_metadata.context_window = 0
|
||||
stored_metadata.per_turn_token = 0
|
||||
stored_metadata.public = None
|
||||
stored_metadata.tags = {}
|
||||
|
||||
saas_metadata = MagicMock(spec=StoredConversationMetadataSaas)
|
||||
saas_metadata.user_id = UUID('a1111111-1111-1111-1111-111111111111')
|
||||
@@ -992,414 +990,3 @@ class TestSandboxIdFilterSaas:
|
||||
sandbox_id__eq=shared_sandbox_id
|
||||
)
|
||||
assert count == 1
|
||||
|
||||
|
||||
class TestApiKeyOrgIdHandling:
|
||||
"""Test suite for API key organization ID handling in save_app_conversation_info.
|
||||
|
||||
These tests verify that when a conversation is created using API key authentication,
|
||||
the conversation is associated with the API key's bound organization, not the user's
|
||||
currently selected organization.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_org_id_used_when_available(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that API key's org_id is used when saving conversation via API key auth.
|
||||
|
||||
This tests the main bug fix: when a user creates an API key in Personal Workspace,
|
||||
then switches to OpenHands org in browser, and uses the API key to create a
|
||||
conversation, the conversation should be saved in Personal Workspace (API key's org),
|
||||
not OpenHands (user's current org).
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create a mock UserAuth with API key org_id
|
||||
@dataclass
|
||||
class MockUserAuth:
|
||||
user_id: str
|
||||
api_key_org_id: UUID | None = None
|
||||
|
||||
async def get_user_id(self) -> str:
|
||||
return self.user_id
|
||||
|
||||
def get_api_key_org_id(self) -> UUID | None:
|
||||
return self.api_key_org_id
|
||||
|
||||
# Create a mock UserContext that wraps the MockUserAuth
|
||||
@dataclass
|
||||
class MockAuthUserContext:
|
||||
user_auth: MockUserAuth
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.user_auth.get_user_id()
|
||||
|
||||
# Simulate: User1's current org is ORG2, but API key is bound to ORG1
|
||||
# First, update user1's current_org_id to ORG2
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID # User is viewing ORG2
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Create service with mock auth context where API key is bound to ORG1
|
||||
mock_user_auth = MockUserAuth(
|
||||
user_id=str(USER1_ID),
|
||||
api_key_org_id=ORG1_ID, # API key created in ORG1
|
||||
)
|
||||
mock_context = MockAuthUserContext(user_auth=mock_user_auth)
|
||||
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=mock_context,
|
||||
)
|
||||
|
||||
# Create and save a conversation
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_api_key_test',
|
||||
title='API Key Created Conversation',
|
||||
)
|
||||
await service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Verify: SAAS metadata should have ORG1 (API key's org), not ORG2 (user's current org)
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert saas_metadata is not None, 'SAAS metadata should be created'
|
||||
assert saas_metadata.user_id == USER1_ID
|
||||
assert (
|
||||
saas_metadata.org_id == ORG1_ID
|
||||
), 'Conversation should be in API key org (ORG1), not user current org (ORG2)'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_api_key_without_org_uses_user_current_org(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that legacy API keys (without org_id) fall back to user's current org.
|
||||
|
||||
Legacy API keys created before the org_id feature was added will have
|
||||
api_key_org_id = None. In this case, we should fall back to the user's
|
||||
current_org_id.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Create a mock UserAuth with API key but NO org_id (legacy key)
|
||||
@dataclass
|
||||
class MockUserAuth:
|
||||
user_id: str
|
||||
api_key_org_id: UUID | None = None
|
||||
|
||||
async def get_user_id(self) -> str:
|
||||
return self.user_id
|
||||
|
||||
def get_api_key_org_id(self) -> UUID | None:
|
||||
return self.api_key_org_id
|
||||
|
||||
@dataclass
|
||||
class MockAuthUserContext:
|
||||
user_auth: MockUserAuth
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.user_auth.get_user_id()
|
||||
|
||||
# Create service with mock auth context where API key has NO org_id
|
||||
mock_user_auth = MockUserAuth(
|
||||
user_id=str(USER1_ID),
|
||||
api_key_org_id=None, # Legacy key without org binding
|
||||
)
|
||||
mock_context = MockAuthUserContext(user_auth=mock_user_auth)
|
||||
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=mock_context,
|
||||
)
|
||||
|
||||
# Create and save a conversation
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_legacy_key_test',
|
||||
title='Legacy API Key Conversation',
|
||||
)
|
||||
await service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Verify: SAAS metadata should use user's current org (ORG1) as fallback
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert saas_metadata is not None, 'SAAS metadata should be created'
|
||||
assert saas_metadata.user_id == USER1_ID
|
||||
assert (
|
||||
saas_metadata.org_id == ORG1_ID
|
||||
), 'Legacy key should fall back to user current org (ORG1)'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cookie_auth_without_api_key_uses_user_current_org(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test that cookie auth (no API key) uses user's current org.
|
||||
|
||||
When authenticated via browser cookie (no API key), there's no
|
||||
get_api_key_org_id method, so we use user's current_org_id.
|
||||
This is already tested by other tests using SpecifyUserContext,
|
||||
but we explicitly test the case where user_context doesn't have user_auth.
|
||||
"""
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
# Use SpecifyUserContext which doesn't have user_auth attribute
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
|
||||
# Create and save a conversation
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_cookie_auth_test',
|
||||
title='Cookie Auth Conversation',
|
||||
)
|
||||
await service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Verify: SAAS metadata should use user's current org (ORG1)
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert saas_metadata is not None, 'SAAS metadata should be created'
|
||||
assert saas_metadata.user_id == USER1_ID
|
||||
assert (
|
||||
saas_metadata.org_id == ORG1_ID
|
||||
), 'Cookie auth should use user current org (ORG1)'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_org_isolation_cross_org_visibility(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""Test end-to-end: conversation created via API key is visible in correct org.
|
||||
|
||||
Simulates the full bug scenario:
|
||||
1. Create conversation via API key (bound to ORG1)
|
||||
2. User switches to ORG2
|
||||
3. User should NOT see the conversation in ORG2
|
||||
4. User switches back to ORG1
|
||||
5. User should see the conversation in ORG1
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class MockUserAuth:
|
||||
user_id: str
|
||||
api_key_org_id: UUID | None = None
|
||||
|
||||
async def get_user_id(self) -> str:
|
||||
return self.user_id
|
||||
|
||||
def get_api_key_org_id(self) -> UUID | None:
|
||||
return self.api_key_org_id
|
||||
|
||||
@dataclass
|
||||
class MockAuthUserContext:
|
||||
user_auth: MockUserAuth
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.user_auth.get_user_id()
|
||||
|
||||
# Step 1: Create conversation via API key bound to ORG1
|
||||
mock_user_auth = MockUserAuth(
|
||||
user_id=str(USER1_ID),
|
||||
api_key_org_id=ORG1_ID,
|
||||
)
|
||||
mock_context = MockAuthUserContext(user_auth=mock_user_auth)
|
||||
|
||||
api_key_service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=mock_context,
|
||||
)
|
||||
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_e2e_api_key',
|
||||
title='E2E API Key Conversation',
|
||||
)
|
||||
await api_key_service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Step 2: Switch user to ORG2 in browser session
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG2_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Step 3: User in ORG2 should NOT see the conversation
|
||||
user_service_org2 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
page_org2 = await user_service_org2.search_app_conversation_info()
|
||||
assert (
|
||||
len(page_org2.items) == 0
|
||||
), 'User in ORG2 should not see conversation created via API key in ORG1'
|
||||
|
||||
# Also verify get_app_conversation_info returns None
|
||||
conv_from_org2 = await user_service_org2.get_app_conversation_info(conv_id)
|
||||
assert (
|
||||
conv_from_org2 is None
|
||||
), 'User in ORG2 should not access conversation from ORG1'
|
||||
|
||||
# Step 4: Switch user back to ORG1
|
||||
result = await async_session_with_users.execute(
|
||||
select(User).where(User.id == USER1_ID)
|
||||
)
|
||||
user_to_update = result.scalars().first()
|
||||
user_to_update.current_org_id = ORG1_ID
|
||||
await async_session_with_users.commit()
|
||||
async_session_with_users.expire_all()
|
||||
|
||||
# Step 5: User in ORG1 should see the conversation
|
||||
user_service_org1 = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=SpecifyUserContext(user_id=str(USER1_ID)),
|
||||
)
|
||||
page_org1 = await user_service_org1.search_app_conversation_info()
|
||||
assert (
|
||||
len(page_org1.items) == 1
|
||||
), 'User in ORG1 should see conversation created via API key in ORG1'
|
||||
assert page_org1.items[0].id == conv_id
|
||||
assert page_org1.items[0].title == 'E2E API Key Conversation'
|
||||
|
||||
# Also verify get_app_conversation_info works
|
||||
conv_from_org1 = await user_service_org1.get_app_conversation_info(conv_id)
|
||||
assert conv_from_org1 is not None
|
||||
assert conv_from_org1.id == conv_id
|
||||
|
||||
|
||||
class TestResolverOrgIdRouting:
|
||||
"""Test that resolver_org_id on user_context overrides the default org_id."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_uses_resolver_org_id_when_set_on_context(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""When user_context has resolver_org_id, conversation is saved in that org."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Arrange: user1 is in ORG1, but resolver routes to ORG2
|
||||
# Use spec to prevent MagicMock from auto-creating undefined attributes
|
||||
mock_context = MagicMock(spec=ResolverUserContext)
|
||||
mock_context.get_user_id = AsyncMock(return_value=str(USER1_ID))
|
||||
mock_context.resolver_org_id = ORG2_ID
|
||||
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=mock_context,
|
||||
)
|
||||
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_resolver',
|
||||
title='Resolver Routed Conversation',
|
||||
)
|
||||
|
||||
# Act
|
||||
await service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Assert: conversation is stored in ORG2, not user's default ORG1
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert saas_metadata is not None
|
||||
assert saas_metadata.org_id == ORG2_ID
|
||||
assert saas_metadata.user_id == USER1_ID
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_uses_default_org_when_resolver_org_id_is_none(
|
||||
self,
|
||||
async_session_with_users: AsyncSession,
|
||||
):
|
||||
"""When resolver_org_id is None, conversation uses user's default org."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from storage.stored_conversation_metadata_saas import (
|
||||
StoredConversationMetadataSaas,
|
||||
)
|
||||
|
||||
from enterprise.integrations.resolver_context import ResolverUserContext
|
||||
|
||||
# Arrange: user1 in ORG1 with no resolver override
|
||||
# Use spec to prevent MagicMock from auto-creating undefined attributes
|
||||
mock_context = MagicMock(spec=ResolverUserContext)
|
||||
mock_context.get_user_id = AsyncMock(return_value=str(USER1_ID))
|
||||
mock_context.resolver_org_id = None
|
||||
|
||||
service = SaasSQLAppConversationInfoService(
|
||||
db_session=async_session_with_users,
|
||||
user_context=mock_context,
|
||||
)
|
||||
|
||||
conv_id = uuid4()
|
||||
conv_info = AppConversationInfo(
|
||||
id=conv_id,
|
||||
created_by_user_id=str(USER1_ID),
|
||||
sandbox_id='sandbox_default',
|
||||
title='Default Org Conversation',
|
||||
)
|
||||
|
||||
# Act
|
||||
await service.save_app_conversation_info(conv_info)
|
||||
|
||||
# Assert: conversation stored in user's default ORG1
|
||||
saas_query = select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id == str(conv_id)
|
||||
)
|
||||
result = await async_session_with_users.execute(saas_query)
|
||||
saas_metadata = result.scalar_one_or_none()
|
||||
|
||||
assert saas_metadata is not None
|
||||
assert saas_metadata.org_id == ORG1_ID
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
from storage.api_key import ApiKey
|
||||
from storage.api_key_store import ApiKeyStore, ApiKeyValidationResult
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -110,8 +110,8 @@ async def test_create_api_key(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
"""Test validating a valid API key returns user_id and org_id."""
|
||||
# Arrange
|
||||
"""Test validating a valid API key."""
|
||||
# Setup - create an API key in the database
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-api-key'
|
||||
@@ -126,19 +126,13 @@ async def test_validate_api_key_valid(api_key_store, async_session_maker):
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
key_id = key_record.id
|
||||
|
||||
# Act
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id == org_id
|
||||
assert result.key_id == key_id
|
||||
assert result.key_name == 'Test Key'
|
||||
# Verify
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -203,7 +197,7 @@ async def test_validate_api_key_valid_timezone_naive(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a valid API key with timezone-naive datetime from database."""
|
||||
# Arrange
|
||||
# Setup - create a valid API key with timezone-naive datetime (future date)
|
||||
user_id = str(uuid.uuid4())
|
||||
org_id = uuid.uuid4()
|
||||
api_key_value = 'test-valid-naive-key'
|
||||
@@ -220,44 +214,12 @@ async def test_validate_api_key_valid_timezone_naive(
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
# Execute - patch a_session_maker to use test's async session maker
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id == org_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validate_api_key_legacy_without_org_id(
|
||||
api_key_store, async_session_maker
|
||||
):
|
||||
"""Test validating a legacy API key without org_id returns None for org_id."""
|
||||
# Arrange
|
||||
user_id = str(uuid.uuid4())
|
||||
api_key_value = 'test-legacy-key-no-org'
|
||||
|
||||
async with async_session_maker() as session:
|
||||
key_record = ApiKey(
|
||||
key=api_key_value,
|
||||
user_id=user_id,
|
||||
org_id=None, # Legacy key without org binding
|
||||
name='Legacy Key',
|
||||
)
|
||||
session.add(key_record)
|
||||
await session.commit()
|
||||
|
||||
# Act
|
||||
with patch('storage.api_key_store.a_session_maker', async_session_maker):
|
||||
result = await api_key_store.validate_api_key(api_key_value)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, ApiKeyValidationResult)
|
||||
assert result is not None
|
||||
assert result.user_id == user_id
|
||||
assert result.org_id is None
|
||||
# Verify
|
||||
assert result == user_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user