mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
245 Commits
fix-basic-
...
feat/chat-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
087669601a | ||
|
|
8b8ed5be96 | ||
|
|
c1328f512d | ||
|
|
e2805dea75 | ||
|
|
127e611706 | ||
|
|
a176a135da | ||
|
|
ab78d7d6e8 | ||
|
|
4eb6e4da09 | ||
|
|
7e66304746 | ||
|
|
a8b12e8eb8 | ||
|
|
53bb82fe2e | ||
|
|
db40eb1e94 | ||
|
|
debbaae385 | ||
|
|
5e5950b091 | ||
|
|
c7ff560465 | ||
|
|
3432bbbb88 | ||
|
|
fc24be2627 | ||
|
|
bc72b38d6e | ||
|
|
145f1266e6 | ||
|
|
e12dd924ce | ||
|
|
598b381e3d | ||
|
|
698cfc2520 | ||
|
|
8356170193 | ||
|
|
fe2e50fc7d | ||
|
|
ef840b046a | ||
|
|
c8fe39b176 | ||
|
|
8c46df6b59 | ||
|
|
b37adbc1e6 | ||
|
|
3ec999e88a | ||
|
|
d1c2185d99 | ||
|
|
ede203add3 | ||
|
|
b0cdd0358f | ||
|
|
6186685ebc | ||
|
|
2d7362bf26 | ||
|
|
1f1fb5a954 | ||
|
|
41d8bd28e9 | ||
|
|
6c394cc415 | ||
|
|
4c380e5a58 | ||
|
|
ded0363e36 | ||
|
|
d8444ef626 | ||
|
|
64e96b7c3c | ||
|
|
dcef5ae1f1 | ||
|
|
cfbf29f6e8 | ||
|
|
59b369047f | ||
|
|
07468e39f7 | ||
|
|
0b0bfdff05 | ||
|
|
42b0a89366 | ||
|
|
e78d7de0c0 | ||
|
|
6751bba939 | ||
|
|
039e966dad | ||
|
|
a1f73bb4c6 | ||
|
|
bf769d1744 | ||
|
|
15e9435b35 | ||
|
|
3e15b849a3 | ||
|
|
c32934ed2f | ||
|
|
518fb2ee24 | ||
|
|
eeac9f14a3 | ||
|
|
039e208167 | ||
|
|
6f8bf24226 | ||
|
|
6e9e906946 | ||
|
|
30245dedef | ||
|
|
3bf019b045 | ||
|
|
ab02c73c7c | ||
|
|
b8db9ecd53 | ||
|
|
b86b2f16af | ||
|
|
a11435b061 | ||
|
|
f01c8dd955 | ||
|
|
baae3780e5 | ||
|
|
1fb28604e6 | ||
|
|
8dac1095d7 | ||
|
|
222e8bd03d | ||
|
|
0ae9128ed7 | ||
|
|
4fc5351ed7 | ||
|
|
a1271dc129 | ||
|
|
45b970c0dd | ||
|
|
4688741324 | ||
|
|
79a0cee7d9 | ||
|
|
d19ba0d166 | ||
|
|
63654c4643 | ||
|
|
2f11f6a39a | ||
|
|
5cad59a661 | ||
|
|
6dff07ea35 | ||
|
|
117ea0466d | ||
|
|
6822169594 | ||
|
|
35024aeffe | ||
|
|
a051f7d6f6 | ||
|
|
4fe3da498a | ||
|
|
b890e53a6e | ||
|
|
8aa730105a | ||
|
|
e7934ea6e5 | ||
|
|
a927b9dc73 | ||
|
|
0b9fd442bd | ||
|
|
501bf64312 | ||
|
|
6f1a7ddadd | ||
|
|
f3026583d7 | ||
|
|
4a3a42c858 | ||
|
|
2d057bb7b4 | ||
|
|
a7a4eb2664 | ||
|
|
0c7ce4ad48 | ||
|
|
4dab34e7b0 | ||
|
|
f8bbd352a9 | ||
|
|
17347a95f8 | ||
|
|
01ef87aaaa | ||
|
|
8059c18b57 | ||
|
|
c82ee4c7db | ||
|
|
7fdb423f99 | ||
|
|
530065dfa7 | ||
|
|
a4cd2d81a5 | ||
|
|
003b430e96 | ||
|
|
d63565186e | ||
|
|
5f42d03ec5 | ||
|
|
62241e2e00 | ||
|
|
f5197bd76a | ||
|
|
e1408f7b15 | ||
|
|
d6b8d80026 | ||
|
|
1e6a92b454 | ||
|
|
b4a3e5db2f | ||
|
|
f9d553d0bb | ||
|
|
f6f6c1ab25 | ||
|
|
c511a89426 | ||
|
|
1f82ff04d9 | ||
|
|
eec17311c7 | ||
|
|
c34fdf4b37 | ||
|
|
25076ee44c | ||
|
|
baaec8473a | ||
|
|
402fa47422 | ||
|
|
8dde385843 | ||
|
|
a905e35531 | ||
|
|
1f185173b7 | ||
|
|
ddc7a78723 | ||
|
|
a29ed4d926 | ||
|
|
b8ab4bb44e | ||
|
|
ddd544f8d6 | ||
|
|
3804b66e32 | ||
|
|
b97adf392a | ||
|
|
dcb584913a | ||
|
|
d2fd54a083 | ||
|
|
112d863287 | ||
|
|
c8680caec3 | ||
|
|
d4b9fb1d03 | ||
|
|
409df1287d | ||
|
|
a92bfe6cc0 | ||
|
|
f93e3254d3 | ||
|
|
0476d57451 | ||
|
|
a4cd21e155 | ||
|
|
7f3af371d1 | ||
|
|
1421794c1b | ||
|
|
2fc689457c | ||
|
|
3161b365a8 | ||
|
|
18ab56ef4e | ||
|
|
a9c0df778c | ||
|
|
51b989b5f8 | ||
|
|
dc039d81d6 | ||
|
|
8e4559b14a | ||
|
|
b84f352b63 | ||
|
|
a0dba6124a | ||
|
|
951739f3eb | ||
|
|
0f1ad46a47 | ||
|
|
5367bef43a | ||
|
|
3afeccfe7f | ||
|
|
0677c035ff | ||
|
|
68165b52d9 | ||
|
|
dcc8217317 | ||
|
|
d1410949ff | ||
|
|
a6c0d80fe1 | ||
|
|
0efb1db85d | ||
|
|
8e0f74c92c | ||
|
|
6e1ba3d836 | ||
|
|
0ec97893d1 | ||
|
|
ddb809bc43 | ||
|
|
872f2b87f2 | ||
|
|
ee86005a3a | ||
|
|
d4aa30580b | ||
|
|
2f0e879129 | ||
|
|
3bc2ef954e | ||
|
|
32ab2a24c6 | ||
|
|
a6e148d1e6 | ||
|
|
3fc977eddd | ||
|
|
89a6890269 | ||
|
|
8927ac2230 | ||
|
|
f3429e33ca | ||
|
|
7cd219792b | ||
|
|
2aabe2ed8c | ||
|
|
731a9a813e | ||
|
|
123e556fed | ||
|
|
6676cae249 | ||
|
|
fede37b496 | ||
|
|
3bcd6f18df | ||
|
|
0da18440c2 | ||
|
|
ac76e10048 | ||
|
|
b98bae8b5f | ||
|
|
516721d1ee | ||
|
|
4d6f66ca28 | ||
|
|
b18568da0b | ||
|
|
83dd3c169c | ||
|
|
35bddb14f1 | ||
|
|
e8425218e2 | ||
|
|
0a879fa781 | ||
|
|
41e142bbab | ||
|
|
b06b9eedac | ||
|
|
a9afafa991 | ||
|
|
663ace4b39 | ||
|
|
2d085a6e0a | ||
|
|
8b7112abe8 | ||
|
|
34547ba947 | ||
|
|
5f958ab60d | ||
|
|
d7656bf1c9 | ||
|
|
2bc107564c | ||
|
|
85eb1e1504 | ||
|
|
cd235cc8c7 | ||
|
|
40f52dfabc | ||
|
|
bab7bf85e8 | ||
|
|
c856537f65 | ||
|
|
736f5b2255 | ||
|
|
c1d9d11772 | ||
|
|
85244499fe | ||
|
|
c55084e223 | ||
|
|
e3bb75deb4 | ||
|
|
1948200762 | ||
|
|
affe0af361 | ||
|
|
f20c956196 | ||
|
|
4a089a3a0d | ||
|
|
aa0b2d0b74 | ||
|
|
bef9b80b9d | ||
|
|
c4a90b1f89 | ||
|
|
0d13c57d9f | ||
|
|
b3422f1275 | ||
|
|
f139a9970b | ||
|
|
54d156122c | ||
|
|
ac072bf686 | ||
|
|
a53812c029 | ||
|
|
1d1c0925b5 | ||
|
|
872f41e3c0 | ||
|
|
d43ff82534 | ||
|
|
8cd8c011b2 | ||
|
|
5c68b10983 | ||
|
|
a97fad1976 | ||
|
|
4c3542a91c | ||
|
|
f460057f58 | ||
|
|
4fa2ad0f47 | ||
|
|
dd8be12809 | ||
|
|
89475095d9 | ||
|
|
05d5f8848a | ||
|
|
ee2885eb0b | ||
|
|
545257f870 |
37
.agents/skills/upcoming-release/SKILL.md
Normal file
37
.agents/skills/upcoming-release/SKILL.md
Normal file
@@ -0,0 +1,37 @@
|
||||
---
|
||||
name: upcoming-release
|
||||
description: This skill should be used when the user asks to "generate release notes", "list upcoming release PRs", "summarize upcoming release", "/upcoming-release", or needs to know what changes are part of an upcoming release.
|
||||
---
|
||||
|
||||
# Upcoming Release Summary
|
||||
|
||||
Generate a concise summary of PRs included in the upcoming release.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Two commit SHAs are required:
|
||||
- **First SHA**: The older commit (current release)
|
||||
- **Second SHA**: The newer commit (what's being released)
|
||||
|
||||
If the user does not provide both SHAs, ask for them before proceeding.
|
||||
|
||||
## Workflow
|
||||
|
||||
1. Run the script from the repository root with the `--json` flag:
|
||||
```bash
|
||||
.github/scripts/find_prs_between_commits.py <older-sha> <newer-sha> --json
|
||||
```
|
||||
|
||||
2. Filter out PRs that are:
|
||||
- Chores
|
||||
- Dependency updates
|
||||
- Adding logs
|
||||
- Refactors
|
||||
|
||||
3. Categorize the remaining PRs:
|
||||
- **Features** - New functionality
|
||||
- **Bug fixes** - Corrections to existing behavior
|
||||
- **Security/CVE fixes** - Security-related changes
|
||||
- **Other** - Everything else
|
||||
|
||||
4. Format the output with PRs listed under their category, including the PR number and a brief description.
|
||||
123
.agents/skills/update-sdk/SKILL.md
Normal file
123
.agents/skills/update-sdk/SKILL.md
Normal file
@@ -0,0 +1,123 @@
|
||||
---
|
||||
name: update-sdk
|
||||
description: This skill should be used when the user asks to "update SDK", "bump SDK version", "pin SDK to a commit", "test unreleased SDK", "update agent-server image", "bump the version", "prepare a release", "what files change for a release", or needs to know how SDK packages are managed in the OpenHands repository. For detailed reference material, see references/docker-image-locations.md and references/sdk-pinning-examples.md in this skill directory.
|
||||
---
|
||||
|
||||
# Update SDK
|
||||
|
||||
Bump SDK packages (`openhands-sdk`, `openhands-agent-server`, `openhands-tools`), pin them to unreleased commits for testing, and cut an OpenHands release.
|
||||
|
||||
## Quick Summary — How Many Files Change?
|
||||
|
||||
| Activity | Manual edits | Auto-regenerated | Total |
|
||||
|----------|:------------:|:----------------:|:-----:|
|
||||
| **SDK bump** (released PyPI version) | 2 | 3 | **5** |
|
||||
| **SDK pin** (unreleased git commit) | 3 | 3 | **6** |
|
||||
| **Release commit** (version bump) | 3 | 0 | **3** |
|
||||
|
||||
The 3 auto-regenerated files are always: `poetry.lock`, `uv.lock`, `enterprise/poetry.lock`.
|
||||
|
||||
## SDK Package Bump — 2 Files + 3 Lock Files
|
||||
|
||||
Land as a separate PR before the release. Examples: `929dcc3` (SDK 1.11.5), `cd235cc` (SDK 1.11.4).
|
||||
|
||||
| File | What to change |
|
||||
|------|----------------|
|
||||
| `pyproject.toml` | `openhands-sdk`, `openhands-agent-server`, `openhands-tools` in **two** sections: the `dependencies` array (PEP 508) **and** `[tool.poetry.dependencies]` |
|
||||
| `openhands/app_server/sandbox/sandbox_spec_service.py` | `AGENT_SERVER_IMAGE` constant — set to `ghcr.io/openhands/agent-server:<version>-python` |
|
||||
|
||||
Then regenerate lock files:
|
||||
```bash
|
||||
poetry lock && uv lock && cd enterprise && poetry lock && cd ..
|
||||
```
|
||||
|
||||
## Docker Image Locations — All Hardcoded References
|
||||
|
||||
For the complete inventory of every file containing a hardcoded Docker image tag or repository, see `references/docker-image-locations.md`. Key files that must stay in sync during an SDK bump:
|
||||
|
||||
| File | Image reference | Updated during SDK bump? |
|
||||
|------|----------------|:------------------------:|
|
||||
| `openhands/app_server/sandbox/sandbox_spec_service.py` | `AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:<tag>-python'` | ✅ Yes |
|
||||
| `docker-compose.yml` | `AGENT_SERVER_IMAGE_TAG` default | ✅ Should be |
|
||||
| `containers/dev/compose.yml` | `AGENT_SERVER_IMAGE_REPOSITORY` + `_TAG` defaults | ✅ Should be |
|
||||
|
||||
> **CI enforcement:** `.github/workflows/check-version-consistency.yml` validates version consistency and compose file image references on every PR and push to main.
|
||||
|
||||
### ⚠️ Docker Image Tag Gotcha (merge-commit SHA)
|
||||
|
||||
The SDK CI in `software-agent-sdk` repo tags Docker images with the **GitHub Actions merge-commit SHA**, NOT the PR head-commit SHA. When pinning to an SDK PR branch:
|
||||
|
||||
1. Check the SDK PR description for the actual image tag (look for the `AGENT_SERVER_IMAGES` section)
|
||||
2. Or query the CI logs: the "Consolidate Build Information" job prints `"short_sha": "<tag>"`
|
||||
3. The merge-commit SHA differs from the head SHA shown in the PR
|
||||
|
||||
For released SDK versions, images use a version tag (e.g., `1.12.0-python`) — no merge-commit ambiguity.
|
||||
|
||||
## Cutting a Release — 3 Files
|
||||
|
||||
A release commit updates the version string across 3 files. Gold-standard examples: 1.3.0 (`d063c8c`), 1.4.0 (`495f48b`).
|
||||
|
||||
| File | What to change |
|
||||
|------|----------------|
|
||||
| `pyproject.toml` | `version = "X.Y.Z"` under `[tool.poetry]` |
|
||||
| `frontend/package.json` | `"version": "X.Y.Z"` |
|
||||
| `frontend/package-lock.json` | `"version": "X.Y.Z"` in **two** places (root object and `packages[""]`) |
|
||||
|
||||
> **Note:** `openhands/version.py` reads the version from `pyproject.toml` at runtime — no manual edit needed there.
|
||||
|
||||
### Compose Files (2 files)
|
||||
|
||||
Both compose files should use `ghcr.io/openhands/agent-server` with the current SDK version tag.
|
||||
|
||||
| File | What to verify |
|
||||
|------|----------------|
|
||||
| `docker-compose.yml` | `AGENT_SERVER_IMAGE_REPOSITORY` defaults to agent-server, `AGENT_SERVER_IMAGE_TAG` is current |
|
||||
| `containers/dev/compose.yml` | Same — must use agent-server, not runtime |
|
||||
|
||||
### Release Workflow
|
||||
|
||||
#### Step 1: Verify the SDK bump has landed
|
||||
|
||||
```bash
|
||||
grep -n "openhands-sdk\|openhands-agent-server\|openhands-tools" pyproject.toml
|
||||
grep -n "AGENT_SERVER_IMAGE" openhands/app_server/sandbox/sandbox_spec_service.py
|
||||
grep "AGENT_SERVER_IMAGE_TAG" docker-compose.yml containers/dev/compose.yml
|
||||
```
|
||||
|
||||
#### Step 2: Bump version numbers
|
||||
|
||||
```bash
|
||||
# Edit pyproject.toml, frontend/package.json, frontend/package-lock.json
|
||||
git add pyproject.toml frontend/package.json frontend/package-lock.json
|
||||
git commit -m "Release X.Y.Z"
|
||||
git tag X.Y.Z
|
||||
```
|
||||
|
||||
Create a `saas-rel-X.Y.Z` branch from the tagged commit for the SaaS deployment pipeline.
|
||||
|
||||
#### Step 3: CI builds Docker images automatically
|
||||
|
||||
The `ghcr-build.yml` workflow triggers on tag pushes and produces:
|
||||
- `ghcr.io/openhands/openhands:X.Y.Z`, `X.Y`, `X`, `latest`
|
||||
- `ghcr.io/openhands/runtime:X.Y.Z-nikolaik`, `X.Y-nikolaik`
|
||||
|
||||
The tagging logic lives in `containers/build.sh` — when `GITHUB_REF_NAME` matches a semver pattern (`^[0-9]+\.[0-9]+\.[0-9]+$`), it auto-generates major, major.minor, and `latest` tags.
|
||||
|
||||
## Development: Pin SDK to an Unreleased Commit
|
||||
|
||||
For detailed examples of all pinning formats (commit, branch, uv-only), see `references/sdk-pinning-examples.md`.
|
||||
|
||||
### Files to change (3 manual + 3 lock files)
|
||||
|
||||
| File | What to change |
|
||||
|------|----------------|
|
||||
| `pyproject.toml` | Pin all 3 SDK packages in **both** `dependencies` and `[tool.poetry.dependencies]` |
|
||||
| `openhands/app_server/sandbox/sandbox_spec_service.py` | `AGENT_SERVER_IMAGE` — use the merge-commit SHA tag, NOT the head-commit SHA |
|
||||
| `docker-compose.yml` | `AGENT_SERVER_IMAGE_TAG` default (for local development) |
|
||||
| `poetry.lock` | Auto-regenerated via `poetry lock` |
|
||||
| `uv.lock` | Auto-regenerated via `uv lock` |
|
||||
| `enterprise/poetry.lock` | Auto-regenerated via `cd enterprise && poetry lock` |
|
||||
|
||||
### CI guard
|
||||
|
||||
The `check-package-versions.yml` workflow blocks merging to `main` if `[tool.poetry.dependencies]` contains any `rev` fields. This ensures unreleased SDK pins do not accidentally ship in a release.
|
||||
@@ -0,0 +1,84 @@
|
||||
# Docker Image Locations — Complete Inventory
|
||||
|
||||
Every file in the OpenHands repository containing a hardcoded Docker image tag, repository, or version-pinned image reference. Organized by update cadence.
|
||||
|
||||
## Updated During SDK Bump (must change)
|
||||
|
||||
These files contain image tags that **must** be updated whenever the SDK version or pinned commit changes.
|
||||
|
||||
### `openhands/app_server/sandbox/sandbox_spec_service.py`
|
||||
- **Line:** `AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:<tag>-python'`
|
||||
- **Format:** `<sdk-version>-python` for releases (e.g., `1.12.0-python`), `<7-char-commit-hash>-python` for dev pins
|
||||
- **Source of truth** for which agent-server image the app server pulls at runtime
|
||||
- **⚠️ Gotcha:** When pinning to an SDK PR, the image tag is the **merge-commit SHA** from GitHub Actions, not the PR head-commit SHA. Check the SDK PR description or CI logs for the correct tag.
|
||||
|
||||
### `docker-compose.yml`
|
||||
- **Lines:**
|
||||
```yaml
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-<tag>-python}
|
||||
```
|
||||
- Used by `docker compose up` for local development
|
||||
|
||||
### `containers/dev/compose.yml`
|
||||
- **Lines:**
|
||||
```yaml
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-<tag>-python}
|
||||
```
|
||||
- Used by the dev container setup
|
||||
- **Known issue:** On main as of 1.4.0, this file still points to `ghcr.io/openhands/runtime` instead of `agent-server`, and the tag is `1.2-nikolaik` (stale from the V0 era). The `check-version-consistency.yml` CI workflow catches this.
|
||||
|
||||
## Updated During Release Commit (version string only)
|
||||
|
||||
### `pyproject.toml`
|
||||
- **Line:** `version = "X.Y.Z"` under `[tool.poetry]`
|
||||
- The Python version is derived from this at runtime via `openhands/version.py`
|
||||
|
||||
### `frontend/package.json`
|
||||
- **Line:** `"version": "X.Y.Z"`
|
||||
|
||||
### `frontend/package-lock.json`
|
||||
- **Two places:** root `"version": "X.Y.Z"` and `packages[""].version`
|
||||
|
||||
## Dynamic References (auto-derived, no manual update)
|
||||
|
||||
### `openhands/version.py`
|
||||
- Reads version from `pyproject.toml` at runtime → `openhands.__version__`
|
||||
|
||||
### `openhands/resolver/issue_resolver.py`
|
||||
- Builds `ghcr.io/openhands/runtime:{openhands.__version__}-nikolaik` dynamically
|
||||
|
||||
### `openhands/runtime/utils/runtime_build.py`
|
||||
- Base repo URL `ghcr.io/openhands/runtime` is a constant; version comes from elsewhere
|
||||
|
||||
### `.github/scripts/update_pr_description.sh`
|
||||
- Uses `${SHORT_SHA}` variable at CI runtime, not hardcoded
|
||||
|
||||
### `enterprise/Dockerfile`
|
||||
- `ARG BASE="ghcr.io/openhands/openhands"` — base image, version supplied at build time
|
||||
|
||||
## V0 Legacy Files (separate update cadence)
|
||||
|
||||
These reference the V0 runtime image (`ghcr.io/openhands/runtime:X.Y-nikolaik`) for local Docker/Kubernetes paths. They are **not** updated as part of a V1 release but may be updated independently.
|
||||
|
||||
### `Development.md`
|
||||
- `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/openhands/runtime:X.Y-nikolaik`
|
||||
|
||||
### `openhands/runtime/impl/kubernetes/README.md`
|
||||
- `runtime_container_image = "docker.openhands.dev/openhands/runtime:X.Y-nikolaik"`
|
||||
|
||||
### `enterprise/enterprise_local/README.md`
|
||||
- Uses `ghcr.io/openhands/runtime:main-nikolaik` (points to `main`, not versioned)
|
||||
|
||||
### `third_party/runtime/impl/daytona/README.md`
|
||||
- Uses `${OPENHANDS_VERSION}` variable, not hardcoded
|
||||
|
||||
## Image Registries
|
||||
|
||||
| Registry | Usage |
|
||||
|----------|-------|
|
||||
| `ghcr.io/openhands/agent-server` | V1 agent-server (sandbox) — built by SDK repo CI |
|
||||
| `ghcr.io/openhands/openhands` | Main app image — built by `ghcr-build.yml` |
|
||||
| `ghcr.io/openhands/runtime` | V0 runtime sandbox — built by `ghcr-build.yml` |
|
||||
| `docker.openhands.dev/openhands/*` | Mirror/CDN for the above images |
|
||||
103
.agents/skills/update-sdk/references/sdk-pinning-examples.md
Normal file
103
.agents/skills/update-sdk/references/sdk-pinning-examples.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# SDK Pinning Examples
|
||||
|
||||
Examples from real commits showing how to pin SDK packages to unreleased commits, branches, or released versions.
|
||||
|
||||
## Pin to a Specific Commit
|
||||
|
||||
Example from commit `169fb76` (pinning all 3 packages to SDK commit `100e9af`):
|
||||
|
||||
### `dependencies` array (PEP 508 format)
|
||||
|
||||
```toml
|
||||
"openhands-agent-server @ git+https://github.com/OpenHands/software-agent-sdk.git@100e9af#subdirectory=openhands-agent-server",
|
||||
"openhands-sdk @ git+https://github.com/OpenHands/software-agent-sdk.git@100e9af#subdirectory=openhands-sdk",
|
||||
"openhands-tools @ git+https://github.com/OpenHands/software-agent-sdk.git@100e9af#subdirectory=openhands-tools",
|
||||
```
|
||||
|
||||
### `[tool.poetry.dependencies]` (Poetry format)
|
||||
|
||||
```toml
|
||||
openhands-sdk = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "100e9af", subdirectory = "openhands-sdk" }
|
||||
openhands-agent-server = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "100e9af", subdirectory = "openhands-agent-server" }
|
||||
openhands-tools = { git = "https://github.com/OpenHands/software-agent-sdk.git", rev = "100e9af", subdirectory = "openhands-tools" }
|
||||
```
|
||||
|
||||
### `openhands/app_server/sandbox/sandbox_spec_service.py`
|
||||
|
||||
```python
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:<merge-commit-sha>-python'
|
||||
```
|
||||
|
||||
**⚠️ Important:** The image tag is the **merge-commit SHA** from the SDK CI, not the commit hash used in `pyproject.toml`. Look up the correct tag from the SDK PR description or CI logs.
|
||||
|
||||
## Pin to a Branch
|
||||
|
||||
Example from commit `430ee1c` (pinning to branch `openhands/issue-2228-sdk-settings-schema`):
|
||||
|
||||
### `[tool.poetry.dependencies]`
|
||||
|
||||
```toml
|
||||
openhands-sdk = { git = "https://github.com/OpenHands/software-agent-sdk.git", branch = "openhands/issue-2228-sdk-settings-schema", subdirectory = "openhands-sdk" }
|
||||
openhands-agent-server = { git = "https://github.com/OpenHands/software-agent-sdk.git", branch = "openhands/issue-2228-sdk-settings-schema", subdirectory = "openhands-agent-server" }
|
||||
openhands-tools = { git = "https://github.com/OpenHands/software-agent-sdk.git", branch = "openhands/issue-2228-sdk-settings-schema", subdirectory = "openhands-tools" }
|
||||
```
|
||||
|
||||
## Using `[tool.uv.sources]` Override
|
||||
|
||||
When only `uv` needs the override (keep PyPI versions in the main arrays), add a `[tool.uv.sources]` section. Example from commit `1daca49`:
|
||||
|
||||
```toml
|
||||
[tool.uv.sources]
|
||||
openhands-sdk = { git = "https://github.com/OpenHands/software-agent-sdk.git", subdirectory = "openhands-sdk", rev = "4170cca" }
|
||||
openhands-agent-server = { git = "https://github.com/OpenHands/software-agent-sdk.git", subdirectory = "openhands-agent-server", rev = "4170cca" }
|
||||
openhands-tools = { git = "https://github.com/OpenHands/software-agent-sdk.git", subdirectory = "openhands-tools", rev = "4170cca" }
|
||||
```
|
||||
|
||||
## Released PyPI Version (standard release)
|
||||
|
||||
Example from commit `929dcc3` (SDK 1.11.5):
|
||||
|
||||
### `dependencies` array
|
||||
|
||||
```toml
|
||||
"openhands-agent-server==1.11.5",
|
||||
"openhands-sdk==1.11.5",
|
||||
"openhands-tools==1.11.5",
|
||||
```
|
||||
|
||||
### `[tool.poetry.dependencies]`
|
||||
|
||||
```toml
|
||||
openhands-sdk = "1.11.5"
|
||||
openhands-agent-server = "1.11.5"
|
||||
openhands-tools = "1.11.5"
|
||||
```
|
||||
|
||||
### `openhands/app_server/sandbox/sandbox_spec_service.py`
|
||||
|
||||
For released versions, the image tag uses the version number:
|
||||
|
||||
```python
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:1.11.5-python'
|
||||
```
|
||||
|
||||
However, **some releases use a commit-hash tag** even for the released version. Check which tag format exists on GHCR. Example from `929dcc3`:
|
||||
|
||||
```python
|
||||
AGENT_SERVER_IMAGE = 'ghcr.io/openhands/agent-server:010e847-python'
|
||||
```
|
||||
|
||||
## Regenerate Lock Files
|
||||
|
||||
After any change to `pyproject.toml`, always regenerate:
|
||||
|
||||
```bash
|
||||
poetry lock
|
||||
uv lock
|
||||
cd enterprise && poetry lock && cd ..
|
||||
```
|
||||
|
||||
## CI Guards
|
||||
|
||||
- **`check-package-versions.yml`**: Blocks merge to `main` if `[tool.poetry.dependencies]` contains `rev` fields (prevents shipping unreleased SDK pins)
|
||||
- **`check-version-consistency.yml`**: Validates version strings match across `pyproject.toml`, `package.json`, `package-lock.json`, and verifies compose files use `agent-server` images
|
||||
2
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
2
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
# disable blank issue creation
|
||||
blank_issues_enabled: false
|
||||
330
.github/scripts/find_prs_between_commits.py
vendored
Normal file
330
.github/scripts/find_prs_between_commits.py
vendored
Normal file
@@ -0,0 +1,330 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Find all PRs that went in between two commits in the OpenHands/OpenHands repository.
|
||||
Handles cherry-picks and different merge strategies.
|
||||
|
||||
This script is designed to run from within the OpenHands repository under .github/scripts:
|
||||
.github/scripts/find_prs_between_commits.py
|
||||
|
||||
Usage: find_prs_between_commits <older_commit> <newer_commit> [--repo <path>]
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def find_openhands_repo() -> Optional[Path]:
|
||||
"""
|
||||
Find the OpenHands repository.
|
||||
Since this script is designed to live in .github/scripts/, it assumes
|
||||
the repository root is two levels up from the script location.
|
||||
Tries:
|
||||
1. Repository root (../../ from script location)
|
||||
2. Current directory
|
||||
3. Environment variable OPENHANDS_REPO
|
||||
"""
|
||||
# Check repository root (assuming script is in .github/scripts/)
|
||||
script_dir = Path(__file__).parent.absolute()
|
||||
repo_root = (
|
||||
script_dir.parent.parent
|
||||
) # Go up two levels: scripts -> .github -> repo root
|
||||
if (repo_root / '.git').exists():
|
||||
return repo_root
|
||||
|
||||
# Check current directory
|
||||
if (Path.cwd() / '.git').exists():
|
||||
return Path.cwd()
|
||||
|
||||
# Check environment variable
|
||||
if 'OPENHANDS_REPO' in os.environ:
|
||||
repo_path = Path(os.environ['OPENHANDS_REPO'])
|
||||
if (repo_path / '.git').exists():
|
||||
return repo_path
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_git_command(cmd: list[str], repo_path: Path) -> str:
|
||||
"""Run a git command in the repository directory and return its output."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd, capture_output=True, text=True, check=True, cwd=str(repo_path)
|
||||
)
|
||||
return result.stdout.strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f'Error running git command: {" ".join(cmd)}', file=sys.stderr)
|
||||
print(f'Error: {e.stderr}', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def extract_pr_numbers_from_message(message: str) -> set[int]:
|
||||
"""Extract PR numbers from commit message in any common format."""
|
||||
# Match #12345 anywhere, including in patterns like (#12345) or "Merge pull request #12345"
|
||||
matches = re.findall(r'#(\d+)', message)
|
||||
return set(int(m) for m in matches)
|
||||
|
||||
|
||||
def get_commit_info(commit_hash: str, repo_path: Path) -> tuple[str, str, str]:
|
||||
"""Get commit subject, body, and author from a commit hash."""
|
||||
subject = run_git_command(
|
||||
['git', 'log', '-1', '--format=%s', commit_hash], repo_path
|
||||
)
|
||||
body = run_git_command(['git', 'log', '-1', '--format=%b', commit_hash], repo_path)
|
||||
author = run_git_command(
|
||||
['git', 'log', '-1', '--format=%an <%ae>', commit_hash], repo_path
|
||||
)
|
||||
return subject, body, author
|
||||
|
||||
|
||||
def get_commits_between(
|
||||
older_commit: str, newer_commit: str, repo_path: Path
|
||||
) -> list[str]:
|
||||
"""Get all commit hashes between two commits."""
|
||||
commits_output = run_git_command(
|
||||
['git', 'rev-list', f'{older_commit}..{newer_commit}'], repo_path
|
||||
)
|
||||
|
||||
if not commits_output:
|
||||
return []
|
||||
|
||||
return commits_output.split('\n')
|
||||
|
||||
|
||||
def get_pr_info_from_github(pr_number: int, repo_path: Path) -> Optional[dict]:
|
||||
"""Get PR information from GitHub API if GITHUB_TOKEN is available."""
|
||||
try:
|
||||
# Set up environment with GitHub token
|
||||
env = os.environ.copy()
|
||||
if 'GITHUB_TOKEN' in env:
|
||||
env['GH_TOKEN'] = env['GITHUB_TOKEN']
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
'gh',
|
||||
'pr',
|
||||
'view',
|
||||
str(pr_number),
|
||||
'--json',
|
||||
'number,title,author,mergedAt,baseRefName,headRefName,url',
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
env=env,
|
||||
cwd=str(repo_path),
|
||||
)
|
||||
return json.loads(result.stdout)
|
||||
except (subprocess.CalledProcessError, FileNotFoundError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
def find_prs_between_commits(
|
||||
older_commit: str, newer_commit: str, repo_path: Path
|
||||
) -> dict[int, dict]:
|
||||
"""
|
||||
Find all PRs that went in between two commits.
|
||||
Returns a dictionary mapping PR numbers to their information.
|
||||
"""
|
||||
print(f'Repository: {repo_path}', file=sys.stderr)
|
||||
print('Finding PRs between commits:', file=sys.stderr)
|
||||
print(f' Older: {older_commit}', file=sys.stderr)
|
||||
print(f' Newer: {newer_commit}', file=sys.stderr)
|
||||
print(file=sys.stderr)
|
||||
|
||||
# Verify commits exist
|
||||
try:
|
||||
run_git_command(['git', 'rev-parse', '--verify', older_commit], repo_path)
|
||||
run_git_command(['git', 'rev-parse', '--verify', newer_commit], repo_path)
|
||||
except SystemExit:
|
||||
print('Error: One or both commits not found in repository', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Extract PRs from the older commit itself (to exclude from results)
|
||||
# These PRs are already included at or before the older commit
|
||||
older_subject, older_body, _ = get_commit_info(older_commit, repo_path)
|
||||
older_message = f'{older_subject}\n{older_body}'
|
||||
excluded_prs = extract_pr_numbers_from_message(older_message)
|
||||
|
||||
if excluded_prs:
|
||||
print(
|
||||
f'Excluding PRs already in older commit: {", ".join(f"#{pr}" for pr in sorted(excluded_prs))}',
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(file=sys.stderr)
|
||||
|
||||
# Get all commits between the two
|
||||
commits = get_commits_between(older_commit, newer_commit, repo_path)
|
||||
print(f'Found {len(commits)} commits to analyze', file=sys.stderr)
|
||||
print(file=sys.stderr)
|
||||
|
||||
# Extract PR numbers from all commits
|
||||
pr_info: dict[int, dict] = {}
|
||||
commits_by_pr: dict[int, list[str]] = defaultdict(list)
|
||||
|
||||
for commit_hash in commits:
|
||||
subject, body, author = get_commit_info(commit_hash, repo_path)
|
||||
full_message = f'{subject}\n{body}'
|
||||
|
||||
pr_numbers = extract_pr_numbers_from_message(full_message)
|
||||
|
||||
for pr_num in pr_numbers:
|
||||
# Skip PRs that are already in the older commit
|
||||
if pr_num in excluded_prs:
|
||||
continue
|
||||
|
||||
commits_by_pr[pr_num].append(commit_hash)
|
||||
|
||||
if pr_num not in pr_info:
|
||||
pr_info[pr_num] = {
|
||||
'number': pr_num,
|
||||
'first_commit': commit_hash[:8],
|
||||
'first_commit_subject': subject,
|
||||
'commits': [],
|
||||
'github_info': None,
|
||||
}
|
||||
|
||||
pr_info[pr_num]['commits'].append(
|
||||
{'hash': commit_hash[:8], 'subject': subject, 'author': author}
|
||||
)
|
||||
|
||||
# Try to get additional info from GitHub API
|
||||
print('Fetching additional info from GitHub API...', file=sys.stderr)
|
||||
for pr_num in pr_info.keys():
|
||||
github_info = get_pr_info_from_github(pr_num, repo_path)
|
||||
if github_info:
|
||||
pr_info[pr_num]['github_info'] = github_info
|
||||
|
||||
print(file=sys.stderr)
|
||||
|
||||
return pr_info
|
||||
|
||||
|
||||
def print_results(pr_info: dict[int, dict]):
|
||||
"""Print the results in a readable format."""
|
||||
sorted_prs = sorted(pr_info.items(), key=lambda x: x[0])
|
||||
|
||||
print(f'{"=" * 80}')
|
||||
print(f'Found {len(sorted_prs)} PRs')
|
||||
print(f'{"=" * 80}')
|
||||
print()
|
||||
|
||||
for pr_num, info in sorted_prs:
|
||||
print(f'PR #{pr_num}')
|
||||
|
||||
if info['github_info']:
|
||||
gh = info['github_info']
|
||||
print(f' Title: {gh["title"]}')
|
||||
print(f' Author: {gh["author"]["login"]}')
|
||||
print(f' URL: {gh["url"]}')
|
||||
if gh.get('mergedAt'):
|
||||
print(f' Merged: {gh["mergedAt"]}')
|
||||
if gh.get('baseRefName'):
|
||||
print(f' Base: {gh["baseRefName"]} ← {gh["headRefName"]}')
|
||||
else:
|
||||
print(f' Subject: {info["first_commit_subject"]}')
|
||||
|
||||
# Show if this PR has multiple commits (cherry-picked or multiple commits)
|
||||
commit_count = len(info['commits'])
|
||||
if commit_count > 1:
|
||||
print(
|
||||
f' ⚠️ Found {commit_count} commits (possible cherry-pick or multi-commit PR):'
|
||||
)
|
||||
for commit in info['commits'][:3]: # Show first 3
|
||||
print(f' {commit["hash"]}: {commit["subject"][:60]}')
|
||||
if commit_count > 3:
|
||||
print(f' ... and {commit_count - 3} more')
|
||||
else:
|
||||
print(f' Commit: {info["first_commit"]}')
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 3:
|
||||
print('Usage: find_prs_between_commits <older_commit> <newer_commit> [options]')
|
||||
print()
|
||||
print('Arguments:')
|
||||
print(' <older_commit> The older commit hash (or ref)')
|
||||
print(' <newer_commit> The newer commit hash (or ref)')
|
||||
print()
|
||||
print('Options:')
|
||||
print(' --json Output results in JSON format')
|
||||
print(' --repo <path> Path to OpenHands repository (default: auto-detect)')
|
||||
print()
|
||||
print('Example:')
|
||||
print(
|
||||
' find_prs_between_commits c79e0cd3c7a2501a719c9296828d7a31e4030585 35bddb14f15124a3dc448a74651a6592911d99e9'
|
||||
)
|
||||
print()
|
||||
print('Repository Detection:')
|
||||
print(' The script will try to find the OpenHands repository in this order:')
|
||||
print(' 1. --repo argument')
|
||||
print(' 2. Repository root (../../ from script location)')
|
||||
print(' 3. Current directory')
|
||||
print(' 4. OPENHANDS_REPO environment variable')
|
||||
print()
|
||||
print('Environment variables:')
|
||||
print(
|
||||
' GITHUB_TOKEN Optional. If set, will fetch additional PR info from GitHub API'
|
||||
)
|
||||
print(' OPENHANDS_REPO Optional. Path to OpenHands repository')
|
||||
sys.exit(1)
|
||||
|
||||
older_commit = sys.argv[1]
|
||||
newer_commit = sys.argv[2]
|
||||
json_output = '--json' in sys.argv
|
||||
|
||||
# Check for --repo argument
|
||||
repo_path = None
|
||||
if '--repo' in sys.argv:
|
||||
repo_idx = sys.argv.index('--repo')
|
||||
if repo_idx + 1 < len(sys.argv):
|
||||
repo_path = Path(sys.argv[repo_idx + 1])
|
||||
if not (repo_path / '.git').exists():
|
||||
print(f'Error: {repo_path} is not a git repository', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Auto-detect repository if not specified
|
||||
if repo_path is None:
|
||||
repo_path = find_openhands_repo()
|
||||
if repo_path is None:
|
||||
print('Error: Could not find OpenHands repository', file=sys.stderr)
|
||||
print('Please either:', file=sys.stderr)
|
||||
print(
|
||||
' 1. Place this script in .github/scripts/ within the OpenHands repository',
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(' 2. Run from the OpenHands repository directory', file=sys.stderr)
|
||||
print(
|
||||
' 3. Use --repo <path> to specify the repository location',
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(' 4. Set OPENHANDS_REPO environment variable', file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Find PRs
|
||||
pr_info = find_prs_between_commits(older_commit, newer_commit, repo_path)
|
||||
|
||||
if json_output:
|
||||
# Output as JSON
|
||||
print(json.dumps(pr_info, indent=2))
|
||||
else:
|
||||
# Print results in human-readable format
|
||||
print_results(pr_info)
|
||||
|
||||
# Also print a simple list for easy copying
|
||||
print(f'{"=" * 80}')
|
||||
print('PR Numbers (for easy copying):')
|
||||
print(f'{"=" * 80}')
|
||||
sorted_pr_nums = sorted(pr_info.keys())
|
||||
print(', '.join(f'#{pr}' for pr in sorted_pr_nums))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
122
.github/workflows/check-version-consistency.yml
vendored
Normal file
122
.github/workflows/check-version-consistency.yml
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
name: Check Version Consistency
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
check-version-consistency:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.12"
|
||||
|
||||
- name: Check version and Docker image tag consistency
|
||||
run: |
|
||||
python - <<'PY'
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import tomllib
|
||||
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# ── 1. Extract the canonical version from pyproject.toml ──────────
|
||||
with open("pyproject.toml", "rb") as f:
|
||||
pyproject = tomllib.load(f)
|
||||
version = pyproject["tool"]["poetry"]["version"]
|
||||
major_minor = ".".join(version.split(".")[:2])
|
||||
print(f"📦 pyproject.toml version: {version} (major.minor: {major_minor})")
|
||||
|
||||
# ── 2. Check frontend/package.json ────────────────────────────────
|
||||
with open("frontend/package.json") as f:
|
||||
pkg = json.load(f)
|
||||
if pkg["version"] != version:
|
||||
errors.append(
|
||||
f"frontend/package.json version is '{pkg['version']}', expected '{version}'"
|
||||
)
|
||||
else:
|
||||
print(f" ✔ frontend/package.json: {pkg['version']}")
|
||||
|
||||
# ── 3. Check frontend/package-lock.json (2 places) ───────────────
|
||||
with open("frontend/package-lock.json") as f:
|
||||
lock = json.load(f)
|
||||
for key, val in [
|
||||
("root.version", lock.get("version")),
|
||||
('packages[""].version', lock.get("packages", {}).get("", {}).get("version")),
|
||||
]:
|
||||
if val != version:
|
||||
errors.append(
|
||||
f"frontend/package-lock.json {key} is '{val}', expected '{version}'"
|
||||
)
|
||||
else:
|
||||
print(f" ✔ frontend/package-lock.json {key}: {val}")
|
||||
|
||||
# ── 4. Check compose files use agent-server images ─────────────────
|
||||
# Both compose files should use ghcr.io/.../agent-server (not runtime).
|
||||
# Agent-server tags use SDK version (e.g. "1.12.0-python") or commit
|
||||
# hashes (e.g. "31536c8-python") — both are acceptable.
|
||||
repo_pattern = re.compile(r"AGENT_SERVER_IMAGE_REPOSITORY[^}]*:-([^}]+)")
|
||||
tag_pattern = re.compile(r"AGENT_SERVER_IMAGE_TAG:-([^}]+)")
|
||||
|
||||
for filepath in ["docker-compose.yml", "containers/dev/compose.yml"]:
|
||||
try:
|
||||
with open(filepath) as f:
|
||||
content = f.read()
|
||||
except FileNotFoundError:
|
||||
warnings.append(f"{filepath}: file not found")
|
||||
continue
|
||||
|
||||
repos = repo_pattern.findall(content)
|
||||
tags = tag_pattern.findall(content)
|
||||
|
||||
if not repos:
|
||||
warnings.append(f"{filepath}: no AGENT_SERVER_IMAGE_REPOSITORY default found")
|
||||
else:
|
||||
repo = repos[0]
|
||||
if "agent-server" not in repo:
|
||||
errors.append(
|
||||
f"{filepath}: AGENT_SERVER_IMAGE_REPOSITORY defaults to '{repo}', "
|
||||
f"expected an agent-server image (not runtime)"
|
||||
)
|
||||
else:
|
||||
print(f" ✔ {filepath} image repository: {repo}")
|
||||
|
||||
if not tags:
|
||||
warnings.append(f"{filepath}: no AGENT_SERVER_IMAGE_TAG default found")
|
||||
else:
|
||||
tag = tags[0]
|
||||
if not tag:
|
||||
errors.append(f"{filepath}: AGENT_SERVER_IMAGE_TAG default is empty")
|
||||
else:
|
||||
print(f" ✔ {filepath} image tag: {tag}")
|
||||
|
||||
# ── 5. Report ─────────────────────────────────────────────────────
|
||||
print()
|
||||
if warnings:
|
||||
print("⚠ Warnings:")
|
||||
for w in warnings:
|
||||
print(f" {w}")
|
||||
print()
|
||||
|
||||
if errors:
|
||||
print("❌ FAILED: Version inconsistencies found:\n")
|
||||
for e in errors:
|
||||
print(f" ✖ {e}")
|
||||
print(
|
||||
"\nAll version numbers and Docker image tags must be consistent."
|
||||
"\nSee .agents/skills/update-sdk/SKILL.md for the full checklist."
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("✅ All version numbers and Docker image tags are consistent.")
|
||||
PY
|
||||
29
.github/workflows/enterprise-preview.yml
vendored
29
.github/workflows/enterprise-preview.yml
vendored
@@ -1,29 +0,0 @@
|
||||
# Feature branch preview for enterprise code
|
||||
name: Enterprise Preview
|
||||
|
||||
# Run on PRs labeled
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
# Match ghcr-build.yml, but don't interrupt it.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
# This must happen for the PR Docker workflow when the label is present,
|
||||
# and also if it's added after the fact. Thus, it exists in both places.
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event.label.name == 'deploy'
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
# This should match the version in ghcr-build.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
16
.github/workflows/ghcr-build.yml
vendored
16
.github/workflows/ghcr-build.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- "saas-rel-*"
|
||||
tags:
|
||||
- "*"
|
||||
pull_request:
|
||||
@@ -239,21 +240,6 @@ jobs:
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
needs: [ghcr_build_enterprise]
|
||||
steps:
|
||||
# This should match the version in enterprise-preview.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/OpenHands/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# "All Runtime Tests Passed" is a required job for PRs to merge
|
||||
# We can remove this once the config changes
|
||||
runtime_tests_check_success:
|
||||
|
||||
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
48
.github/workflows/pr-review-by-openhands.yml
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
---
|
||||
name: PR Review by OpenHands
|
||||
|
||||
on:
|
||||
# TEMPORARY MITIGATION (Clinejection hardening)
|
||||
#
|
||||
# We temporarily avoid `pull_request_target` here. We'll restore it after the PR review
|
||||
# workflow is fully hardened for untrusted execution.
|
||||
pull_request:
|
||||
types: [opened, ready_for_review, labeled, review_requested]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
issues: write
|
||||
|
||||
jobs:
|
||||
pr-review:
|
||||
# Note: fork PRs will not have access to repository secrets under `pull_request`.
|
||||
# Skip forks to avoid noisy failures until we restore a hardened `pull_request_target` flow.
|
||||
if: |
|
||||
github.event.pull_request.head.repo.full_name == github.repository &&
|
||||
(
|
||||
(github.event.action == 'opened' && github.event.pull_request.draft == false) ||
|
||||
github.event.action == 'ready_for_review' ||
|
||||
(github.event.action == 'labeled' && github.event.label.name == 'review-this') ||
|
||||
(
|
||||
github.event.action == 'review_requested' &&
|
||||
(
|
||||
github.event.requested_reviewer.login == 'openhands-agent' ||
|
||||
github.event.requested_reviewer.login == 'all-hands-bot'
|
||||
)
|
||||
)
|
||||
)
|
||||
concurrency:
|
||||
group: pr-review-${{ github.event.pull_request.number }}
|
||||
cancel-in-progress: true
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Run PR Review
|
||||
uses: OpenHands/extensions/plugins/pr-review@main
|
||||
with:
|
||||
llm-model: litellm_proxy/claude-sonnet-4-5-20250929
|
||||
llm-base-url: https://llm-proxy.app.all-hands.dev
|
||||
review-style: roasted
|
||||
llm-api-key: ${{ secrets.LLM_API_KEY }}
|
||||
github-token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
lmnr-api-key: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
85
.github/workflows/pr-review-evaluation.yml
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
---
|
||||
name: PR Review Evaluation
|
||||
|
||||
# This workflow evaluates how well PR review comments were addressed.
|
||||
# It runs when a PR is closed to assess review effectiveness.
|
||||
#
|
||||
# Security note: pull_request_target is safe here because:
|
||||
# 1. Only triggers on PR close (not on code changes)
|
||||
# 2. Does not checkout PR code - only downloads artifacts from trusted workflow runs
|
||||
# 3. Runs evaluation scripts from the extensions repo, not from the PR
|
||||
|
||||
on:
|
||||
pull_request_target:
|
||||
types: [closed]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
evaluate:
|
||||
runs-on: ubuntu-24.04
|
||||
env:
|
||||
PR_NUMBER: ${{ github.event.pull_request.number }}
|
||||
REPO_NAME: ${{ github.repository }}
|
||||
PR_MERGED: ${{ github.event.pull_request.merged }}
|
||||
|
||||
steps:
|
||||
- name: Download review trace artifact
|
||||
id: download-trace
|
||||
uses: dawidd6/action-download-artifact@v6
|
||||
continue-on-error: true
|
||||
with:
|
||||
workflow: pr-review-by-openhands.yml
|
||||
name: pr-review-trace-${{ github.event.pull_request.number }}
|
||||
path: trace-info
|
||||
search_artifacts: true
|
||||
if_no_artifact_found: warn
|
||||
|
||||
- name: Check if trace file exists
|
||||
id: check-trace
|
||||
run: |
|
||||
if [ -f "trace-info/laminar_trace_info.json" ]; then
|
||||
echo "trace_exists=true" >> $GITHUB_OUTPUT
|
||||
echo "Found trace file for PR #$PR_NUMBER"
|
||||
else
|
||||
echo "trace_exists=false" >> $GITHUB_OUTPUT
|
||||
echo "No trace file found for PR #$PR_NUMBER - skipping evaluation"
|
||||
fi
|
||||
|
||||
# Always checkout main branch for security - cannot test script changes in PRs
|
||||
- name: Checkout extensions repository
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/checkout@v5
|
||||
with:
|
||||
repository: OpenHands/extensions
|
||||
path: extensions
|
||||
|
||||
- name: Set up Python
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
|
||||
- name: Install dependencies
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
run: pip install lmnr
|
||||
|
||||
- name: Run evaluation
|
||||
if: steps.check-trace.outputs.trace_exists == 'true'
|
||||
env:
|
||||
# Script expects LMNR_PROJECT_API_KEY; org secret is named LMNR_SKILLS_API_KEY
|
||||
LMNR_PROJECT_API_KEY: ${{ secrets.LMNR_SKILLS_API_KEY }}
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
run: |
|
||||
python extensions/plugins/pr-review/scripts/evaluate_review.py \
|
||||
--trace-file trace-info/laminar_trace_info.json
|
||||
|
||||
- name: Upload evaluation logs
|
||||
uses: actions/upload-artifact@v5
|
||||
if: always() && steps.check-trace.outputs.trace_exists == 'true'
|
||||
with:
|
||||
name: pr-review-evaluation-${{ github.event.pull_request.number }}
|
||||
path: '*.log'
|
||||
retention-days: 30
|
||||
1156
.pr/chat-message-persistence-design.md
Normal file
1156
.pr/chat-message-persistence-design.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -165,7 +165,7 @@ Each integration follows a consistent pattern with service classes, storage mode
|
||||
|
||||
**Import Patterns:**
|
||||
- Use relative imports without `enterprise.` prefix in enterprise code
|
||||
- Example: `from storage.database import session_maker` not `from enterprise.storage.database import session_maker`
|
||||
- Example: `from storage.database import a_session_maker` not `from enterprise.storage.database import a_session_maker`
|
||||
- This ensures code works in both OpenHands and enterprise contexts
|
||||
|
||||
**Test Structure:**
|
||||
|
||||
@@ -54,7 +54,7 @@ The experience will be familiar to anyone who has used Devin or Jules.
|
||||
### OpenHands Cloud
|
||||
This is a deployment of OpenHands GUI, running on hosted infrastructure.
|
||||
|
||||
You can try it with a free $10 credit by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
You can try it for free using the Minimax model by [signing in with your GitHub or GitLab account](https://app.all-hands.dev).
|
||||
|
||||
OpenHands Cloud comes with source-available features and integrations:
|
||||
- Integrations with Slack, Jira, and Linear
|
||||
|
||||
@@ -12,8 +12,8 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/runtime}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.2-nikolaik}
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.12.0-python}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -8,7 +8,7 @@ services:
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- AGENT_SERVER_IMAGE_REPOSITORY=${AGENT_SERVER_IMAGE_REPOSITORY:-ghcr.io/openhands/agent-server}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-31536c8-python}
|
||||
- AGENT_SERVER_IMAGE_TAG=${AGENT_SERVER_IMAGE_TAG:-1.12.0-python}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -23,12 +23,23 @@ RUN apt-get update && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python packages with security fixes
|
||||
RUN /app/.venv/bin/pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace "posthog>=6.0.0" "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy google-cloud-recaptcha-enterprise && \
|
||||
# Update packages with known CVE fixes
|
||||
/app/.venv/bin/pip install --upgrade \
|
||||
"mcp>=1.10.0" \
|
||||
"pillow>=11.3.0"
|
||||
# Install poetry and export before importing current code.
|
||||
RUN /app/.venv/bin/pip install poetry poetry-plugin-export
|
||||
|
||||
# Install Python dependencies from poetry.lock for reproducible builds
|
||||
# Copy lock files first for better Docker layer caching
|
||||
COPY --chown=openhands:openhands enterprise/pyproject.toml enterprise/poetry.lock /tmp/enterprise/
|
||||
RUN cd /tmp/enterprise && \
|
||||
# Export only main dependencies with hashes for supply chain security
|
||||
/app/.venv/bin/poetry export --only main -o requirements.txt && \
|
||||
# Remove the local path dependency (openhands-ai is already in base image)
|
||||
sed -i '/^-e /d; /openhands-ai/d' requirements.txt && \
|
||||
# Install pinned dependencies from lock file
|
||||
/app/.venv/bin/pip install -r requirements.txt && \
|
||||
# Cleanup - return to /app before removing /tmp/enterprise
|
||||
cd /app && \
|
||||
rm -rf /tmp/enterprise && \
|
||||
/app/.venv/bin/pip uninstall -y poetry poetry-plugin-export
|
||||
|
||||
WORKDIR /app
|
||||
COPY --chown=openhands:openhands --chmod=770 enterprise .
|
||||
|
||||
@@ -1772,6 +1772,40 @@
|
||||
"sendIdTokenOnLogout": "true",
|
||||
"passMaxAge": "false"
|
||||
}
|
||||
},
|
||||
{
|
||||
"alias": "bitbucket_data_center",
|
||||
"displayName": "Bitbucket Data Center",
|
||||
"internalId": "b77b4ead-20e8-451c-ad27-99f92d561616",
|
||||
"providerId": "oauth2",
|
||||
"enabled": true,
|
||||
"updateProfileFirstLoginMode": "on",
|
||||
"trustEmail": true,
|
||||
"storeToken": true,
|
||||
"addReadTokenRoleOnCreate": false,
|
||||
"authenticateByDefault": false,
|
||||
"linkOnly": false,
|
||||
"hideOnLogin": false,
|
||||
"config": {
|
||||
"givenNameClaim": "given_name",
|
||||
"userInfoUrl": "https://${WEB_HOST}/bitbucket-dc-proxy/oauth2/userinfo",
|
||||
"clientId": "$BITBUCKET_DATA_CENTER_CLIENT_ID",
|
||||
"tokenUrl": "https://${BITBUCKET_DATA_CENTER_HOST}/rest/oauth2/latest/token",
|
||||
"acceptsPromptNoneForwardFromClient": "false",
|
||||
"fullNameClaim": "name",
|
||||
"userIDClaim": "sub",
|
||||
"emailClaim": "email",
|
||||
"userNameClaim": "preferred_username",
|
||||
"caseSensitiveOriginalUsername": "false",
|
||||
"familyNameClaim": "family_name",
|
||||
"pkceEnabled": "false",
|
||||
"authorizationUrl": "https://${BITBUCKET_DATA_CENTER_HOST}/rest/oauth2/latest/authorize",
|
||||
"clientAuthMethod": "client_secret_post",
|
||||
"syncMode": "IMPORT",
|
||||
"clientSecret": "$BITBUCKET_DATA_CENTER_CLIENT_SECRET",
|
||||
"allowedClockSkew": "0",
|
||||
"defaultScope": "REPO_WRITE"
|
||||
}
|
||||
}
|
||||
],
|
||||
"identityProviderMappers": [
|
||||
@@ -1829,6 +1863,26 @@
|
||||
"syncMode": "FORCE",
|
||||
"attribute": "identity_provider"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "id-mapper",
|
||||
"identityProviderAlias": "bitbucket_data_center",
|
||||
"identityProviderMapper": "oidc-user-attribute-idp-mapper",
|
||||
"config": {
|
||||
"syncMode": "FORCE",
|
||||
"claim": "sub",
|
||||
"user.attribute": "bitbucket_data_center_id"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "identity-provider",
|
||||
"identityProviderAlias": "bitbucket_data_center",
|
||||
"identityProviderMapper": "hardcoded-attribute-idp-mapper",
|
||||
"config": {
|
||||
"attribute.value": "bitbucket_data_center",
|
||||
"syncMode": "FORCE",
|
||||
"attribute": "identity_provider"
|
||||
}
|
||||
}
|
||||
],
|
||||
"components": {
|
||||
|
||||
@@ -50,8 +50,10 @@ repos:
|
||||
- ./
|
||||
- stripe==11.5.0
|
||||
- pygithub==2.6.1
|
||||
# To see gaps add `--html-report mypy-report/`
|
||||
entry: mypy --config-file enterprise/dev_config/python/mypy.ini enterprise/
|
||||
# Use -p (package) to avoid dual module name conflict when using MYPYPATH
|
||||
# MYPYPATH=enterprise allows resolving bare imports like "from integrations.xxx"
|
||||
# Note: tests package excluded to avoid conflict with core openhands tests
|
||||
entry: bash -c 'MYPYPATH=enterprise mypy --config-file enterprise/dev_config/python/mypy.ini -p integrations -p server -p storage -p sync'
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
files: ^enterprise/
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
warn_unused_configs = True
|
||||
ignore_missing_imports = True
|
||||
check_untyped_defs = True
|
||||
explicit_package_bases = True
|
||||
warn_unreachable = True
|
||||
warn_redundant_casts = True
|
||||
no_implicit_optional = True
|
||||
|
||||
@@ -200,7 +200,7 @@ class MetricsCollector(ABC):
|
||||
"""Base class for metrics collectors."""
|
||||
|
||||
@abstractmethod
|
||||
def collect(self) -> List[MetricResult]:
|
||||
async def collect(self) -> List[MetricResult]:
|
||||
"""Collect metrics and return results."""
|
||||
pass
|
||||
|
||||
@@ -264,12 +264,13 @@ class SystemMetricsCollector(MetricsCollector):
|
||||
def collector_name(self) -> str:
|
||||
return "system_metrics"
|
||||
|
||||
def collect(self) -> List[MetricResult]:
|
||||
async def collect(self) -> List[MetricResult]:
|
||||
results = []
|
||||
|
||||
# Collect user count
|
||||
with session_maker() as session:
|
||||
user_count = session.query(UserSettings).count()
|
||||
async with a_session_maker() as session:
|
||||
user_count_result = await session.execute(select(func.count()).select_from(UserSettings))
|
||||
user_count = user_count_result.scalar()
|
||||
results.append(MetricResult(
|
||||
key="total_users",
|
||||
value=user_count
|
||||
@@ -277,9 +278,11 @@ class SystemMetricsCollector(MetricsCollector):
|
||||
|
||||
# Collect conversation count (last 30 days)
|
||||
thirty_days_ago = datetime.now(timezone.utc) - timedelta(days=30)
|
||||
conversation_count = session.query(StoredConversationMetadata)\
|
||||
.filter(StoredConversationMetadata.created_at >= thirty_days_ago)\
|
||||
.count()
|
||||
conversation_count_result = await session.execute(
|
||||
select(func.count()).select_from(StoredConversationMetadata)
|
||||
.where(StoredConversationMetadata.created_at >= thirty_days_ago)
|
||||
)
|
||||
conversation_count = conversation_count_result.scalar()
|
||||
|
||||
results.append(MetricResult(
|
||||
key="conversations_30d",
|
||||
@@ -303,7 +306,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
|
||||
"""Collect metrics from all registered collectors."""
|
||||
|
||||
# Check if collection is needed
|
||||
if not self._should_collect():
|
||||
if not await self._should_collect():
|
||||
return {"status": "skipped", "reason": "too_recent"}
|
||||
|
||||
# Collect metrics from all registered collectors
|
||||
@@ -313,7 +316,7 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
|
||||
for collector in collector_registry.get_all_collectors():
|
||||
try:
|
||||
if collector.should_collect():
|
||||
results = collector.collect()
|
||||
results = await collector.collect()
|
||||
for result in results:
|
||||
all_metrics[result.key] = result.value
|
||||
collector_results[collector.collector_name] = len(results)
|
||||
@@ -322,13 +325,13 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
|
||||
collector_results[collector.collector_name] = f"error: {e}"
|
||||
|
||||
# Store metrics in database
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
telemetry_record = TelemetryMetrics(
|
||||
metrics_data=all_metrics,
|
||||
collected_at=datetime.now(timezone.utc)
|
||||
)
|
||||
session.add(telemetry_record)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
# Note: No need to track last_collection_at separately
|
||||
# Can be derived from MAX(collected_at) in telemetry_metrics
|
||||
@@ -339,11 +342,12 @@ class TelemetryCollectionProcessor(MaintenanceTaskProcessor):
|
||||
"collectors_run": collector_results
|
||||
}
|
||||
|
||||
def _should_collect(self) -> bool:
|
||||
async def _should_collect(self) -> bool:
|
||||
"""Check if collection is needed based on interval."""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Get last collection time from metrics table
|
||||
last_collected = session.query(func.max(TelemetryMetrics.collected_at)).scalar()
|
||||
result = await session.execute(select(func.max(TelemetryMetrics.collected_at)))
|
||||
last_collected = result.scalar()
|
||||
if not last_collected:
|
||||
return True
|
||||
|
||||
@@ -366,17 +370,19 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
"""Upload pending metrics to Replicated."""
|
||||
|
||||
# Get pending metrics
|
||||
with session_maker() as session:
|
||||
pending_metrics = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.uploaded_at.is_(None))\
|
||||
.order_by(TelemetryMetrics.collected_at)\
|
||||
.all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(TelemetryMetrics)
|
||||
.where(TelemetryMetrics.uploaded_at.is_(None))
|
||||
.order_by(TelemetryMetrics.collected_at)
|
||||
)
|
||||
pending_metrics = result.scalars().all()
|
||||
|
||||
if not pending_metrics:
|
||||
return {"status": "no_pending_metrics"}
|
||||
|
||||
# Get admin email - skip if not available
|
||||
admin_email = self._get_admin_email()
|
||||
admin_email = await self._get_admin_email()
|
||||
if not admin_email:
|
||||
logger.info("Skipping telemetry upload - no admin email available")
|
||||
return {
|
||||
@@ -413,13 +419,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
await instance.set_status(InstanceStatus.RUNNING)
|
||||
|
||||
# Mark as uploaded
|
||||
with session_maker() as session:
|
||||
record = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.id == metric_record.id)\
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(TelemetryMetrics)
|
||||
.where(TelemetryMetrics.id == metric_record.id)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if record:
|
||||
record.uploaded_at = datetime.now(timezone.utc)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
uploaded_count += 1
|
||||
|
||||
@@ -427,14 +435,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
logger.error(f"Failed to upload metrics {metric_record.id}: {e}")
|
||||
|
||||
# Update error info
|
||||
with session_maker() as session:
|
||||
record = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.id == metric_record.id)\
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(TelemetryMetrics)
|
||||
.where(TelemetryMetrics.id == metric_record.id)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if record:
|
||||
record.upload_attempts += 1
|
||||
record.last_upload_error = str(e)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
failed_count += 1
|
||||
|
||||
@@ -448,7 +458,7 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
"total_processed": len(pending_metrics)
|
||||
}
|
||||
|
||||
def _get_admin_email(self) -> str | None:
|
||||
async def _get_admin_email(self) -> str | None:
|
||||
"""Get administrator email for customer identification."""
|
||||
# 1. Check environment variable first
|
||||
env_admin_email = os.getenv('OPENHANDS_ADMIN_EMAIL')
|
||||
@@ -457,12 +467,15 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
return env_admin_email
|
||||
|
||||
# 2. Use first active user's email (earliest accepted_tos)
|
||||
with session_maker() as session:
|
||||
first_user = session.query(UserSettings)\
|
||||
.filter(UserSettings.email.isnot(None))\
|
||||
.filter(UserSettings.accepted_tos.isnot(None))\
|
||||
.order_by(UserSettings.accepted_tos.asc())\
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(UserSettings)
|
||||
.where(UserSettings.email.isnot(None))
|
||||
.where(UserSettings.accepted_tos.isnot(None))
|
||||
.order_by(UserSettings.accepted_tos.asc())
|
||||
.limit(1)
|
||||
)
|
||||
first_user = result.scalar_one_or_none()
|
||||
|
||||
if first_user and first_user.email:
|
||||
logger.info(f"Using first active user email: {first_user.email}")
|
||||
@@ -474,15 +487,16 @@ class TelemetryUploadProcessor(MaintenanceTaskProcessor):
|
||||
|
||||
async def _update_telemetry_identity(self, customer_id: str, instance_id: str) -> None:
|
||||
"""Update or create telemetry identity record."""
|
||||
with session_maker() as session:
|
||||
identity = session.query(TelemetryIdentity).first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(select(TelemetryIdentity).limit(1))
|
||||
identity = result.scalar_one_or_none()
|
||||
if not identity:
|
||||
identity = TelemetryIdentity()
|
||||
session.add(identity)
|
||||
|
||||
identity.customer_id = customer_id
|
||||
identity.instance_id = instance_id
|
||||
session.commit()
|
||||
await session.commit()
|
||||
```
|
||||
|
||||
### 4.4 License Warning System
|
||||
@@ -503,11 +517,13 @@ async def get_license_status():
|
||||
if not _is_openhands_enterprise():
|
||||
return {"warn": False, "message": ""}
|
||||
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Get last successful upload time from metrics table
|
||||
last_upload = session.query(func.max(TelemetryMetrics.uploaded_at))\
|
||||
.filter(TelemetryMetrics.uploaded_at.isnot(None))\
|
||||
.scalar()
|
||||
result = await session.execute(
|
||||
select(func.max(TelemetryMetrics.uploaded_at))
|
||||
.where(TelemetryMetrics.uploaded_at.isnot(None))
|
||||
)
|
||||
last_upload = result.scalar()
|
||||
|
||||
if not last_upload:
|
||||
# No successful uploads yet - show warning after 4 days
|
||||
@@ -521,10 +537,13 @@ async def get_license_status():
|
||||
|
||||
if days_since_upload > 4:
|
||||
# Find oldest unsent batch
|
||||
oldest_unsent = session.query(TelemetryMetrics)\
|
||||
.filter(TelemetryMetrics.uploaded_at.is_(None))\
|
||||
.order_by(TelemetryMetrics.collected_at)\
|
||||
.first()
|
||||
result = await session.execute(
|
||||
select(TelemetryMetrics)
|
||||
.where(TelemetryMetrics.uploaded_at.is_(None))
|
||||
.order_by(TelemetryMetrics.collected_at)
|
||||
.limit(1)
|
||||
)
|
||||
oldest_unsent = result.scalar_one_or_none()
|
||||
|
||||
if oldest_unsent:
|
||||
# Calculate expiration date (oldest unsent + 34 days)
|
||||
@@ -630,19 +649,23 @@ spec:
|
||||
- python
|
||||
- -c
|
||||
- |
|
||||
import asyncio
|
||||
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from enterprise.storage.database import session_maker
|
||||
from enterprise.storage.database import a_session_maker
|
||||
from enterprise.server.telemetry.collection_processor import TelemetryCollectionProcessor
|
||||
|
||||
# Create collection task
|
||||
processor = TelemetryCollectionProcessor()
|
||||
task = MaintenanceTask()
|
||||
task.set_processor(processor)
|
||||
task.status = MaintenanceTaskStatus.PENDING
|
||||
async def main():
|
||||
# Create collection task
|
||||
processor = TelemetryCollectionProcessor()
|
||||
task = MaintenanceTask()
|
||||
task.set_processor(processor)
|
||||
task.status = MaintenanceTaskStatus.PENDING
|
||||
|
||||
with session_maker() as session:
|
||||
session.add(task)
|
||||
session.commit()
|
||||
async with a_session_maker() as session:
|
||||
session.add(task)
|
||||
await session.commit()
|
||||
|
||||
asyncio.run(main())
|
||||
restartPolicy: OnFailure
|
||||
```
|
||||
|
||||
@@ -680,23 +703,27 @@ spec:
|
||||
- python
|
||||
- -c
|
||||
- |
|
||||
import asyncio
|
||||
from enterprise.storage.maintenance_task import MaintenanceTask, MaintenanceTaskStatus
|
||||
from enterprise.storage.database import session_maker
|
||||
from enterprise.storage.database import a_session_maker
|
||||
from enterprise.server.telemetry.upload_processor import TelemetryUploadProcessor
|
||||
import os
|
||||
|
||||
# Create upload task
|
||||
processor = TelemetryUploadProcessor(
|
||||
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
|
||||
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
|
||||
)
|
||||
task = MaintenanceTask()
|
||||
task.set_processor(processor)
|
||||
task.status = MaintenanceTaskStatus.PENDING
|
||||
async def main():
|
||||
# Create upload task
|
||||
processor = TelemetryUploadProcessor(
|
||||
replicated_publishable_key=os.getenv('REPLICATED_PUBLISHABLE_KEY'),
|
||||
replicated_app_slug=os.getenv('REPLICATED_APP_SLUG', 'openhands-enterprise')
|
||||
)
|
||||
task = MaintenanceTask()
|
||||
task.set_processor(processor)
|
||||
task.status = MaintenanceTaskStatus.PENDING
|
||||
|
||||
with session_maker() as session:
|
||||
session.add(task)
|
||||
session.commit()
|
||||
async with a_session_maker() as session:
|
||||
session.add(task)
|
||||
await session.commit()
|
||||
|
||||
asyncio.run(main())
|
||||
restartPolicy: OnFailure
|
||||
```
|
||||
|
||||
|
||||
131
enterprise/doc/design-doc/plugin-launch-flow.md
Normal file
131
enterprise/doc/design-doc/plugin-launch-flow.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Plugin Launch Flow
|
||||
|
||||
This document describes how plugins are launched in OpenHands Saas / Enterprise, from the plugin directory through to agent execution.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
Plugin Directory ──▶ Frontend /launch ──▶ App Server ──▶ Agent Server ──▶ SDK
|
||||
(external) (modal) (API) (in sandbox) (plugin loading)
|
||||
```
|
||||
|
||||
| Component | Responsibility |
|
||||
|-----------|---------------|
|
||||
| **Plugin Directory** | Index plugins, present to user, construct launch URLs |
|
||||
| **Frontend** | Display confirmation modal, collect parameters, call API |
|
||||
| **App Server** | Validate request, pass plugin specs to agent server |
|
||||
| **Agent Server** | Run inside sandbox, delegate plugin loading to SDK |
|
||||
| **SDK** | Fetch plugins, load contents, merge skills/hooks/MCP into agent |
|
||||
|
||||
## User Experience
|
||||
|
||||
### Plugin Directory
|
||||
|
||||
The plugin directory presents users with a catalog of available plugins. For each plugin, users see:
|
||||
- Plugin name and description (from `plugin.json`)
|
||||
- Author and version information
|
||||
- A "Launch" button
|
||||
|
||||
When a user clicks "Launch", the plugin directory:
|
||||
1. Reads the plugin's `entry_command` to know which slash command to invoke
|
||||
2. Determines what parameters the plugin accepts (if any)
|
||||
3. Redirects to OpenHands with this information encoded in the URL
|
||||
|
||||
### Parameter Collection
|
||||
|
||||
If a plugin requires user input (API keys, configuration values, etc.), the frontend displays a form modal before starting the conversation. Parameters are passed in the launch URL and rendered as form fields based on their type:
|
||||
|
||||
- **String values** → Text input
|
||||
- **Number values** → Number input
|
||||
- **Boolean values** → Checkbox
|
||||
|
||||
Only primitive types are supported. Complex types (arrays, objects) are not currently supported for parameter input.
|
||||
|
||||
The user fills in required values, then clicks "Start Conversation" to proceed.
|
||||
|
||||
## Launch Flow
|
||||
|
||||
1. **Plugin Directory** (external) constructs a launch URL to the OpenHands app server when user clicks "Launch":
|
||||
```
|
||||
/launch?plugins=BASE64_JSON&message=/city-weather:now%20Tokyo
|
||||
```
|
||||
|
||||
The `plugins` parameter includes any parameter definitions with default values:
|
||||
```json
|
||||
[{
|
||||
"source": "github:owner/repo",
|
||||
"repo_path": "plugins/my-plugin",
|
||||
"parameters": {"api_key": "", "timeout": 30, "debug": false}
|
||||
}]
|
||||
```
|
||||
|
||||
2. **OpenHands Frontend** (`/launch` route, [PR #12699](https://github.com/OpenHands/OpenHands/pull/12699)) displays modal with parameter form, collects user input
|
||||
|
||||
3. **OpenHands App Server** ([PR #12338](https://github.com/OpenHands/OpenHands/pull/12338)) receives the API call:
|
||||
```
|
||||
POST /api/v1/app-conversations
|
||||
{
|
||||
"plugins": [{"source": "github:owner/repo", "repo_path": "plugins/city-weather"}],
|
||||
"initial_message": {"content": [{"type": "text", "text": "/city-weather:now Tokyo"}]}
|
||||
}
|
||||
```
|
||||
|
||||
Call stack:
|
||||
- `AppConversationRouter` receives request with `PluginSpec` list
|
||||
- `LiveStatusAppConversationService._finalize_conversation_request()` converts `PluginSpec` → `PluginSource`
|
||||
- Creates `StartConversationRequest(plugins=sdk_plugins, ...)` and sends to agent server
|
||||
|
||||
4. **Agent Server** (inside sandbox, [SDK PR #1651](https://github.com/OpenHands/software-agent-sdk/pull/1651)) stores specs, defers loading:
|
||||
|
||||
Call stack:
|
||||
- `ConversationService.start_conversation()` receives `StartConversationRequest`
|
||||
- Creates `StoredConversation` with plugin specs
|
||||
- Creates `LocalConversation(plugins=request.plugins, ...)`
|
||||
- Plugin loading deferred until first `run()` or `send_message()`
|
||||
|
||||
5. **SDK** fetches and loads plugins on first use:
|
||||
|
||||
Call stack:
|
||||
- `LocalConversation._ensure_plugins_loaded()` triggered by first message
|
||||
- For each plugin spec:
|
||||
- `Plugin.fetch(source, ref, repo_path)` → clones/caches git repo
|
||||
- `Plugin.load(path)` → parses `plugin.json`, loads commands/skills/hooks
|
||||
- `plugin.add_skills_to(context)` → merges skills into agent
|
||||
- `plugin.add_mcp_config_to(config)` → merges MCP servers
|
||||
|
||||
6. **Agent** receives message, `/city-weather:now` triggers the skill
|
||||
|
||||
## Key Design Decisions
|
||||
|
||||
### Plugin Loading in Sandbox
|
||||
|
||||
Plugins load **inside the sandbox** because:
|
||||
- Plugin hooks and scripts need isolated execution
|
||||
- MCP servers run inside the sandbox
|
||||
- Skills may reference sandbox filesystem
|
||||
|
||||
### Entry Command Handling
|
||||
|
||||
The `entry_command` field in `plugin.json` allows plugin authors to declare a default command:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "city-weather",
|
||||
"entry_command": "now"
|
||||
}
|
||||
```
|
||||
|
||||
This flows through the system:
|
||||
1. Plugin author declares `entry_command` in plugin.json
|
||||
2. Plugin directory reads it when indexing
|
||||
3. Plugin directory includes `/city-weather:now` in the launch URL's `message` parameter
|
||||
4. Message passes through to agent as `initial_message`
|
||||
|
||||
The SDK exposes this field but does not auto-invoke it—callers control the initial message.
|
||||
|
||||
## Related
|
||||
|
||||
- [OpenHands PR #12338](https://github.com/OpenHands/OpenHands/pull/12338) - App server plugin support
|
||||
- [OpenHands PR #12699](https://github.com/OpenHands/OpenHands/pull/12699) - Frontend `/launch` route
|
||||
- [SDK PR #1651](https://github.com/OpenHands/software-agent-sdk/pull/1651) - Agent server plugin loading
|
||||
- [SDK PR #1647](https://github.com/OpenHands/software-agent-sdk/pull/1647) - Plugin.fetch() for remote plugin fetching
|
||||
@@ -1,207 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
This script can be removed once orgs is established - probably after Feb 15 2026
|
||||
|
||||
Downgrade script for migrated users.
|
||||
|
||||
This script identifies users who have been migrated (already_migrated=True)
|
||||
and reverts them back to the pre-migration state.
|
||||
|
||||
Usage:
|
||||
# Dry run - just list the users that would be downgraded
|
||||
python downgrade_migrated_users.py --dry-run
|
||||
|
||||
# Downgrade a specific user by their keycloak_user_id
|
||||
python downgrade_migrated_users.py --user-id <user_id>
|
||||
|
||||
# Downgrade all migrated users (with confirmation)
|
||||
python downgrade_migrated_users.py --all
|
||||
|
||||
# Downgrade all migrated users without confirmation (dangerous!)
|
||||
python downgrade_migrated_users.py --all --no-confirm
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
# Add the enterprise directory to the path
|
||||
sys.path.insert(0, '/workspace/project/OpenHands/enterprise')
|
||||
|
||||
from server.logger import logger
|
||||
from sqlalchemy import select, text
|
||||
from storage.database import session_maker
|
||||
from storage.user_settings import UserSettings
|
||||
from storage.user_store import UserStore
|
||||
|
||||
|
||||
def get_migrated_users() -> list[str]:
|
||||
"""Get list of keycloak_user_ids for users who have been migrated.
|
||||
|
||||
This includes:
|
||||
1. Users with already_migrated=True in user_settings (migrated users)
|
||||
2. Users in the 'user' table who don't have a user_settings entry (new sign-ups)
|
||||
"""
|
||||
with session_maker() as session:
|
||||
# Get users from user_settings with already_migrated=True
|
||||
migrated_result = session.execute(
|
||||
select(UserSettings.keycloak_user_id).where(
|
||||
UserSettings.already_migrated.is_(True)
|
||||
)
|
||||
)
|
||||
migrated_users = {row[0] for row in migrated_result.fetchall() if row[0]}
|
||||
|
||||
# Get users from the 'user' table (new sign-ups won't have user_settings)
|
||||
# These are users who signed up after the migration was deployed
|
||||
new_signup_result = session.execute(
|
||||
text("""
|
||||
SELECT CAST(u.id AS VARCHAR)
|
||||
FROM "user" u
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1 FROM user_settings us
|
||||
WHERE us.keycloak_user_id = CAST(u.id AS VARCHAR)
|
||||
)
|
||||
""")
|
||||
)
|
||||
new_signups = {row[0] for row in new_signup_result.fetchall() if row[0]}
|
||||
|
||||
# Combine both sets
|
||||
all_users = migrated_users | new_signups
|
||||
return list(all_users)
|
||||
|
||||
|
||||
async def downgrade_user(user_id: str) -> bool:
|
||||
"""Downgrade a single user.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak_user_id to downgrade
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await UserStore.downgrade_user(user_id)
|
||||
if result:
|
||||
print(f'✓ Successfully downgraded user: {user_id}')
|
||||
return True
|
||||
else:
|
||||
print(f'✗ Failed to downgrade user: {user_id}')
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f'✗ Error downgrading user {user_id}: {e}')
|
||||
logger.exception(
|
||||
'downgrade_script:error',
|
||||
extra={'user_id': user_id, 'error': str(e)},
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Downgrade migrated users back to pre-migration state'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Just list users that would be downgraded, without making changes',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
type=str,
|
||||
help='Downgrade a specific user by keycloak_user_id',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--all',
|
||||
action='store_true',
|
||||
help='Downgrade all migrated users',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-confirm',
|
||||
action='store_true',
|
||||
help='Skip confirmation prompt (use with caution!)',
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get list of migrated users
|
||||
migrated_users = get_migrated_users()
|
||||
print(f'\nFound {len(migrated_users)} migrated user(s).')
|
||||
|
||||
if args.dry_run:
|
||||
print('\n--- DRY RUN MODE ---')
|
||||
print('The following users would be downgraded:')
|
||||
for user_id in migrated_users:
|
||||
print(f' - {user_id}')
|
||||
print('\nNo changes were made.')
|
||||
return
|
||||
|
||||
if args.user_id:
|
||||
# Downgrade a specific user
|
||||
if args.user_id not in migrated_users:
|
||||
print(f'\nUser {args.user_id} is not in the migrated users list.')
|
||||
print('Either the user was not migrated, or the user_id is incorrect.')
|
||||
return
|
||||
|
||||
print(f'\nDowngrading user: {args.user_id}')
|
||||
if not args.no_confirm:
|
||||
confirm = input('Are you sure? (yes/no): ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
success = await downgrade_user(args.user_id)
|
||||
if success:
|
||||
print('\nDowngrade completed successfully.')
|
||||
else:
|
||||
print('\nDowngrade failed. Check logs for details.')
|
||||
sys.exit(1)
|
||||
|
||||
elif args.all:
|
||||
# Downgrade all migrated users
|
||||
if not migrated_users:
|
||||
print('\nNo migrated users to downgrade.')
|
||||
return
|
||||
|
||||
print(f'\n⚠️ About to downgrade {len(migrated_users)} user(s).')
|
||||
if not args.no_confirm:
|
||||
print('\nThis will:')
|
||||
print(' - Revert LiteLLM team/user budget settings')
|
||||
print(' - Delete organization entries')
|
||||
print(' - Delete user entries in the new schema')
|
||||
print(' - Reset the already_migrated flag')
|
||||
print('\nUsers to downgrade:')
|
||||
for user_id in migrated_users[:10]: # Show first 10
|
||||
print(f' - {user_id}')
|
||||
if len(migrated_users) > 10:
|
||||
print(f' ... and {len(migrated_users) - 10} more')
|
||||
|
||||
confirm = input('\nType "yes" to proceed: ')
|
||||
if confirm.lower() != 'yes':
|
||||
print('Cancelled.')
|
||||
return
|
||||
|
||||
print('\nStarting downgrade...\n')
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for user_id in migrated_users:
|
||||
success = await downgrade_user(user_id)
|
||||
if success:
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
print('\n--- Summary ---')
|
||||
print(f'Successful: {success_count}')
|
||||
print(f'Failed: {fail_count}')
|
||||
|
||||
if fail_count > 0:
|
||||
sys.exit(1)
|
||||
|
||||
else:
|
||||
parser.print_help()
|
||||
print('\nPlease specify --dry-run, --user-id, or --all')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
@@ -109,6 +109,9 @@ lines.append(
|
||||
lines.append(
|
||||
'OPENHANDS_BITBUCKET_SERVICE_CLS=integrations.bitbucket.bitbucket_service.SaaSBitBucketService'
|
||||
)
|
||||
lines.append(
|
||||
'OPENHANDS_BITBUCKET_DATA_CENTER_SERVICE_CLS=integrations.bitbucket_data_center.bitbucket_dc_service.SaaSBitbucketDCService'
|
||||
)
|
||||
lines.append(
|
||||
'OPENHANDS_CONVERSATION_VALIDATOR_CLS=storage.saas_conversation_validator.SaasConversationValidator'
|
||||
)
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
import os
|
||||
|
||||
import posthog
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Initialize PostHog
|
||||
posthog.api_key = os.environ.get('POSTHOG_CLIENT_KEY', 'phc_placeholder')
|
||||
posthog.host = os.environ.get('POSTHOG_HOST', 'https://us.i.posthog.com')
|
||||
|
||||
# Log PostHog configuration with masked API key for security
|
||||
api_key = posthog.api_key
|
||||
if api_key and len(api_key) > 8:
|
||||
masked_key = f'{api_key[:4]}...{api_key[-4:]}'
|
||||
else:
|
||||
masked_key = 'not_set_or_too_short'
|
||||
logger.info('posthog_configuration', extra={'posthog_api_key_masked': masked_key})
|
||||
|
||||
# Global toggle for the experiment manager
|
||||
ENABLE_EXPERIMENT_MANAGER = (
|
||||
os.environ.get('ENABLE_EXPERIMENT_MANAGER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Get the current experiment type from environment variable
|
||||
# If None, no experiment is running
|
||||
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT = os.environ.get(
|
||||
'EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT', ''
|
||||
)
|
||||
# System prompt experiment toggle
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT = os.environ.get(
|
||||
'EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', ''
|
||||
)
|
||||
|
||||
EXPERIMENT_CLAUDE4_VS_GPT5 = os.environ.get('EXPERIMENT_CLAUDE4_VS_GPT5', '')
|
||||
|
||||
EXPERIMENT_CONDENSER_MAX_STEP = os.environ.get('EXPERIMENT_CONDENSER_MAX_STEP', '')
|
||||
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:experiment_config',
|
||||
extra={
|
||||
'enable_experiment_manager': ENABLE_EXPERIMENT_MANAGER,
|
||||
'experiment_litellm_default_model_experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'experiment_system_prompt_experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'experiment_claude4_vs_gpt5_experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'experiment_condenser_max_step': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
},
|
||||
)
|
||||
@@ -1,97 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from experiments.constants import (
|
||||
ENABLE_EXPERIMENT_MANAGER,
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
)
|
||||
from experiments.experiment_versions import (
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.experiments.experiment_manager import ExperimentManager
|
||||
from openhands.sdk import Agent
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
class SaaSExperimentManager(ExperimentManager):
|
||||
@staticmethod
|
||||
def run_agent_variant_tests__v1(
|
||||
user_id: str | None, conversation_id: UUID, agent: Agent
|
||||
) -> Agent:
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return agent
|
||||
|
||||
if EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
agent = agent.model_copy(
|
||||
update={'system_prompt_filename': 'system_prompt_long_horizon.j2'}
|
||||
)
|
||||
|
||||
return agent
|
||||
|
||||
@staticmethod
|
||||
def run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_settings
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Run conversation variant test and potentially modify the conversation settings
|
||||
based on the PostHog feature flags.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings that may include convo_id and llm_model
|
||||
|
||||
Returns:
|
||||
The modified conversation settings
|
||||
"""
|
||||
logger.debug(
|
||||
'experiment_manager:run_conversation_variant_test:started',
|
||||
extra={'user_id': user_id, 'conversation_id': conversation_id},
|
||||
)
|
||||
|
||||
return conversation_settings
|
||||
|
||||
@staticmethod
|
||||
def run_config_variant_test(
|
||||
user_id: str | None, conversation_id: str, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Run agent config variant test and potentially modify the OpenHands config
|
||||
based on the current experiment type and PostHog feature flags.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
config: The OpenHands configuration
|
||||
|
||||
Returns:
|
||||
The modified OpenHands configuration
|
||||
"""
|
||||
logger.info(
|
||||
'experiment_manager:run_config_variant_test:started',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Skip all experiment processing if the experiment manager is disabled
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_config_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return config
|
||||
|
||||
# Pass the entire OpenHands config to the system prompt experiment
|
||||
# Let the experiment handler directly modify the config as needed
|
||||
modified_config = handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config
|
||||
)
|
||||
|
||||
# Condenser max step experiment is applied via conversation variant test,
|
||||
# not config variant test. Return modified config from system prompt only.
|
||||
return modified_config
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT
|
||||
from server.constants import (
|
||||
IS_FEATURE_ENV,
|
||||
build_litellm_proxy_model_path,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def handle_litellm_default_model_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
):
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings
|
||||
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT:
|
||||
logger.info(
|
||||
'experiment_manager:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='model_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'model_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
# Set the model based on the feature flag variant
|
||||
if enabled_variant == 'claude37':
|
||||
# Use the shared utility to construct the LiteLLM proxy model path
|
||||
model = build_litellm_proxy_model_path('claude-3-7-sonnet-20250219')
|
||||
# Update the conversation settings with the selected model
|
||||
conversation_settings.llm_model = model
|
||||
else:
|
||||
# Update the conversation settings with the default model for the current version
|
||||
conversation_settings.llm_model = get_default_litellm_model()
|
||||
|
||||
return conversation_settings
|
||||
@@ -1,181 +0,0 @@
|
||||
"""
|
||||
System prompt experiment handler.
|
||||
|
||||
This module contains the handler for the system prompt experiment that uses
|
||||
the PostHog variant as the system prompt filename.
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def _get_system_prompt_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the system prompt variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
logger.info(
|
||||
'experiment_manager_002:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='system_prompt_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='system_prompt_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'system_prompt_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Handle the system prompt experiment for OpenHands config.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
config: The OpenHands configuration
|
||||
|
||||
Returns:
|
||||
Modified OpenHands configuration
|
||||
"""
|
||||
enabled_variant = _get_system_prompt_variant(user_id, conversation_id)
|
||||
|
||||
# If variant is None, experiment is not enabled or there was an error
|
||||
if enabled_variant is None:
|
||||
return config
|
||||
|
||||
# Deep copy the config to avoid modifying the original
|
||||
modified_config = copy.deepcopy(config)
|
||||
|
||||
# Set the system prompt filename based on the variant
|
||||
if enabled_variant == 'control':
|
||||
# Use the long-horizon system prompt for the control variant
|
||||
agent_config = modified_config.get_agent_config(modified_config.default_agent)
|
||||
agent_config.system_prompt_filename = 'system_prompt_long_horizon.j2'
|
||||
agent_config.enable_plan_mode = True
|
||||
elif enabled_variant == 'interactive':
|
||||
modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename = 'system_prompt_interactive.j2'
|
||||
elif enabled_variant == 'no_tools':
|
||||
modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename = 'system_prompt.j2'
|
||||
else:
|
||||
logger.error(
|
||||
'system_prompt_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'no explicit mapping; returning original config',
|
||||
},
|
||||
)
|
||||
return config
|
||||
|
||||
# Log which prompt is being used
|
||||
logger.info(
|
||||
'system_prompt_experiment:prompt_selected',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'system_prompt_filename': modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return modified_config
|
||||
@@ -1,137 +0,0 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CLAUDE4_VS_GPT5
|
||||
from server.constants import (
|
||||
IS_FEATURE_ENV,
|
||||
build_litellm_proxy_model_path,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
def _get_model_variant(user_id: str | None, conversation_id: str) -> str | None:
|
||||
if not EXPERIMENT_CLAUDE4_VS_GPT5:
|
||||
logger.info(
|
||||
'experiment_manager:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_CLAUDE4_VS_GPT5, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='claude4_vs_gpt5_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='claude4_or_gpt5_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'claude4_or_gpt5_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_claude4_vs_gpt5_experiment(
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings
|
||||
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
|
||||
enabled_variant = _get_model_variant(user_id, conversation_id)
|
||||
|
||||
if not enabled_variant:
|
||||
return conversation_settings
|
||||
|
||||
# Set the model based on the feature flag variant
|
||||
if enabled_variant == 'gpt5':
|
||||
model = build_litellm_proxy_model_path('gpt-5-2025-08-07')
|
||||
conversation_settings.llm_model = model
|
||||
else:
|
||||
conversation_settings.llm_model = get_default_litellm_model()
|
||||
|
||||
return conversation_settings
|
||||
@@ -1,232 +0,0 @@
|
||||
"""
|
||||
Condenser max step experiment handler.
|
||||
|
||||
This module contains the handler for the condenser max step experiment that tests
|
||||
different max_size values for the condenser configuration.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CONDENSER_MAX_STEP
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.sdk import Agent
|
||||
from openhands.sdk.context.condenser import (
|
||||
LLMSummarizingCondenser,
|
||||
)
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
def _get_condenser_max_step_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the condenser max step variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_CONDENSER_MAX_STEP:
|
||||
logger.info(
|
||||
'experiment_manager_004:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_CONDENSER_MAX_STEP, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='condenser_max_step_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='condenser_max_step_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'condenser_max_step_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_condenser_max_step_experiment(
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Handle the condenser max step experiment for conversation settings.
|
||||
|
||||
We should not modify persistent user settings. Instead, apply the experiment
|
||||
variant to the conversation's in-memory settings object for this session only.
|
||||
|
||||
Variants:
|
||||
- control -> condenser_max_size = 120
|
||||
- treatment -> condenser_max_size = 80
|
||||
|
||||
Returns the (potentially) modified conversation_settings.
|
||||
"""
|
||||
|
||||
enabled_variant = _get_condenser_max_step_variant(user_id, conversation_id)
|
||||
|
||||
if enabled_variant is None:
|
||||
return conversation_settings
|
||||
|
||||
if enabled_variant == 'control':
|
||||
condenser_max_size = 120
|
||||
elif enabled_variant == 'treatment':
|
||||
condenser_max_size = 80
|
||||
else:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'unknown variant; returning original conversation settings',
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
try:
|
||||
# Apply the variant to this conversation only; do not persist to DB.
|
||||
# Not all OpenHands versions expose `condenser_max_size` on settings.
|
||||
if hasattr(conversation_settings, 'condenser_max_size'):
|
||||
conversation_settings.condenser_max_size = condenser_max_size
|
||||
logger.info(
|
||||
'condenser_max_step_experiment:conversation_settings_applied',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'condenser_max_size': condenser_max_size,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'condenser_max_step_experiment:field_missing_on_settings',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'condenser_max_size not present on ConversationInitData',
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:apply_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
return conversation_settings
|
||||
|
||||
|
||||
def handle_condenser_max_step_experiment__v1(
|
||||
user_id: str | None,
|
||||
conversation_id: UUID,
|
||||
agent: Agent,
|
||||
) -> Agent:
|
||||
enabled_variant = _get_condenser_max_step_variant(user_id, str(conversation_id))
|
||||
|
||||
if enabled_variant is None:
|
||||
return agent
|
||||
|
||||
if enabled_variant == 'control':
|
||||
condenser_max_size = 120
|
||||
elif enabled_variant == 'treatment':
|
||||
condenser_max_size = 80
|
||||
else:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'unknown variant; returning original conversation settings',
|
||||
},
|
||||
)
|
||||
return agent
|
||||
|
||||
condenser_llm = agent.llm.model_copy(update={'usage_id': 'condenser'})
|
||||
condenser = LLMSummarizingCondenser(
|
||||
llm=condenser_llm, max_size=condenser_max_size, keep_first=4
|
||||
)
|
||||
|
||||
return agent.model_copy(update={'condenser': condenser})
|
||||
@@ -1,25 +0,0 @@
|
||||
"""
|
||||
Experiment versions package.
|
||||
|
||||
This package contains handlers for different experiment versions.
|
||||
"""
|
||||
|
||||
from experiments.experiment_versions._001_litellm_default_model_experiment import (
|
||||
handle_litellm_default_model_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._002_system_prompt_experiment import (
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._003_llm_claude4_vs_gpt5_experiment import (
|
||||
handle_claude4_vs_gpt5_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||
handle_condenser_max_step_experiment,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'handle_litellm_default_model_experiment',
|
||||
'handle_system_prompt_experiment',
|
||||
'handle_claude4_vs_gpt5_experiment',
|
||||
'handle_condenser_max_step_experiment',
|
||||
]
|
||||
@@ -0,0 +1,65 @@
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.bitbucket_data_center.bitbucket_dc_service import (
|
||||
BitbucketDCService,
|
||||
)
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
class SaaSBitbucketDCService(BitbucketDCService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.debug(
|
||||
f'SaaSBitbucketDCService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
self.refresh = True
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
bitbucket_dc_token = None
|
||||
if self.external_auth_token:
|
||||
bitbucket_dc_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(),
|
||||
idp=ProviderType.BITBUCKET_DATA_CENTER,
|
||||
)
|
||||
)
|
||||
logger.debug('Got Bitbucket DC token via external_auth_token')
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
bitbucket_dc_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.BITBUCKET_DATA_CENTER
|
||||
)
|
||||
)
|
||||
logger.debug('Got Bitbucket DC token via external_auth_id')
|
||||
elif self.user_id:
|
||||
bitbucket_dc_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.BITBUCKET_DATA_CENTER
|
||||
)
|
||||
)
|
||||
logger.debug('Got Bitbucket DC token via user_id')
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return bitbucket_dc_token
|
||||
@@ -116,10 +116,8 @@ class GitHubDataCollector:
|
||||
|
||||
return suffix
|
||||
|
||||
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||
token_data = self.github_integration.get_access_token(
|
||||
installation_id # type: ignore[arg-type]
|
||||
)
|
||||
def _get_installation_access_token(self, installation_id: int) -> str:
|
||||
token_data = self.github_integration.get_access_token(installation_id)
|
||||
return token_data.token
|
||||
|
||||
def _check_openhands_author(self, name, login) -> bool:
|
||||
@@ -134,7 +132,7 @@ class GitHubDataCollector:
|
||||
)
|
||||
|
||||
def _get_issue_comments(
|
||||
self, installation_id: str, repo_name: str, issue_number: int, conversation_id
|
||||
self, installation_id: int, repo_name: str, issue_number: int, conversation_id
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all comments from an issue until a comment with conversation_id is found
|
||||
@@ -234,7 +232,7 @@ class GitHubDataCollector:
|
||||
f'[Github]: Saved issue #{issue_number} for {github_view.full_repo_name}'
|
||||
)
|
||||
|
||||
def _get_pr_commits(self, installation_id: str, repo_name: str, pr_number: int):
|
||||
def _get_pr_commits(self, installation_id: int, repo_name: str, pr_number: int):
|
||||
commits = []
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
@@ -431,7 +429,7 @@ class GitHubDataCollector:
|
||||
- Num openhands review comments
|
||||
"""
|
||||
pr_number = openhands_pr.pr_number
|
||||
installation_id = openhands_pr.installation_id
|
||||
installation_id = int(openhands_pr.installation_id)
|
||||
repo_id = openhands_pr.repo_id
|
||||
|
||||
# Get installation token and create Github client
|
||||
@@ -569,7 +567,7 @@ class GitHubDataCollector:
|
||||
openhands_helped_author = openhands_commit_count > 0
|
||||
|
||||
# Update the PR with OpenHands statistics
|
||||
update_success = store.update_pr_openhands_stats(
|
||||
update_success = await store.update_pr_openhands_stats(
|
||||
repo_id=repo_id,
|
||||
pr_number=pr_number,
|
||||
original_updated_at=openhands_pr.updated_at,
|
||||
@@ -612,7 +610,7 @@ class GitHubDataCollector:
|
||||
action = payload.get('action', '')
|
||||
return action == 'closed' and 'pull_request' in payload
|
||||
|
||||
def _track_closed_or_merged_pr(self, payload):
|
||||
async def _track_closed_or_merged_pr(self, payload):
|
||||
"""
|
||||
Track PR closed/merged event
|
||||
"""
|
||||
@@ -671,17 +669,17 @@ class GitHubDataCollector:
|
||||
num_general_comments=num_general_comments,
|
||||
)
|
||||
|
||||
store.insert_pr(pr)
|
||||
await store.insert_pr(pr)
|
||||
logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}')
|
||||
|
||||
def process_payload(self, message: Message):
|
||||
async def process_payload(self, message: Message):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
return
|
||||
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
if self._is_pr_closed_or_merged(raw_payload):
|
||||
self._track_closed_or_merged_pr(raw_payload)
|
||||
await self._track_closed_or_merged_pr(raw_payload)
|
||||
|
||||
async def save_data(self, github_view: ResolverViewInterface):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
|
||||
@@ -10,6 +10,7 @@ from integrations.github.github_view import (
|
||||
GithubIssue,
|
||||
GithubIssueComment,
|
||||
GithubPRComment,
|
||||
GithubViewType,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import (
|
||||
@@ -19,9 +20,11 @@ from integrations.models import (
|
||||
from integrations.types import ResolverViewInterface
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
ENABLE_SOLVABILITY_ANALYSIS,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
get_session_expired_message,
|
||||
get_user_not_found_message,
|
||||
)
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
@@ -40,10 +43,9 @@ from openhands.server.types import (
|
||||
SessionExpiredError,
|
||||
)
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class GithubManager(Manager):
|
||||
class GithubManager(Manager[GithubViewType]):
|
||||
def __init__(
|
||||
self, token_manager: TokenManager, data_collector: GitHubDataCollector
|
||||
):
|
||||
@@ -67,11 +69,8 @@ class GithubManager(Manager):
|
||||
|
||||
return f'{owner}/{repo_name}'
|
||||
|
||||
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||
# get_access_token is typed to only accept int, but it can handle str.
|
||||
token_data = self.github_integration.get_access_token(
|
||||
installation_id # type: ignore[arg-type]
|
||||
)
|
||||
def _get_installation_access_token(self, installation_id: int) -> str:
|
||||
token_data = self.github_integration.get_access_token(installation_id)
|
||||
return token_data.token
|
||||
|
||||
def _add_reaction(
|
||||
@@ -126,6 +125,76 @@ class GithubManager(Manager):
|
||||
|
||||
return False
|
||||
|
||||
def _get_issue_number_from_payload(self, message: Message) -> int | None:
|
||||
"""Extract issue/PR number from a GitHub webhook payload.
|
||||
|
||||
Supports all event types that can trigger jobs:
|
||||
- Labeled issues: payload['issue']['number']
|
||||
- Issue comments: payload['issue']['number']
|
||||
- PR comments: payload['issue']['number'] (PRs are accessed via issue endpoint)
|
||||
- Inline PR comments: payload['pull_request']['number']
|
||||
|
||||
Args:
|
||||
message: The incoming GitHub webhook message
|
||||
|
||||
Returns:
|
||||
The issue/PR number, or None if not found
|
||||
"""
|
||||
payload = message.message.get('payload', {})
|
||||
|
||||
# Labeled issues, issue comments, and PR comments all have 'issue' in payload
|
||||
if 'issue' in payload:
|
||||
return payload['issue']['number']
|
||||
|
||||
# Inline PR comments have 'pull_request' directly in payload
|
||||
if 'pull_request' in payload:
|
||||
return payload['pull_request']['number']
|
||||
|
||||
return None
|
||||
|
||||
def _send_user_not_found_message(self, message: Message, username: str):
|
||||
"""Send a message to the user informing them they need to create an OpenHands account.
|
||||
|
||||
This method handles all supported trigger types:
|
||||
- Labeled issues (action='labeled' with openhands label)
|
||||
- Issue comments (comment containing @openhands)
|
||||
- PR comments (comment containing @openhands on a PR)
|
||||
- Inline PR review comments (comment containing @openhands)
|
||||
|
||||
Args:
|
||||
message: The incoming GitHub webhook message
|
||||
username: The GitHub username to mention in the response
|
||||
"""
|
||||
payload = message.message.get('payload', {})
|
||||
installation_id = message.message['installation']
|
||||
repo_obj = payload['repository']
|
||||
full_repo_name = self._get_full_repo_name(repo_obj)
|
||||
|
||||
# Get installation token to post the comment
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
|
||||
# Determine the issue/PR number based on the event type
|
||||
issue_number = self._get_issue_number_from_payload(message)
|
||||
|
||||
if not issue_number:
|
||||
logger.warning(
|
||||
f'[GitHub] Could not determine issue/PR number to send user not found message for {username}. '
|
||||
f'Payload keys: {list(payload.keys())}'
|
||||
)
|
||||
return
|
||||
|
||||
# Post the comment
|
||||
try:
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(full_repo_name)
|
||||
issue = repo.get_issue(number=issue_number)
|
||||
issue.create_comment(get_user_not_found_message(username))
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[GitHub] Failed to send user not found message to {username} '
|
||||
f'on {full_repo_name}#{issue_number}: {e}'
|
||||
)
|
||||
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
@@ -145,11 +214,7 @@ class GithubManager(Manager):
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
# Check event types before making expensive API calls (e.g., _user_has_write_access_to_repo)
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
@@ -159,13 +224,22 @@ class GithubManager(Manager):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
user_has_write_access = self._user_has_write_access_to_repo(
|
||||
installation_id, repo_name, username
|
||||
)
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
if (
|
||||
GithubFactory.is_eligible_for_conversation_starter(message)
|
||||
and user_has_write_access
|
||||
):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
return user_has_write_access
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
try:
|
||||
await call_sync_from_async(self.data_collector.process_payload, message)
|
||||
await self.data_collector.process_payload(message)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'[Github]: Error processing payload for gh interaction', exc_info=True
|
||||
@@ -174,9 +248,20 @@ class GithubManager(Manager):
|
||||
if await self.is_job_requested(message):
|
||||
payload = message.message.get('payload', {})
|
||||
user_id = payload['sender']['id']
|
||||
username = payload['sender']['login']
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
|
||||
# Check if the user has an OpenHands account
|
||||
if not keycloak_user_id:
|
||||
logger.warning(
|
||||
f'[GitHub] User {username} (id={user_id}) not found in Keycloak. '
|
||||
f'User must create an OpenHands account first.'
|
||||
)
|
||||
self._send_user_not_found_message(message, username)
|
||||
return
|
||||
|
||||
github_view = await GithubFactory.create_github_view_from_payload(
|
||||
message, keycloak_user_id
|
||||
)
|
||||
@@ -188,46 +273,51 @@ class GithubManager(Manager):
|
||||
github_view.installation_id
|
||||
)
|
||||
# Store the installation token
|
||||
self.token_manager.store_org_token(
|
||||
await self.token_manager.store_org_token(
|
||||
github_view.installation_id, installation_token
|
||||
)
|
||||
# Add eyes reaction to acknowledge we've read the request
|
||||
self._add_reaction(github_view, 'eyes', installation_token)
|
||||
await self.start_job(github_view)
|
||||
|
||||
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||
installation_token = self.token_manager.load_org_token(
|
||||
async def send_message(self, message: str, github_view: GithubViewType):
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
github_view: The GitHub view object containing issue/PR/comment info
|
||||
"""
|
||||
installation_token = await self.token_manager.load_org_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
if not installation_token:
|
||||
logger.warning('Missing installation token')
|
||||
return
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=github_view.comment_id, body=outgoing_message
|
||||
comment_id=github_view.comment_id, body=message
|
||||
)
|
||||
|
||||
elif (
|
||||
isinstance(github_view, GithubPRComment)
|
||||
or isinstance(github_view, GithubIssueComment)
|
||||
or isinstance(github_view, GithubIssue)
|
||||
elif isinstance(
|
||||
github_view, (GithubPRComment, GithubIssueComment, GithubIssue)
|
||||
):
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(number=github_view.issue_number)
|
||||
issue.create_comment(outgoing_message)
|
||||
issue.create_comment(message)
|
||||
|
||||
else:
|
||||
logger.warning('Unsupported location')
|
||||
# Catch any new types added to GithubViewType that aren't handled above
|
||||
logger.warning( # type: ignore[unreachable]
|
||||
f'Unsupported github_view type: {type(github_view).__name__}'
|
||||
)
|
||||
return
|
||||
|
||||
async def start_job(self, github_view: ResolverViewInterface):
|
||||
async def start_job(self, github_view: GithubViewType) -> None:
|
||||
"""Kick off a job with openhands agent.
|
||||
|
||||
1. Get user credential
|
||||
@@ -240,7 +330,7 @@ class GithubManager(Manager):
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
msg_info: str = ''
|
||||
|
||||
try:
|
||||
user_info = github_view.user_info
|
||||
@@ -281,19 +371,19 @@ class GithubManager(Manager):
|
||||
# 3. Once the conversation is started, its base cost will include the report's spend as well which allows us to control max budget per resolver task
|
||||
convo_metadata = await github_view.initialize_new_conversation()
|
||||
solvability_summary = None
|
||||
try:
|
||||
if user_token:
|
||||
if not ENABLE_SOLVABILITY_ANALYSIS:
|
||||
logger.info(
|
||||
'[Github]: Solvability report feature is disabled, skipping'
|
||||
)
|
||||
else:
|
||||
try:
|
||||
solvability_summary = await summarize_issue_solvability(
|
||||
github_view, user_token
|
||||
)
|
||||
else:
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
'[Github]: No user token available for solvability analysis'
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
github_view.user_info.keycloak_user_id, self.token_manager
|
||||
@@ -356,15 +446,13 @@ class GithubManager(Manager):
|
||||
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, github_view)
|
||||
await self.send_message(msg_info, github_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Github]: Error starting job')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', github_view
|
||||
)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
try:
|
||||
await self.data_collector.save_data(github_view)
|
||||
|
||||
@@ -122,13 +122,37 @@ class SaaSGitHubService(GitHubService):
|
||||
raise Exception(f'No node_id found for repository {repo_id}')
|
||||
return node_id
|
||||
|
||||
async def _get_external_auth_id(self) -> str | None:
|
||||
"""Get or fetch external_auth_id from Keycloak token if not already set."""
|
||||
if self.external_auth_id:
|
||||
return self.external_auth_id
|
||||
|
||||
if self.external_auth_token:
|
||||
try:
|
||||
user_info = await self.token_manager.get_user_info(
|
||||
self.external_auth_token.get_secret_value()
|
||||
)
|
||||
self.external_auth_id = user_info.sub
|
||||
logger.info(
|
||||
f'Determined external_auth_id from Keycloak token: {self.external_auth_id}'
|
||||
)
|
||||
return self.external_auth_id
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Could not determine external_auth_id from token: {e}',
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def get_paginated_repos(self, page, per_page, sort, installation_id):
|
||||
repositories = await super().get_paginated_repos(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, self.external_auth_id)
|
||||
)
|
||||
external_auth_id = await self._get_external_auth_id()
|
||||
if external_auth_id:
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, external_auth_id)
|
||||
)
|
||||
return repositories
|
||||
|
||||
async def get_all_repositories(
|
||||
@@ -136,8 +160,10 @@ class SaaSGitHubService(GitHubService):
|
||||
) -> list[Repository]:
|
||||
repositories = await super().get_all_repositories(sort, app_mode)
|
||||
# Schedule the background task without awaiting it
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, self.external_auth_id)
|
||||
)
|
||||
external_auth_id = await self._get_external_auth_id()
|
||||
if external_auth_id:
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, external_auth_id)
|
||||
)
|
||||
# Return repositories immediately
|
||||
return repositories
|
||||
|
||||
@@ -14,7 +14,6 @@ from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
@@ -90,7 +89,6 @@ async def summarize_issue_solvability(
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
@@ -108,6 +106,11 @@ async def summarize_issue_solvability(
|
||||
f'Solvability analysis disabled for user {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
if user_settings.llm_api_key is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No LLM API key found for user {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
try:
|
||||
llm_config = LLMConfig(
|
||||
model=user_settings.llm_model,
|
||||
|
||||
@@ -3,8 +3,9 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from github import Auth, Github, GithubIntegration
|
||||
from integrations.utils import CONVERSATION_URL, get_summary_instruction
|
||||
from github import Auth, Github, GithubException, GithubIntegration
|
||||
from integrations.utils import get_summary_instruction
|
||||
from integrations.v1_utils import handle_callback_error
|
||||
from pydantic import Field
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
|
||||
@@ -42,7 +43,6 @@ class GithubV1CallbackProcessor(EventCallbackProcessor):
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for GitHub V1 integration."""
|
||||
|
||||
# Only handle ConversationStateUpdateEvent
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
@@ -78,25 +78,20 @@ class GithubV1CallbackProcessor(EventCallbackProcessor):
|
||||
detail=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.exception('[GitHub V1] Error processing callback: %s', e)
|
||||
|
||||
# Only try to post error to GitHub if we have basic requirements
|
||||
try:
|
||||
# Check if we have installation ID and credentials before posting
|
||||
if (
|
||||
self.github_view_data.get('installation_id')
|
||||
and GITHUB_APP_CLIENT_ID
|
||||
and GITHUB_APP_PRIVATE_KEY
|
||||
):
|
||||
await self._post_summary_to_github(
|
||||
f'OpenHands encountered an error: **{str(e)}**.\n\n'
|
||||
f'[See the conversation]({CONVERSATION_URL.format(conversation_id)})'
|
||||
'for more information.'
|
||||
)
|
||||
except Exception as post_error:
|
||||
_logger.warning(
|
||||
'[GitHub V1] Failed to post error message to GitHub: %s', post_error
|
||||
)
|
||||
# Check if we have installation ID and credentials before posting
|
||||
can_post_error = bool(
|
||||
self.github_view_data.get('installation_id')
|
||||
and GITHUB_APP_CLIENT_ID
|
||||
and GITHUB_APP_PRIVATE_KEY
|
||||
)
|
||||
await handle_callback_error(
|
||||
error=e,
|
||||
conversation_id=conversation_id,
|
||||
service_name='GitHub',
|
||||
service_logger=_logger,
|
||||
can_post_error=can_post_error,
|
||||
post_error_func=self._post_summary_to_github,
|
||||
)
|
||||
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.ERROR,
|
||||
@@ -137,19 +132,30 @@ class GithubV1CallbackProcessor(EventCallbackProcessor):
|
||||
full_repo_name = self.github_view_data['full_repo_name']
|
||||
issue_number = self.github_view_data['issue_number']
|
||||
|
||||
if self.inline_pr_comment:
|
||||
try:
|
||||
if self.inline_pr_comment:
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(full_repo_name)
|
||||
pr = repo.get_pull(issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=self.github_view_data.get('comment_id', ''),
|
||||
body=summary,
|
||||
)
|
||||
return
|
||||
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(full_repo_name)
|
||||
pr = repo.get_pull(issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=self.github_view_data.get('comment_id', ''), body=summary
|
||||
issue = repo.get_issue(number=issue_number)
|
||||
issue.create_comment(summary)
|
||||
except GithubException as e:
|
||||
if e.status == 410:
|
||||
_logger.info(
|
||||
'[GitHub V1] Issue/PR %s#%s was deleted, skipping summary post',
|
||||
full_repo_name,
|
||||
issue_number,
|
||||
)
|
||||
return
|
||||
|
||||
with Github(auth=Auth.Token(installation_token)) as github_client:
|
||||
repo = github_client.get_repo(full_repo_name)
|
||||
issue = repo.get_issue(number=issue_number)
|
||||
issue.create_comment(summary)
|
||||
else:
|
||||
raise
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Agent / sandbox helpers
|
||||
@@ -167,8 +173,8 @@ class GithubV1CallbackProcessor(EventCallbackProcessor):
|
||||
send_message_request = AskAgentRequest(question=message_content)
|
||||
|
||||
url = (
|
||||
f'{agent_server_url.rstrip("/")}'
|
||||
f'/api/conversations/{conversation_id}/ask_agent'
|
||||
f"{agent_server_url.rstrip('/')}"
|
||||
f"/api/conversations/{conversation_id}/ask_agent"
|
||||
)
|
||||
headers = {'X-Session-API-Key': session_api_key}
|
||||
payload = send_message_request.model_dump()
|
||||
@@ -230,8 +236,7 @@ class GithubV1CallbackProcessor(EventCallbackProcessor):
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _request_summary(self, conversation_id: UUID) -> str:
|
||||
"""
|
||||
Ask the agent to produce a summary of its work and return the agent response.
|
||||
"""Ask the agent to produce a summary of its work and return the agent response.
|
||||
|
||||
NOTE: This method now returns a string (the agent server's response text)
|
||||
and raises exceptions on errors. The wrapping into EventCallbackResult
|
||||
|
||||
@@ -24,7 +24,6 @@ from jinja2 import Environment
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.org_store import OrgStore
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
@@ -73,7 +72,6 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
@@ -82,13 +80,10 @@ async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
org = OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
return False
|
||||
return bool(org.enable_proactive_conversation_starters)
|
||||
|
||||
|
||||
# =================================================
|
||||
@@ -153,9 +148,7 @@ class GithubIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
@@ -238,6 +231,29 @@ class GithubIssue(ResolverViewInterface):
|
||||
conversation_instructions=conversation_instructions,
|
||||
)
|
||||
|
||||
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
|
||||
"""Build the initial user message for V1 resolver conversations.
|
||||
|
||||
For "issue opened" events (no specific comment body), we can simply
|
||||
concatenate the user prompt and the rendered issue context.
|
||||
|
||||
Subclasses that represent comment-driven events (issue comments, PR review
|
||||
comments, inline review comments) override this method to control ordering
|
||||
(e.g., context first, then the triggering comment, then previous comments).
|
||||
"""
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
|
||||
parts: list[str] = []
|
||||
if user_instructions.strip():
|
||||
parts.append(user_instructions.strip())
|
||||
if conversation_instructions.strip():
|
||||
parts.append(conversation_instructions.strip())
|
||||
|
||||
return '\n\n'.join(parts)
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
@@ -247,13 +263,11 @@ class GithubIssue(ResolverViewInterface):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitHub V1]: Creating V1 conversation')
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
initial_user_text = await self._get_v1_initial_user_message(jinja_env)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
role='user', content=[TextContent(text=user_instructions)]
|
||||
role='user', content=[TextContent(text=initial_user_text)]
|
||||
)
|
||||
|
||||
# Create the GitHub V1 callback processor
|
||||
@@ -265,7 +279,9 @@ class GithubIssue(ResolverViewInterface):
|
||||
# Create the V1 conversation start request with the callback processor
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
system_message_suffix=conversation_instructions,
|
||||
# NOTE: Resolver instructions are intended to be lower priority than the
|
||||
# system prompt, so we inject them into the initial user message.
|
||||
system_message_suffix=None,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
@@ -336,6 +352,17 @@ class GithubIssueComment(GithubIssue):
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
|
||||
await self._load_resolver_context()
|
||||
template = jinja_env.get_template('issue_comment_initial_message.j2')
|
||||
return template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
issue_comment=self.comment_body,
|
||||
previous_comments=self.previous_comments,
|
||||
).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubPRComment(GithubIssueComment):
|
||||
@@ -362,6 +389,18 @@ class GithubPRComment(GithubIssueComment):
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
|
||||
await self._load_resolver_context()
|
||||
template = jinja_env.get_template('pr_update_initial_message.j2')
|
||||
return template.render(
|
||||
pr_number=self.issue_number,
|
||||
branch_name=self.branch_name,
|
||||
pr_title=self.title,
|
||||
pr_body=self.description,
|
||||
pr_comment=self.comment_body,
|
||||
comments=self.previous_comments,
|
||||
).strip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubInlinePRComment(GithubPRComment):
|
||||
@@ -408,6 +447,20 @@ class GithubInlinePRComment(GithubPRComment):
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_v1_initial_user_message(self, jinja_env: Environment) -> str:
|
||||
await self._load_resolver_context()
|
||||
template = jinja_env.get_template('pr_update_initial_message.j2')
|
||||
return template.render(
|
||||
pr_number=self.issue_number,
|
||||
branch_name=self.branch_name,
|
||||
pr_title=self.title,
|
||||
pr_body=self.description,
|
||||
file_location=self.file_location,
|
||||
line_number=self.line_number,
|
||||
pr_comment=self.comment_body,
|
||||
comments=self.previous_comments,
|
||||
).strip()
|
||||
|
||||
def _create_github_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitHub integration."""
|
||||
from integrations.github.github_v1_callback_processor import (
|
||||
@@ -740,7 +793,7 @@ class GithubFactory:
|
||||
@staticmethod
|
||||
async def create_github_view_from_payload(
|
||||
message: Message, keycloak_user_id: str
|
||||
) -> ResolverViewInterface:
|
||||
) -> GithubViewType:
|
||||
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||
Also return metadata about the event (e.g., action type).
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import cast
|
||||
|
||||
from integrations.gitlab.gitlab_view import (
|
||||
GitlabFactory,
|
||||
@@ -17,6 +20,7 @@ from integrations.utils import (
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
get_session_expired_message,
|
||||
)
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
@@ -33,7 +37,7 @@ from openhands.server.types import (
|
||||
from openhands.storage.data_models.secrets import Secrets
|
||||
|
||||
|
||||
class GitlabManager(Manager):
|
||||
class GitlabManager(Manager[GitlabViewType]):
|
||||
def __init__(self, token_manager: TokenManager, data_collector: None = None):
|
||||
self.token_manager = token_manager
|
||||
|
||||
@@ -67,11 +71,11 @@ class GitlabManager(Manager):
|
||||
logger.warning(f'Got invalid keyloak user id for GitLab User {user_id}')
|
||||
return False
|
||||
|
||||
# Importing here prevents circular import
|
||||
# GitLabServiceImpl returns SaaSGitLabService in enterprise context
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service: SaaSGitLabService = GitLabServiceImpl(
|
||||
external_auth_id=keycloak_user_id
|
||||
gitlab_service = cast(
|
||||
SaaSGitLabService, GitLabServiceImpl(external_auth_id=keycloak_user_id)
|
||||
)
|
||||
|
||||
return await gitlab_service.user_has_write_access(project_id)
|
||||
@@ -121,55 +125,52 @@ class GitlabManager(Manager):
|
||||
# Check if the user has write access to the repository
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
async def send_message(self, message: str, gitlab_view: ResolverViewInterface):
|
||||
"""Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
message: The message content to send (plain text string)
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||
|
||||
# Importing here prevents circular import
|
||||
# GitLabServiceImpl returns SaaSGitLabService in enterprise context
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service: SaaSGitLabService = GitLabServiceImpl(
|
||||
external_auth_id=keycloak_user_id
|
||||
gitlab_service = cast(
|
||||
SaaSGitLabService, GitLabServiceImpl(external_auth_id=keycloak_user_id)
|
||||
)
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||
gitlab_view, GitlabMRComment
|
||||
):
|
||||
await gitlab_service.reply_to_mr(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message.message,
|
||||
project_id=str(gitlab_view.project_id),
|
||||
merge_request_iid=str(gitlab_view.issue_number),
|
||||
discussion_id=gitlab_view.discussion_id,
|
||||
body=message,
|
||||
)
|
||||
|
||||
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
outgoing_message,
|
||||
project_id=str(gitlab_view.project_id),
|
||||
issue_number=str(gitlab_view.issue_number),
|
||||
discussion_id=gitlab_view.discussion_id,
|
||||
body=message,
|
||||
)
|
||||
elif isinstance(gitlab_view, GitlabIssue):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
None, # no discussion id, issue is tagged
|
||||
outgoing_message,
|
||||
project_id=str(gitlab_view.project_id),
|
||||
issue_number=str(gitlab_view.issue_number),
|
||||
discussion_id=None, # no discussion id, issue is tagged
|
||||
body=message,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'[GitLab] Unsupported view type: {type(gitlab_view).__name__}'
|
||||
)
|
||||
|
||||
async def start_job(self, gitlab_view: GitlabViewType):
|
||||
async def start_job(self, gitlab_view: GitlabViewType) -> None:
|
||||
"""
|
||||
Start a job for the GitLab view.
|
||||
|
||||
@@ -214,8 +215,18 @@ class GitlabManager(Manager):
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize conversation and get metadata (following GitHub pattern)
|
||||
convo_metadata = await gitlab_view.initialize_new_conversation()
|
||||
|
||||
saas_user_auth = await get_saas_user_auth(
|
||||
gitlab_view.user_info.keycloak_user_id, self.token_manager
|
||||
)
|
||||
|
||||
await gitlab_view.create_new_conversation(
|
||||
self.jinja_env, secret_store.provider_tokens
|
||||
self.jinja_env,
|
||||
secret_store.provider_tokens,
|
||||
convo_metadata,
|
||||
saas_user_auth,
|
||||
)
|
||||
|
||||
conversation_id = gitlab_view.conversation_id
|
||||
@@ -224,18 +235,19 @@ class GitlabManager(Manager):
|
||||
f'[GitLab] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# Create a GitlabCallbackProcessor for this conversation
|
||||
processor = GitlabCallbackProcessor(
|
||||
gitlab_view=gitlab_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
if not gitlab_view.v1_enabled:
|
||||
# Create a GitlabCallbackProcessor for this conversation
|
||||
processor = GitlabCallbackProcessor(
|
||||
gitlab_view=gitlab_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
logger.info(
|
||||
f'[GitLab] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
msg_info = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
@@ -262,12 +274,10 @@ class GitlabManager(Manager):
|
||||
msg_info = get_session_expired_message(user_info.username)
|
||||
|
||||
# Send the acknowledgment message
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
await self.send_message(msg_info, gitlab_view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', gitlab_view
|
||||
)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
@@ -185,6 +185,30 @@ class SaaSGitLabService(GitLabService):
|
||||
users_personal_projects: List of personal projects owned by the user
|
||||
repositories: List of Repository objects to store
|
||||
"""
|
||||
# If external_auth_id is not set, try to determine it from the Keycloak token
|
||||
if not self.external_auth_id and self.external_auth_token:
|
||||
try:
|
||||
user_info = await self.token_manager.get_user_info(
|
||||
self.external_auth_token.get_secret_value()
|
||||
)
|
||||
keycloak_user_id = user_info.sub
|
||||
self.external_auth_id = keycloak_user_id
|
||||
logger.info(
|
||||
f'Determined external_auth_id from Keycloak token: {self.external_auth_id}'
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'Cannot store repository data: external_auth_id is not set and could not be determined from token',
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
if not self.external_auth_id:
|
||||
logger.warning(
|
||||
'Cannot store repository data: external_auth_id could not be determined'
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# First, add owned projects and groups to the database
|
||||
await self.add_owned_projects_and_groups_to_db(users_personal_projects)
|
||||
|
||||
269
enterprise/integrations/gitlab/gitlab_v1_callback_processor.py
Normal file
269
enterprise/integrations/gitlab/gitlab_v1_callback_processor.py
Normal file
@@ -0,0 +1,269 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from integrations.utils import get_summary_instruction
|
||||
from integrations.v1_utils import handle_callback_error
|
||||
from pydantic import Field
|
||||
|
||||
from openhands.agent_server.models import AskAgentRequest, AskAgentResponse
|
||||
from openhands.app_server.event_callback.event_callback_models import (
|
||||
EventCallback,
|
||||
EventCallbackProcessor,
|
||||
)
|
||||
from openhands.app_server.event_callback.event_callback_result_models import (
|
||||
EventCallbackResult,
|
||||
EventCallbackResultStatus,
|
||||
)
|
||||
from openhands.app_server.event_callback.util import (
|
||||
ensure_conversation_found,
|
||||
ensure_running_sandbox,
|
||||
get_agent_server_url_from_sandbox,
|
||||
)
|
||||
from openhands.sdk import Event
|
||||
from openhands.sdk.event import ConversationStateUpdateEvent
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GitlabV1CallbackProcessor(EventCallbackProcessor):
|
||||
"""Callback processor for GitLab V1 integrations."""
|
||||
|
||||
gitlab_view_data: dict[str, Any] = Field(default_factory=dict)
|
||||
should_request_summary: bool = Field(default=True)
|
||||
inline_mr_comment: bool = Field(default=False)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
conversation_id: UUID,
|
||||
callback: EventCallback,
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for GitLab V1 integration."""
|
||||
# Only handle ConversationStateUpdateEvent
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
|
||||
# Only act when execution has finished
|
||||
if not (event.key == 'execution_status' and event.value == 'finished'):
|
||||
return None
|
||||
|
||||
_logger.info('[GitLab V1] Callback agent state was %s', event)
|
||||
_logger.info(
|
||||
'[GitLab V1] Should request summary: %s', self.should_request_summary
|
||||
)
|
||||
|
||||
if not self.should_request_summary:
|
||||
return None
|
||||
|
||||
self.should_request_summary = False
|
||||
|
||||
try:
|
||||
_logger.info(f'[GitLab V1] Requesting summary {conversation_id}')
|
||||
summary = await self._request_summary(conversation_id)
|
||||
_logger.info(
|
||||
f'[GitLab V1] Posting summary {conversation_id}',
|
||||
extra={'summary': summary},
|
||||
)
|
||||
await self._post_summary_to_gitlab(summary)
|
||||
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.SUCCESS,
|
||||
event_callback_id=callback.id,
|
||||
event_id=event.id,
|
||||
conversation_id=conversation_id,
|
||||
detail=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
can_post_error = bool(self.gitlab_view_data.get('keycloak_user_id'))
|
||||
await handle_callback_error(
|
||||
error=e,
|
||||
conversation_id=conversation_id,
|
||||
service_name='GitLab',
|
||||
service_logger=_logger,
|
||||
can_post_error=can_post_error,
|
||||
post_error_func=self._post_summary_to_gitlab,
|
||||
)
|
||||
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.ERROR,
|
||||
event_callback_id=callback.id,
|
||||
event_id=event.id,
|
||||
conversation_id=conversation_id,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# GitLab helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _post_summary_to_gitlab(self, summary: str) -> None:
|
||||
"""Post a summary comment to the configured GitLab issue or MR."""
|
||||
# Import here to avoid circular imports
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
keycloak_user_id = self.gitlab_view_data.get('keycloak_user_id')
|
||||
if not keycloak_user_id:
|
||||
raise RuntimeError('Missing keycloak user ID for GitLab')
|
||||
|
||||
gitlab_service = SaaSGitLabService(external_auth_id=keycloak_user_id)
|
||||
|
||||
project_id = self.gitlab_view_data['project_id']
|
||||
issue_number = self.gitlab_view_data['issue_number']
|
||||
discussion_id = self.gitlab_view_data['discussion_id']
|
||||
is_mr = self.gitlab_view_data.get('is_mr', False)
|
||||
|
||||
if is_mr:
|
||||
await gitlab_service.reply_to_mr(
|
||||
project_id,
|
||||
issue_number,
|
||||
discussion_id,
|
||||
summary,
|
||||
)
|
||||
else:
|
||||
await gitlab_service.reply_to_issue(
|
||||
project_id,
|
||||
issue_number,
|
||||
discussion_id,
|
||||
summary,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Agent / sandbox helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _ask_question(
|
||||
self,
|
||||
httpx_client: httpx.AsyncClient,
|
||||
agent_server_url: str,
|
||||
conversation_id: UUID,
|
||||
session_api_key: str,
|
||||
message_content: str,
|
||||
) -> str:
|
||||
"""Send a message to the agent server via the V1 API and return response text."""
|
||||
send_message_request = AskAgentRequest(question=message_content)
|
||||
|
||||
url = (
|
||||
f"{agent_server_url.rstrip('/')}"
|
||||
f"/api/conversations/{conversation_id}/ask_agent"
|
||||
)
|
||||
headers = {'X-Session-API-Key': session_api_key}
|
||||
payload = send_message_request.model_dump()
|
||||
|
||||
try:
|
||||
response = await httpx_client.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
agent_response = AskAgentResponse.model_validate(response.json())
|
||||
return agent_response.response
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = f'HTTP {e.response.status_code} error'
|
||||
try:
|
||||
error_body = e.response.text
|
||||
if error_body:
|
||||
error_detail += f': {error_body}'
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
|
||||
_logger.error(
|
||||
'[GitLab V1] HTTP error sending message to %s: %s. '
|
||||
'Request payload: %s. Response headers: %s',
|
||||
url,
|
||||
error_detail,
|
||||
payload,
|
||||
dict(e.response.headers),
|
||||
exc_info=True,
|
||||
)
|
||||
raise Exception(f'Failed to send message to agent server: {error_detail}')
|
||||
|
||||
except httpx.TimeoutException:
|
||||
error_detail = f'Request timeout after 30 seconds to {url}'
|
||||
_logger.error(
|
||||
'[GitLab V1] %s. Request payload: %s',
|
||||
error_detail,
|
||||
payload,
|
||||
exc_info=True,
|
||||
)
|
||||
raise Exception(error_detail)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
error_detail = f'Request error to {url}: {str(e)}'
|
||||
_logger.error(
|
||||
'[GitLab V1] %s. Request payload: %s',
|
||||
error_detail,
|
||||
payload,
|
||||
exc_info=True,
|
||||
)
|
||||
raise Exception(error_detail)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Summary orchestration
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _request_summary(self, conversation_id: UUID) -> str:
|
||||
"""Ask the agent to produce a summary of its work and return the agent response.
|
||||
|
||||
NOTE: This method now returns a string (the agent server's response text)
|
||||
and raises exceptions on errors. The wrapping into EventCallbackResult
|
||||
is handled by __call__.
|
||||
"""
|
||||
# Import services within the method to avoid circular imports
|
||||
from openhands.app_server.config import (
|
||||
get_app_conversation_info_service,
|
||||
get_httpx_client,
|
||||
get_sandbox_service,
|
||||
)
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import (
|
||||
ADMIN,
|
||||
USER_CONTEXT_ATTR,
|
||||
)
|
||||
|
||||
# Create injector state for dependency injection
|
||||
state = InjectorState()
|
||||
setattr(state, USER_CONTEXT_ATTR, ADMIN)
|
||||
|
||||
async with (
|
||||
get_app_conversation_info_service(state) as app_conversation_info_service,
|
||||
get_sandbox_service(state) as sandbox_service,
|
||||
get_httpx_client(state) as httpx_client,
|
||||
):
|
||||
# 1. Conversation lookup
|
||||
app_conversation_info = ensure_conversation_found(
|
||||
await app_conversation_info_service.get_app_conversation_info(
|
||||
conversation_id
|
||||
),
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
# 2. Sandbox lookup + validation
|
||||
sandbox = ensure_running_sandbox(
|
||||
await sandbox_service.get_sandbox(app_conversation_info.sandbox_id),
|
||||
app_conversation_info.sandbox_id,
|
||||
)
|
||||
|
||||
assert (
|
||||
sandbox.session_api_key is not None
|
||||
), f'No session API key for sandbox: {sandbox.id}'
|
||||
|
||||
# 3. URL + instruction
|
||||
agent_server_url = get_agent_server_url_from_sandbox(sandbox)
|
||||
|
||||
# Prepare message based on agent state
|
||||
message_content = get_summary_instruction()
|
||||
|
||||
# Ask the agent and return the response text
|
||||
return await self._ask_question(
|
||||
httpx_client=httpx_client,
|
||||
agent_server_url=agent_server_url,
|
||||
conversation_id=conversation_id,
|
||||
session_api_key=sandbox.session_api_key,
|
||||
message_content=message_content,
|
||||
)
|
||||
@@ -1,25 +1,53 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from integrations.utils import (
|
||||
ENABLE_V1_GITLAB_RESOLVER,
|
||||
HOST,
|
||||
get_oh_labels,
|
||||
get_user_v1_enabled_setting,
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.agent_server.models import SendMessageRequest
|
||||
from openhands.app_server.app_conversation.app_conversation_models import (
|
||||
AppConversationStartRequest,
|
||||
AppConversationStartTaskStatus,
|
||||
)
|
||||
from openhands.app_server.config import get_app_conversation_service
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
from openhands.app_server.user.specifiy_user_context import USER_CONTEXT_ATTR
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.sdk import TextContent
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
CONFIDENTIAL_NOTE = 'confidential_note'
|
||||
NOTE_TYPES = ['note', CONFIDENTIAL_NOTE]
|
||||
|
||||
|
||||
async def is_v1_enabled_for_gitlab_resolver(user_id: str) -> bool:
|
||||
return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_GITLAB_RESOLVER
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Factory to create appriorate Gitlab view
|
||||
# =================================================
|
||||
@@ -41,6 +69,10 @@ class GitlabIssue(ResolverViewInterface):
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
is_mr: bool
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_branch_name(self) -> str | None:
|
||||
return getattr(self, 'branch_name', None)
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
gitlab_service = GitLabServiceImpl(
|
||||
@@ -78,35 +110,158 @@ class GitlabIssue(ResolverViewInterface):
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
secrets_store = SaasSecretsStore(self.user_info.keycloak_user_id, get_config())
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# v1_enabled is already set at construction time in the factory method
|
||||
# This is the source of truth for the conversation type
|
||||
if self.v1_enabled:
|
||||
# Create dummy conversation metadata
|
||||
# Don't save to conversation store
|
||||
# V1 conversations are stored in a separate table
|
||||
self.conversation_id = uuid4().hex
|
||||
return ConversationMetadata(
|
||||
conversation_id=self.conversation_id,
|
||||
selected_repository=self.full_repo_name,
|
||||
)
|
||||
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITLAB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
async def create_new_conversation(
|
||||
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
# v1_enabled is already set at construction time in the factory method
|
||||
if self.v1_enabled:
|
||||
# Use V1 app conversation service
|
||||
await self._create_v1_conversation(
|
||||
jinja_env, saas_user_auth, conversation_metadata
|
||||
)
|
||||
else:
|
||||
await self._create_v0_conversation(
|
||||
jinja_env, git_provider_tokens, conversation_metadata
|
||||
)
|
||||
|
||||
async def _create_v0_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
logger.info('[GitLab]: Creating V0 conversation')
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
agent_loop_info = await create_new_conversation(
|
||||
|
||||
await start_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions,
|
||||
image_urls=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
replay_json=None,
|
||||
conversation_id=conversation_metadata.conversation_id,
|
||||
conversation_metadata=conversation_metadata,
|
||||
conversation_instructions=conversation_instructions,
|
||||
)
|
||||
|
||||
async def _create_v1_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
saas_user_auth: UserAuth,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
logger.info('[GitLab V1]: Creating V1 conversation')
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
role='user', content=[TextContent(text=user_instructions)]
|
||||
)
|
||||
|
||||
# Create the GitLab V1 callback processor
|
||||
gitlab_callback_processor = self._create_gitlab_v1_callback_processor()
|
||||
|
||||
# Get the app conversation service and start the conversation
|
||||
injector_state = InjectorState()
|
||||
|
||||
# Determine the title based on whether it's an MR or issue
|
||||
title_prefix = 'GitLab MR' if self.is_mr else 'GitLab Issue'
|
||||
title = f'{title_prefix} #{self.issue_number}: {self.title}'
|
||||
|
||||
# Create the V1 conversation start request with the callback processor
|
||||
start_request = AppConversationStartRequest(
|
||||
conversation_id=UUID(conversation_metadata.conversation_id),
|
||||
system_message_suffix=conversation_instructions,
|
||||
initial_message=initial_message,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self._get_branch_name(),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
title=title,
|
||||
trigger=ConversationTrigger.RESOLVER,
|
||||
processors=[
|
||||
gitlab_callback_processor
|
||||
], # Pass the callback processor directly
|
||||
)
|
||||
|
||||
# Set up the GitLab user context for the V1 system
|
||||
gitlab_user_context = ResolverUserContext(saas_user_auth=saas_user_auth)
|
||||
setattr(injector_state, USER_CONTEXT_ATTR, gitlab_user_context)
|
||||
|
||||
async with get_app_conversation_service(
|
||||
injector_state
|
||||
) as app_conversation_service:
|
||||
async for task in app_conversation_service.start_app_conversation(
|
||||
start_request
|
||||
):
|
||||
if task.status == AppConversationStartTaskStatus.ERROR:
|
||||
logger.error(f'Failed to start V1 conversation: {task.detail}')
|
||||
raise RuntimeError(
|
||||
f'Failed to start V1 conversation: {task.detail}'
|
||||
)
|
||||
|
||||
def _create_gitlab_v1_callback_processor(self):
|
||||
"""Create a V1 callback processor for GitLab integration."""
|
||||
from integrations.gitlab.gitlab_v1_callback_processor import (
|
||||
GitlabV1CallbackProcessor,
|
||||
)
|
||||
|
||||
# Create and return the GitLab V1 callback processor
|
||||
return GitlabV1CallbackProcessor(
|
||||
gitlab_view_data={
|
||||
'issue_number': self.issue_number,
|
||||
'project_id': self.project_id,
|
||||
'full_repo_name': self.full_repo_name,
|
||||
'installation_id': self.installation_id,
|
||||
'keycloak_user_id': self.user_info.keycloak_user_id,
|
||||
'is_mr': self.is_mr,
|
||||
'discussion_id': getattr(self, 'discussion_id', None),
|
||||
},
|
||||
send_summary_instruction=self.send_summary_instruction,
|
||||
)
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
return self.conversation_id
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -141,6 +296,9 @@ class GitlabIssueComment(GitlabIssue):
|
||||
class GitlabMRComment(GitlabIssueComment):
|
||||
branch_name: str
|
||||
|
||||
def _get_branch_name(self) -> str | None:
|
||||
return self.branch_name
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
@@ -162,29 +320,6 @@ class GitlabMRComment(GitlabIssueComment):
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def create_new_conversation(
|
||||
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self.branch_name,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions,
|
||||
image_urls=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
replay_json=None,
|
||||
)
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
return self.conversation_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabInlineMRComment(GitlabMRComment):
|
||||
@@ -306,7 +441,7 @@ class GitlabFactory:
|
||||
@staticmethod
|
||||
async def create_gitlab_view_from_payload(
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
) -> GitlabViewType:
|
||||
payload = message.message['payload']
|
||||
installation_id = message.message['installation_id']
|
||||
user = payload['user']
|
||||
@@ -325,6 +460,16 @@ class GitlabFactory:
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
)
|
||||
|
||||
# Check v1_enabled at construction time - this is the source of truth
|
||||
v1_enabled = (
|
||||
await is_v1_enabled_for_gitlab_resolver(keycloak_user_id)
|
||||
if keycloak_user_id
|
||||
else False
|
||||
)
|
||||
logger.info(
|
||||
f'[GitLab V1]: User flag found for {keycloak_user_id} is {v1_enabled}'
|
||||
)
|
||||
|
||||
if GitlabFactory.is_labeled_issue(message):
|
||||
issue_iid = payload['object_attributes']['iid']
|
||||
|
||||
@@ -346,6 +491,7 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_issue_comment(message):
|
||||
@@ -376,6 +522,7 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message):
|
||||
@@ -408,6 +555,7 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message, inline=True):
|
||||
@@ -448,4 +596,7 @@ class GitlabFactory:
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
v1_enabled=v1_enabled,
|
||||
)
|
||||
|
||||
raise ValueError(f'Unhandled GitLab webhook event: {message}')
|
||||
|
||||
@@ -4,7 +4,9 @@ This module contains reusable functions and classes for installing GitLab webhoo
|
||||
that can be used by both the cron job and API routes.
|
||||
"""
|
||||
|
||||
from typing import cast
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
@@ -13,7 +15,9 @@ from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.service_types import GitService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
# Webhook configuration constants
|
||||
WEBHOOK_NAME = 'OpenHands Resolver'
|
||||
@@ -35,7 +39,7 @@ class BreakLoopException(Exception):
|
||||
|
||||
|
||||
async def verify_webhook_conditions(
|
||||
gitlab_service: type[GitService],
|
||||
gitlab_service: SaaSGitLabService,
|
||||
resource_type: GitLabResourceType,
|
||||
resource_id: str,
|
||||
webhook_store: GitlabWebhookStore,
|
||||
@@ -52,10 +56,6 @@ async def verify_webhook_conditions(
|
||||
webhook_store: Webhook store instance
|
||||
webhook: Webhook object to verify
|
||||
"""
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
|
||||
|
||||
# Check if resource exists
|
||||
does_resource_exist, status = await gitlab_service.check_resource_exists(
|
||||
resource_type, resource_id
|
||||
@@ -106,7 +106,9 @@ async def verify_webhook_conditions(
|
||||
does_webhook_exist_on_resource,
|
||||
status,
|
||||
) = await gitlab_service.check_webhook_exists_on_resource(
|
||||
resource_type, resource_id, GITLAB_WEBHOOK_URL
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
webhook_url=GITLAB_WEBHOOK_URL,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -131,7 +133,7 @@ async def verify_webhook_conditions(
|
||||
|
||||
|
||||
async def install_webhook_on_resource(
|
||||
gitlab_service: type[GitService],
|
||||
gitlab_service: SaaSGitLabService,
|
||||
resource_type: GitLabResourceType,
|
||||
resource_id: str,
|
||||
webhook_store: GitlabWebhookStore,
|
||||
@@ -150,10 +152,6 @@ async def install_webhook_on_resource(
|
||||
Returns:
|
||||
Tuple of (webhook_id, status)
|
||||
"""
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service = cast(type[SaaSGitLabService], gitlab_service)
|
||||
|
||||
webhook_secret = f'{webhook.user_id}-{str(uuid4())}'
|
||||
webhook_uuid = f'{str(uuid4())}'
|
||||
|
||||
@@ -167,17 +165,15 @@ async def install_webhook_on_resource(
|
||||
scopes=SCOPES,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'Creating new webhook',
|
||||
extra={
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
},
|
||||
)
|
||||
log_extra = {
|
||||
'webhook_id': webhook_id,
|
||||
'status': status,
|
||||
'resource_id': resource_id,
|
||||
'resource_type': resource_type,
|
||||
}
|
||||
|
||||
if status == WebhookStatus.RATE_LIMITED:
|
||||
logger.warning('Rate limited while creating webhook', extra=log_extra)
|
||||
raise BreakLoopException()
|
||||
|
||||
if webhook_id:
|
||||
@@ -191,9 +187,8 @@ async def install_webhook_on_resource(
|
||||
'webhook_uuid': webhook_uuid, # required to identify which webhook installation is sending payload
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'Installed webhook for {webhook.user_id} on {resource_type}:{resource_id}'
|
||||
)
|
||||
logger.info('Created new webhook', extra=log_extra)
|
||||
else:
|
||||
logger.error('Failed to create webhook', extra=log_extra)
|
||||
|
||||
return webhook_id, status
|
||||
|
||||
@@ -57,7 +57,7 @@ JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
class JiraManager(Manager):
|
||||
class JiraManager(Manager[JiraViewInterface]):
|
||||
"""Manager for processing Jira webhook events.
|
||||
|
||||
This class orchestrates the flow from webhook receipt to conversation creation,
|
||||
@@ -257,7 +257,7 @@ class JiraManager(Manager):
|
||||
|
||||
return jira_user, saas_user_auth
|
||||
|
||||
async def start_job(self, view: JiraViewInterface):
|
||||
async def start_job(self, view: JiraViewInterface) -> None:
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
@@ -341,17 +341,25 @@ class JiraManager(Manager):
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Message,
|
||||
message: str,
|
||||
issue_key: str,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
"""Send a comment to a Jira issue."""
|
||||
"""Send a comment to a Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
jira_cloud_id: The Jira Cloud ID
|
||||
svc_acc_email: Service account email for authentication
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message.message}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
@@ -366,7 +374,7 @@ class JiraManager(Manager):
|
||||
view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg),
|
||||
msg,
|
||||
issue_key=view.payload.issue_key,
|
||||
jira_cloud_id=view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=view.jira_workspace.svc_acc_email,
|
||||
@@ -388,7 +396,7 @@ class JiraManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
error_msg,
|
||||
issue_key=payload.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -212,8 +212,6 @@ class JiraPayloadParser:
|
||||
missing.append('issue.id')
|
||||
if not issue_key:
|
||||
missing.append('issue.key')
|
||||
if not user_email:
|
||||
missing.append('user.emailAddress')
|
||||
if not display_name:
|
||||
missing.append('user.displayName')
|
||||
if not account_id:
|
||||
|
||||
@@ -42,7 +42,7 @@ from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
class JiraDcManager(Manager):
|
||||
class JiraDcManager(Manager[JiraDcViewInterface]):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraDcIntegrationStore.get_instance()
|
||||
@@ -353,7 +353,7 @@ class JiraDcManager(Manager):
|
||||
logger.error(f'[Jira DC] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_dc_view: JiraDcViewInterface):
|
||||
async def start_job(self, jira_dc_view: JiraDcViewInterface) -> None:
|
||||
"""Start a Jira DC job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_dc_callback_processor import (
|
||||
@@ -418,7 +418,7 @@ class JiraDcManager(Manager):
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
msg_info,
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -456,12 +456,19 @@ class JiraDcManager(Manager):
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
self, message: str, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
):
|
||||
"""Send message/comment to Jira DC issue."""
|
||||
"""Send message/comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_key: The Jira issue key (e.g., 'PROJ-123')
|
||||
base_api_url: The base API URL for the Jira DC instance
|
||||
svc_acc_api_key: Service account API key for authentication
|
||||
"""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message.message}
|
||||
data = {'body': message}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
@@ -481,7 +488,7 @@ class JiraDcManager(Manager):
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
error_msg,
|
||||
issue_key=job_context.issue_key,
|
||||
base_api_url=job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
@@ -502,7 +509,7 @@ class JiraDcManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
comment_msg,
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -19,7 +19,7 @@ class JiraDcViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
@@ -61,7 +61,7 @@ class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -113,7 +113,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
@@ -155,6 +155,9 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
if agent_loop_info.event_store is None:
|
||||
raise StartingConvoException('Event store not available')
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
@@ -167,7 +170,7 @@ class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -39,7 +39,7 @@ from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
class LinearManager(Manager):
|
||||
class LinearManager(Manager[LinearViewInterface]):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = LinearIntegrationStore.get_instance()
|
||||
@@ -343,7 +343,7 @@ class LinearManager(Manager):
|
||||
logger.error(f'[Linear] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, linear_view: LinearViewInterface):
|
||||
async def start_job(self, linear_view: LinearViewInterface) -> None:
|
||||
"""Start a Linear job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.linear_callback_processor import (
|
||||
@@ -408,7 +408,7 @@ class LinearManager(Manager):
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
msg_info,
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
@@ -473,8 +473,14 @@ class LinearManager(Manager):
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue."""
|
||||
async def send_message(self, message: str, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string)
|
||||
issue_id: The Linear issue ID to comment on
|
||||
api_key: The Linear API key for authentication
|
||||
"""
|
||||
query = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
@@ -485,7 +491,7 @@ class LinearManager(Manager):
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||
variables = {'input': {'issueId': issue_id, 'body': message}}
|
||||
return await self._query_api(query, variables, api_key)
|
||||
|
||||
async def _send_error_comment(
|
||||
@@ -498,9 +504,7 @@ class LinearManager(Manager):
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||
)
|
||||
await self.send_message(error_msg, issue_id, api_key)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||
|
||||
@@ -517,7 +521,7 @@ class LinearManager(Manager):
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
comment_msg,
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class LinearViewInterface(ABC):
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
@@ -58,7 +58,7 @@ class LinearNewConversationView(LinearViewInterface):
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
instructions, user_msg = await self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
@@ -110,7 +110,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
@@ -152,6 +152,9 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
if agent_loop_info.event_store is None:
|
||||
raise StartingConvoException('Event store not available')
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
@@ -164,7 +167,7 @@ class LinearExistingConversationView(LinearViewInterface):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
_, user_msg = await self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
# TypeVar for view types - each manager subclass specifies its own view type
|
||||
ViewT = TypeVar('ViewT')
|
||||
|
||||
class Manager(ABC):
|
||||
|
||||
class Manager(ABC, Generic[ViewT]):
|
||||
manager_type: SourceType
|
||||
|
||||
@abstractmethod
|
||||
@@ -12,14 +16,21 @@ class Manager(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
def send_message(self, message: str, *args: Any, **kwargs: Any):
|
||||
"""Send message to integration from OpenHands server.
|
||||
|
||||
Args:
|
||||
message: The message content to send (plain text string).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
raise NotImplementedError
|
||||
async def start_job(self, view: ViewT) -> None:
|
||||
"""Kick off a job with openhands agent.
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||
Args:
|
||||
view: Integration-specific view object containing job context.
|
||||
Each manager subclass accepts its own view type
|
||||
(e.g., SlackViewInterface, JiraViewInterface, etc.)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -16,8 +17,16 @@ class SourceType(str, Enum):
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""Message model for incoming webhook payloads from integrations.
|
||||
|
||||
Note: This model is intended for INCOMING messages only.
|
||||
For outgoing messages (e.g., sending comments to GitHub/GitLab),
|
||||
pass strings directly to the send_message methods instead of
|
||||
wrapping them in a Message object.
|
||||
"""
|
||||
|
||||
source: SourceType
|
||||
message: str | dict
|
||||
message: dict[str, Any]
|
||||
ephemeral: bool = False
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from openhands.app_server.user.user_context import UserContext
|
||||
from openhands.app_server.user.user_models import UserInfo
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderHandler
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.sdk.secret import SecretSource, StaticSecret
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
@@ -14,6 +14,7 @@ class ResolverUserContext(UserContext):
|
||||
saas_user_auth: UserAuth,
|
||||
):
|
||||
self.saas_user_auth = saas_user_auth
|
||||
self._provider_handler: ProviderHandler | None = None
|
||||
|
||||
async def get_user_id(self) -> str | None:
|
||||
return await self.saas_user_auth.get_user_id()
|
||||
@@ -29,12 +30,26 @@ class ResolverUserContext(UserContext):
|
||||
|
||||
return UserInfo(id=user_id)
|
||||
|
||||
async def _get_provider_handler(self) -> ProviderHandler:
|
||||
"""Get or create a ProviderHandler for git operations."""
|
||||
if self._provider_handler is None:
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('No provider tokens available')
|
||||
user_id = await self.saas_user_auth.get_user_id()
|
||||
self._provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens, external_auth_id=user_id
|
||||
)
|
||||
return self._provider_handler
|
||||
|
||||
async def get_authenticated_git_url(
|
||||
self, repository: str, is_optional: bool = False
|
||||
) -> str:
|
||||
# This would need to be implemented based on the git provider tokens
|
||||
# For now, return a basic HTTPS URL
|
||||
return f'https://github.com/{repository}.git'
|
||||
provider_handler = await self._get_provider_handler()
|
||||
url = await provider_handler.get_authenticated_git_url(
|
||||
repository, is_optional=is_optional
|
||||
)
|
||||
return url
|
||||
|
||||
async def get_latest_token(self, provider_type: ProviderType) -> str | None:
|
||||
# Return the appropriate token string from git_provider_tokens
|
||||
|
||||
128
enterprise/integrations/slack/slack_errors.py
Normal file
128
enterprise/integrations/slack/slack_errors.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Centralized error handling for Slack integration.
|
||||
|
||||
This module provides:
|
||||
- SlackErrorCode: Unique error codes for traceability
|
||||
- SlackError: Exception class for user-facing errors
|
||||
- get_user_message(): Function to get user-facing messages for error codes
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from integrations.utils import HOST_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlackErrorCode(Enum):
|
||||
"""Unique error codes for traceability in logs and user messages."""
|
||||
|
||||
SESSION_EXPIRED = 'SLACK_ERR_001'
|
||||
REDIS_STORE_FAILED = 'SLACK_ERR_002'
|
||||
REDIS_RETRIEVE_FAILED = 'SLACK_ERR_003'
|
||||
USER_NOT_AUTHENTICATED = 'SLACK_ERR_004'
|
||||
PROVIDER_TIMEOUT = 'SLACK_ERR_005'
|
||||
PROVIDER_AUTH_FAILED = 'SLACK_ERR_006'
|
||||
LLM_AUTH_FAILED = 'SLACK_ERR_007'
|
||||
MISSING_SETTINGS = 'SLACK_ERR_008'
|
||||
UNEXPECTED_ERROR = 'SLACK_ERR_999'
|
||||
|
||||
|
||||
class SlackError(Exception):
|
||||
"""Exception for errors that should be communicated to the Slack user.
|
||||
|
||||
This exception is caught by the centralized error handler in SlackManager,
|
||||
which logs the error and sends an appropriate message to the user.
|
||||
|
||||
Usage:
|
||||
raise SlackError(SlackErrorCode.USER_NOT_AUTHENTICATED,
|
||||
message_kwargs={'login_link': link})
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code: SlackErrorCode,
|
||||
message_kwargs: dict[str, Any] | None = None,
|
||||
log_context: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Initialize a SlackError.
|
||||
|
||||
Args:
|
||||
code: The error code identifying the type of error
|
||||
message_kwargs: Kwargs for formatting the user message
|
||||
(e.g., {'login_link': '...'})
|
||||
log_context: Additional context for structured logging
|
||||
"""
|
||||
self.code = code
|
||||
self.message_kwargs = message_kwargs or {}
|
||||
self.log_context = log_context or {}
|
||||
super().__init__(f'{code.value}: {code.name}')
|
||||
|
||||
def get_user_message(self) -> str:
|
||||
"""Get the user-facing message for this error."""
|
||||
return get_user_message(self.code, **self.message_kwargs)
|
||||
|
||||
|
||||
# Centralized user-facing messages
|
||||
_USER_MESSAGES: dict[SlackErrorCode, str] = {
|
||||
SlackErrorCode.SESSION_EXPIRED: (
|
||||
'⏰ Your session has expired. '
|
||||
'Please mention me again with your request to start a new conversation.'
|
||||
),
|
||||
SlackErrorCode.REDIS_STORE_FAILED: (
|
||||
'⚠️ Something went wrong on our end (ref: {code}). '
|
||||
'Please try again in a few moments.'
|
||||
),
|
||||
SlackErrorCode.REDIS_RETRIEVE_FAILED: (
|
||||
'⚠️ Something went wrong on our end (ref: {code}). '
|
||||
'Please try again in a few moments.'
|
||||
),
|
||||
SlackErrorCode.USER_NOT_AUTHENTICATED: (
|
||||
'🔐 Please link your Slack account to OpenHands: '
|
||||
'[Click here to Login]({login_link})'
|
||||
),
|
||||
SlackErrorCode.PROVIDER_TIMEOUT: (
|
||||
'⏱️ The request timed out while connecting to your git provider. '
|
||||
'Please try again.'
|
||||
),
|
||||
SlackErrorCode.PROVIDER_AUTH_FAILED: (
|
||||
'🔐 Authentication with your git provider failed. '
|
||||
f'Please re-login at [OpenHands Cloud]({HOST_URL}) and try again.'
|
||||
),
|
||||
SlackErrorCode.LLM_AUTH_FAILED: (
|
||||
'@{username} please set a valid LLM API key in '
|
||||
f'[OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
),
|
||||
SlackErrorCode.MISSING_SETTINGS: (
|
||||
'{username} please re-login into '
|
||||
f'[OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
),
|
||||
SlackErrorCode.UNEXPECTED_ERROR: (
|
||||
'Uh oh! There was an unexpected error (ref: {code}). Please try again later.'
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def get_user_message(error_code: SlackErrorCode, **kwargs) -> str:
|
||||
"""Get a user-facing message for a given error code.
|
||||
|
||||
Args:
|
||||
error_code: The error code to get a message for
|
||||
**kwargs: Additional formatting arguments (e.g., username, login_link)
|
||||
|
||||
Returns:
|
||||
Formatted user-facing message string
|
||||
"""
|
||||
msg = _USER_MESSAGES.get(
|
||||
error_code, _USER_MESSAGES[SlackErrorCode.UNEXPECTED_ERROR]
|
||||
)
|
||||
try:
|
||||
return msg.format(code=error_code.value, **kwargs)
|
||||
except KeyError as e:
|
||||
logger.warning(
|
||||
f'Missing format key {e} in error message',
|
||||
extra={'error_code': error_code.value},
|
||||
)
|
||||
# Return a generic error message with the code for debugging
|
||||
return f'An error occurred (ref: {error_code.value}). Please try again later.'
|
||||
@@ -1,20 +1,25 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import jwt
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||
from integrations.slack.slack_errors import SlackError, SlackErrorCode
|
||||
from integrations.slack.slack_types import (
|
||||
SlackMessageView,
|
||||
SlackViewInterface,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.slack.slack_view import (
|
||||
SlackFactory,
|
||||
SlackNewConversationFromRepoFormView,
|
||||
SlackNewConversationView,
|
||||
SlackUnkownUserView,
|
||||
SlackUpdateExistingConversationView,
|
||||
)
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
get_session_expired_message,
|
||||
infer_repo_from_message,
|
||||
)
|
||||
from integrations.v1_utils import get_saas_user_auth
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
@@ -22,13 +27,18 @@ from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import config, server_config
|
||||
from openhands.integrations.service_types import (
|
||||
AuthenticationError,
|
||||
ProviderTimeoutError,
|
||||
Repository,
|
||||
)
|
||||
from openhands.server.shared import config, server_config, sio
|
||||
from openhands.server.types import (
|
||||
LLMAuthenticationError,
|
||||
MissingSettingsError,
|
||||
@@ -42,8 +52,14 @@ authorize_url_generator = AuthorizeUrlGenerator(
|
||||
user_scopes=['search:read'],
|
||||
)
|
||||
|
||||
# Key prefix for storing user messages in Redis during repo selection flow
|
||||
SLACK_USER_MSG_KEY_PREFIX = 'slack_user_msg'
|
||||
# Expiration time for stored user messages (5 minutes)
|
||||
# Arbitrary timeout based on typical user attention span; may be tuned based on feedback
|
||||
SLACK_USER_MSG_EXPIRATION = 300
|
||||
|
||||
class SlackManager(Manager):
|
||||
|
||||
class SlackManager(Manager[SlackViewInterface]):
|
||||
def __init__(self, token_manager):
|
||||
self.token_manager = token_manager
|
||||
self.login_link = (
|
||||
@@ -63,12 +79,11 @@ class SlackManager(Manager):
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(SlackUser).where(SlackUser.slack_user_id == slack_user_id)
|
||||
)
|
||||
slack_user = result.scalar_one_or_none()
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
@@ -81,18 +96,126 @@ class SlackManager(Manager):
|
||||
|
||||
return slack_user, saas_user_auth
|
||||
|
||||
def _infer_repo_from_message(self, user_msg: str) -> str | None:
|
||||
# Regular expression to match patterns like "OpenHands/OpenHands" or "deploy repo"
|
||||
pattern = r'([a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+)|([a-zA-Z0-9_-]+)(?=\s+repo)'
|
||||
match = re.search(pattern, user_msg)
|
||||
async def _store_user_msg_for_form(
|
||||
self, message_ts: str, thread_ts: str | None, user_msg: str
|
||||
) -> None:
|
||||
"""Store user message in Redis for later retrieval when form is submitted.
|
||||
|
||||
if match:
|
||||
repo = match.group(1) if match.group(1) else match.group(2)
|
||||
return repo
|
||||
This is needed because when a user selects a repo from the external_select
|
||||
dropdown, Slack sends a separate interaction payload that doesn't include
|
||||
the original user message.
|
||||
|
||||
return None
|
||||
Args:
|
||||
message_ts: The message timestamp (unique identifier)
|
||||
thread_ts: The thread timestamp (if in a thread)
|
||||
user_msg: The original user message to store
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
Raises:
|
||||
SlackError: If storage fails (REDIS_STORE_FAILED)
|
||||
"""
|
||||
key = f'{SLACK_USER_MSG_KEY_PREFIX}:{message_ts}:{thread_ts}'
|
||||
try:
|
||||
redis = sio.manager.redis
|
||||
await redis.set(key, user_msg, ex=SLACK_USER_MSG_EXPIRATION)
|
||||
logger.info(
|
||||
'slack_stored_user_msg',
|
||||
extra={
|
||||
'message_ts': message_ts,
|
||||
'thread_ts': thread_ts,
|
||||
'key': key,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'slack_store_user_msg_failed',
|
||||
extra={
|
||||
'message_ts': message_ts,
|
||||
'thread_ts': thread_ts,
|
||||
'key': key,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise SlackError(
|
||||
SlackErrorCode.REDIS_STORE_FAILED,
|
||||
log_context={'message_ts': message_ts, 'thread_ts': thread_ts},
|
||||
)
|
||||
|
||||
async def _retrieve_user_msg_for_form(
|
||||
self, message_ts: str, thread_ts: str | None
|
||||
) -> str:
|
||||
"""Retrieve stored user message from Redis.
|
||||
|
||||
Args:
|
||||
message_ts: The message timestamp
|
||||
thread_ts: The thread timestamp (if in a thread)
|
||||
|
||||
Returns:
|
||||
The stored user message
|
||||
|
||||
Raises:
|
||||
SlackError: If retrieval fails (REDIS_RETRIEVE_FAILED) or message
|
||||
not found (SESSION_EXPIRED)
|
||||
"""
|
||||
key = f'{SLACK_USER_MSG_KEY_PREFIX}:{message_ts}:{thread_ts}'
|
||||
try:
|
||||
redis = sio.manager.redis
|
||||
user_msg = await redis.get(key)
|
||||
if user_msg:
|
||||
# Redis returns bytes, decode to string
|
||||
if isinstance(user_msg, bytes):
|
||||
user_msg = user_msg.decode('utf-8')
|
||||
logger.info(
|
||||
'slack_retrieved_user_msg',
|
||||
extra={
|
||||
'message_ts': message_ts,
|
||||
'thread_ts': thread_ts,
|
||||
'key': key,
|
||||
},
|
||||
)
|
||||
return user_msg
|
||||
else:
|
||||
logger.warning(
|
||||
'slack_user_msg_not_found',
|
||||
extra={
|
||||
'message_ts': message_ts,
|
||||
'thread_ts': thread_ts,
|
||||
'key': key,
|
||||
},
|
||||
)
|
||||
raise SlackError(
|
||||
SlackErrorCode.SESSION_EXPIRED,
|
||||
log_context={'message_ts': message_ts, 'thread_ts': thread_ts},
|
||||
)
|
||||
except SlackError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'slack_retrieve_user_msg_failed',
|
||||
extra={
|
||||
'message_ts': message_ts,
|
||||
'thread_ts': thread_ts,
|
||||
'key': key,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
raise SlackError(
|
||||
SlackErrorCode.REDIS_RETRIEVE_FAILED,
|
||||
log_context={'message_ts': message_ts, 'thread_ts': thread_ts},
|
||||
)
|
||||
|
||||
async def _search_repositories(
|
||||
self, user_auth: UserAuth, query: str = '', per_page: int = 100
|
||||
) -> list[Repository]:
|
||||
"""Search repositories for a user with optional query filtering.
|
||||
|
||||
Args:
|
||||
user_auth: The user's authentication context
|
||||
query: Search query to filter repositories (empty string returns all)
|
||||
per_page: Maximum number of results to return
|
||||
|
||||
Returns:
|
||||
List of matching Repository objects
|
||||
"""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
@@ -103,31 +226,33 @@ class SlackManager(Manager):
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
repos: list[Repository] = await client.search_repositories(
|
||||
selected_provider=None,
|
||||
query=query,
|
||||
per_page=per_page,
|
||||
sort='pushed',
|
||||
order='desc',
|
||||
app_mode=server_config.app_mode,
|
||||
)
|
||||
return repos
|
||||
|
||||
def _generate_repo_selection_form(
|
||||
self, repo_list: list[Repository], message_ts: str, thread_ts: str | None
|
||||
):
|
||||
options = [
|
||||
{
|
||||
'text': {'type': 'plain_text', 'text': 'No Repository'},
|
||||
'value': '-',
|
||||
}
|
||||
]
|
||||
options.extend(
|
||||
{
|
||||
'text': {
|
||||
'type': 'plain_text',
|
||||
'text': repo.full_name,
|
||||
},
|
||||
'value': repo.full_name,
|
||||
}
|
||||
for repo in repo_list
|
||||
)
|
||||
self, message_ts: str, thread_ts: str | None
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate a repo selection form using external_select for dynamic loading.
|
||||
|
||||
This uses Slack's external_select element which allows:
|
||||
- Type-ahead search for repositories
|
||||
- Dynamic loading of options from an external endpoint
|
||||
- Support for users with many repositories (no 100 option limit)
|
||||
|
||||
Args:
|
||||
message_ts: The message timestamp for tracking
|
||||
thread_ts: The thread timestamp if in a thread
|
||||
|
||||
Returns:
|
||||
List of Slack Block Kit blocks for the selection form
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'type': 'header',
|
||||
@@ -137,158 +262,395 @@ class SlackManager(Manager):
|
||||
'emoji': True,
|
||||
},
|
||||
},
|
||||
{
|
||||
'type': 'section',
|
||||
'text': {
|
||||
'type': 'mrkdwn',
|
||||
'text': 'Type to search your repositories:',
|
||||
},
|
||||
},
|
||||
{
|
||||
'type': 'actions',
|
||||
'elements': [
|
||||
{
|
||||
'type': 'static_select',
|
||||
'type': 'external_select',
|
||||
'action_id': f'repository_select:{message_ts}:{thread_ts}',
|
||||
'options': options,
|
||||
'placeholder': {
|
||||
'type': 'plain_text',
|
||||
'text': 'Search repositories...',
|
||||
},
|
||||
'min_query_length': 0, # Load initial options immediately
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def filter_potential_repos_by_user_msg(
|
||||
self, user_msg: str, user_repos: list[Repository]
|
||||
) -> tuple[bool, list[Repository]]:
|
||||
inferred_repo = self._infer_repo_from_message(user_msg)
|
||||
if not inferred_repo:
|
||||
return False, user_repos[0:99]
|
||||
def _build_repo_options(self, repos: list[Repository]) -> list[dict[str, Any]]:
|
||||
"""Build Slack options list from repositories.
|
||||
|
||||
final_repos = []
|
||||
for repo in user_repos:
|
||||
if inferred_repo.lower() in repo.full_name.lower():
|
||||
final_repos.append(repo)
|
||||
Always includes a "No Repository" option at the top, followed by up to 99
|
||||
repositories (Slack has a 100 option limit for external_select).
|
||||
|
||||
# no repos matched, return original list
|
||||
if len(final_repos) == 0:
|
||||
return False, user_repos[0:99]
|
||||
Args:
|
||||
repos: List of Repository objects
|
||||
|
||||
# Found exact match
|
||||
elif len(final_repos) == 1:
|
||||
return True, final_repos
|
||||
Returns:
|
||||
List of Slack option objects
|
||||
"""
|
||||
options: list[dict[str, Any]] = [
|
||||
{
|
||||
'text': {'type': 'plain_text', 'text': 'No Repository'},
|
||||
'value': '-',
|
||||
}
|
||||
]
|
||||
options.extend(
|
||||
{
|
||||
'text': {
|
||||
'type': 'plain_text',
|
||||
'text': repo.full_name[:75], # Slack has 75 char limit for text
|
||||
},
|
||||
'value': repo.full_name,
|
||||
}
|
||||
for repo in repos[:99] # Leave room for "No Repository" option
|
||||
)
|
||||
return options
|
||||
|
||||
# Found partial matches
|
||||
return False, final_repos[0:99]
|
||||
async def search_repos_for_slack(
|
||||
self, user_auth: UserAuth, query: str, per_page: int = 20
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Public API for repository search with formatted Slack options.
|
||||
|
||||
This method searches for repositories and formats the results as Slack
|
||||
external_select options.
|
||||
|
||||
Args:
|
||||
user_auth: The user's authentication context
|
||||
query: Search query to filter repositories (empty string returns all)
|
||||
per_page: Maximum number of results to return (default: 20)
|
||||
|
||||
Returns:
|
||||
List of Slack option objects ready for external_select response
|
||||
"""
|
||||
repos = await self._search_repositories(
|
||||
user_auth, query=query, per_page=per_page
|
||||
)
|
||||
return self._build_repo_options(repos)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process an incoming Slack message.
|
||||
|
||||
This is the single entry point for all Slack message processing.
|
||||
All SlackErrors raised during processing are caught and handled here,
|
||||
sending appropriate error messages to the user.
|
||||
"""
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
try:
|
||||
slack_view = await self._process_message(message)
|
||||
if slack_view and await self.is_job_requested(message, slack_view):
|
||||
await self.start_job(slack_view)
|
||||
|
||||
except SlackError as e:
|
||||
await self.handle_slack_error(message.message, e)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'slack_unexpected_error',
|
||||
extra={'error': str(e), **message.message},
|
||||
)
|
||||
await self.handle_slack_error(
|
||||
message.message,
|
||||
SlackError(SlackErrorCode.UNEXPECTED_ERROR),
|
||||
)
|
||||
|
||||
async def receive_form_interaction(self, slack_payload: dict):
|
||||
"""Process a Slack form interaction (repository selection).
|
||||
|
||||
This handles the block_actions payload when a user selects a repository
|
||||
from the dropdown form. It retrieves the original user message from Redis
|
||||
and delegates to receive_message for processing.
|
||||
|
||||
Args:
|
||||
slack_payload: The raw Slack interaction payload
|
||||
"""
|
||||
# Extract fields from the Slack interaction payload
|
||||
selected_repository = slack_payload['actions'][0]['selected_option']['value']
|
||||
if selected_repository == '-':
|
||||
selected_repository = None
|
||||
|
||||
slack_user_id = slack_payload['user']['id']
|
||||
channel_id = slack_payload['container']['channel_id']
|
||||
team_id = slack_payload['team']['id']
|
||||
|
||||
# Get original message_ts and thread_ts from action_id
|
||||
attribs = slack_payload['actions'][0]['action_id'].split('repository_select:')[
|
||||
-1
|
||||
]
|
||||
message_ts, thread_ts = attribs.split(':')
|
||||
thread_ts = None if thread_ts == 'None' else thread_ts
|
||||
|
||||
# Build partial payload for error handling during Redis retrieval
|
||||
payload = {
|
||||
'team_id': team_id,
|
||||
'channel_id': channel_id,
|
||||
'slack_user_id': slack_user_id,
|
||||
'message_ts': message_ts,
|
||||
'thread_ts': thread_ts,
|
||||
}
|
||||
|
||||
# Retrieve the original user message from Redis
|
||||
try:
|
||||
user_msg = await self._retrieve_user_msg_for_form(message_ts, thread_ts)
|
||||
except SlackError as e:
|
||||
await self.handle_slack_error(payload, e)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
'slack_unexpected_error',
|
||||
extra={'error': str(e), **payload},
|
||||
)
|
||||
await self.handle_slack_error(
|
||||
payload, SlackError(SlackErrorCode.UNEXPECTED_ERROR)
|
||||
)
|
||||
return
|
||||
|
||||
# Complete the payload and delegate to receive_message
|
||||
payload['selected_repo'] = selected_repository
|
||||
payload['user_msg'] = user_msg
|
||||
|
||||
message = Message(source=SourceType.SLACK, message=payload)
|
||||
await self.receive_message(message)
|
||||
|
||||
async def _process_message(self, message: Message) -> SlackViewInterface | None:
|
||||
"""Process message and return view if authenticated, or raise SlackError.
|
||||
|
||||
Returns:
|
||||
SlackViewInterface if user is authenticated and ready to proceed,
|
||||
None if processing should stop (but no error).
|
||||
|
||||
Raises:
|
||||
SlackError: If user is not authenticated or other recoverable error.
|
||||
"""
|
||||
slack_user, saas_user_auth = await self.authenticate_user(
|
||||
slack_user_id=message.message['slack_user_id']
|
||||
)
|
||||
|
||||
try:
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
message, slack_user, saas_user_auth
|
||||
slack_view = await SlackFactory.create_slack_view_from_payload(
|
||||
message, slack_user, saas_user_auth
|
||||
)
|
||||
|
||||
# Check if this is an unauthenticated user (SlackMessageView but not SlackViewInterface)
|
||||
if not isinstance(slack_view, SlackViewInterface):
|
||||
login_link = self._generate_login_link_with_state(message)
|
||||
raise SlackError(
|
||||
SlackErrorCode.USER_NOT_AUTHENTICATED,
|
||||
message_kwargs={'login_link': login_link},
|
||||
log_context=slack_view.to_log_context(),
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
return slack_view
|
||||
|
||||
def _generate_login_link_with_state(self, message: Message) -> str:
|
||||
"""Generate OAuth login link with message state encoded."""
|
||||
jwt_secret = config.jwt_secret
|
||||
if not jwt_secret:
|
||||
raise ValueError('Must configure jwt_secret')
|
||||
state = jwt.encode(
|
||||
message.message, jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
return authorize_url_generator.generate(state)
|
||||
|
||||
async def handle_slack_error(self, payload: dict, error: SlackError) -> None:
|
||||
"""Handle a SlackError by logging and sending user message.
|
||||
|
||||
This is the centralized error handler for all SlackErrors, used by both
|
||||
the manager and routes.
|
||||
|
||||
Args:
|
||||
payload: The Slack payload dict containing channel/user info
|
||||
error: The SlackError to handle
|
||||
"""
|
||||
# Create a minimal view for sending the error message
|
||||
view = await SlackMessageView.from_payload(
|
||||
payload, self._get_slack_team_store()
|
||||
)
|
||||
|
||||
if not view:
|
||||
logger.error(
|
||||
f'[Slack]: Failed to create slack view: {e}',
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
'slack_error_no_view',
|
||||
extra={
|
||||
'error_code': error.code.value,
|
||||
**error.log_context,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(slack_view, SlackUnkownUserView):
|
||||
jwt_secret = config.jwt_secret
|
||||
if not jwt_secret:
|
||||
raise ValueError('Must configure jwt_secret')
|
||||
state = jwt.encode(
|
||||
message.message, jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
link = authorize_url_generator.generate(state)
|
||||
msg = self.login_link.format(link)
|
||||
# Log the error
|
||||
log_level = (
|
||||
'exception' if error.code == SlackErrorCode.UNEXPECTED_ERROR else 'warning'
|
||||
)
|
||||
log_data = {
|
||||
'error_code': error.code.value,
|
||||
**view.to_log_context(),
|
||||
**error.log_context,
|
||||
}
|
||||
getattr(logger, log_level)(
|
||||
f'slack_error_{error.code.name.lower()}', extra=log_data
|
||||
)
|
||||
|
||||
logger.info('slack_not_yet_authenticated')
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||
)
|
||||
return
|
||||
# Send user-facing message
|
||||
await self.send_message(error.get_user_message(), view, ephemeral=True)
|
||||
|
||||
if not await self.is_job_requested(message, slack_view):
|
||||
return
|
||||
def _get_slack_team_store(self):
|
||||
"""Get the SlackTeamStore instance (lazy import to avoid circular deps)."""
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
|
||||
await self.start_job(slack_view)
|
||||
return SlackTeamStore.get_instance()
|
||||
|
||||
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||
async def send_message(
|
||||
self,
|
||||
message: str | dict[str, Any],
|
||||
slack_view: SlackMessageView,
|
||||
ephemeral: bool = False,
|
||||
):
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content. Can be a string (for simple text) or
|
||||
a dict with 'text' and 'blocks' keys (for structured messages).
|
||||
slack_view: The Slack view object containing channel/thread info.
|
||||
Can be either SlackMessageView (for unauthenticated users)
|
||||
or SlackViewInterface (for authenticated users).
|
||||
ephemeral: If True, send as an ephemeral message visible only to the user.
|
||||
"""
|
||||
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||
if message.ephemeral and isinstance(message.message, str):
|
||||
if ephemeral and isinstance(message, str):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
markdown_text=message,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
)
|
||||
elif message.ephemeral and isinstance(message.message, dict):
|
||||
elif ephemeral and isinstance(message, dict):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
text=message.message['text'],
|
||||
blocks=message.message['blocks'],
|
||||
text=message['text'],
|
||||
blocks=message['blocks'],
|
||||
)
|
||||
else:
|
||||
await client.chat_postMessage(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
markdown_text=message,
|
||||
thread_ts=slack_view.message_ts,
|
||||
)
|
||||
|
||||
async def _try_verify_inferred_repo(
|
||||
self, slack_view: SlackNewConversationView
|
||||
) -> bool:
|
||||
"""Try to infer and verify a repository from the user's message.
|
||||
|
||||
Returns:
|
||||
True if a valid repo was found and verified, False otherwise
|
||||
"""
|
||||
user = slack_view.slack_to_openhands_user
|
||||
inferred_repos = infer_repo_from_message(slack_view.user_msg)
|
||||
|
||||
if len(inferred_repos) != 1:
|
||||
return False
|
||||
|
||||
inferred_repo = inferred_repos[0]
|
||||
logger.info(
|
||||
f'[Slack] Verifying inferred repo "{inferred_repo}" '
|
||||
f'for user {user.slack_display_name} (id={slack_view.saas_user_auth.get_user_id()})'
|
||||
)
|
||||
|
||||
try:
|
||||
provider_tokens = await slack_view.saas_user_auth.get_provider_tokens()
|
||||
if not provider_tokens:
|
||||
return False
|
||||
|
||||
access_token = await slack_view.saas_user_auth.get_access_token()
|
||||
user_id = await slack_view.saas_user_auth.get_user_id()
|
||||
provider_handler = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repo = await provider_handler.verify_repo_provider(inferred_repo)
|
||||
slack_view.selected_repo = repo.full_name
|
||||
return True
|
||||
except (AuthenticationError, ProviderTimeoutError) as e:
|
||||
logger.info(
|
||||
f'[Slack] Could not verify repo "{inferred_repo}": {e}. '
|
||||
f'Showing repository selector.'
|
||||
)
|
||||
return False
|
||||
|
||||
async def _show_repo_selection_form(
|
||||
self, slack_view: SlackNewConversationView
|
||||
) -> None:
|
||||
"""Display the repository selection form to the user.
|
||||
|
||||
Raises:
|
||||
SlackError: If storing the user message fails (REDIS_STORE_FAILED)
|
||||
"""
|
||||
user = slack_view.slack_to_openhands_user
|
||||
logger.info(
|
||||
'render_repository_selector',
|
||||
extra={
|
||||
'slack_user_id': user.slack_user_id,
|
||||
'keycloak_user_id': user.keycloak_user_id,
|
||||
'message_ts': slack_view.message_ts,
|
||||
'thread_ts': slack_view.thread_ts,
|
||||
},
|
||||
)
|
||||
|
||||
# Store the user message for later retrieval - raises SlackError on failure
|
||||
await self._store_user_msg_for_form(
|
||||
slack_view.message_ts, slack_view.thread_ts, slack_view.user_msg
|
||||
)
|
||||
|
||||
repo_selection_msg = {
|
||||
'text': 'Choose a Repository:',
|
||||
'blocks': self._generate_repo_selection_form(
|
||||
slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(repo_selection_msg, slack_view, ephemeral=True)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, slack_view: SlackViewInterface
|
||||
) -> bool:
|
||||
"""A job is always request we only receive webhooks for events associated with the slack bot
|
||||
This method really just checks
|
||||
1. Is the user is authenticated
|
||||
2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user)
|
||||
"""Determine if a job should be started based on the current context.
|
||||
|
||||
This method checks:
|
||||
1. If the view type allows immediate job start
|
||||
2. If a repo can be inferred and verified from the message
|
||||
3. Otherwise shows the repo selection form
|
||||
|
||||
Args:
|
||||
slack_view: Must be a SlackViewType (authenticated view that can start jobs)
|
||||
|
||||
Returns:
|
||||
True if job should start, False if waiting for user input
|
||||
"""
|
||||
# Infer repo from user message is not needed; user selected repo from the form or is updating existing convo
|
||||
# Check if view type allows immediate start
|
||||
if isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||
return True
|
||||
elif isinstance(slack_view, SlackNewConversationFromRepoFormView):
|
||||
if isinstance(slack_view, SlackNewConversationFromRepoFormView):
|
||||
return True
|
||||
elif isinstance(slack_view, SlackNewConversationView):
|
||||
user = slack_view.slack_to_openhands_user
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
slack_view.saas_user_auth
|
||||
)
|
||||
match, repos = self.filter_potential_repos_by_user_msg(
|
||||
slack_view.user_msg, user_repos
|
||||
)
|
||||
|
||||
# User mentioned a matching repo is their message, start job without repo selection form
|
||||
if match:
|
||||
slack_view.selected_repo = repos[0].full_name
|
||||
# For new conversations, try to infer/verify repo or show selection form
|
||||
if isinstance(slack_view, SlackNewConversationView):
|
||||
if await self._try_verify_inferred_repo(slack_view):
|
||||
return True
|
||||
await self._show_repo_selection_form(slack_view)
|
||||
|
||||
logger.info(
|
||||
'render_repository_selector',
|
||||
extra={
|
||||
'slack_user_id': user,
|
||||
'keycloak_user_id': user.keycloak_user_id,
|
||||
'message_ts': slack_view.message_ts,
|
||||
'thread_ts': slack_view.thread_ts,
|
||||
},
|
||||
)
|
||||
return False
|
||||
|
||||
repo_selection_msg = {
|
||||
'text': 'Choose a Repository:',
|
||||
'blocks': self._generate_repo_selection_form(
|
||||
repos, slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||
slack_view,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def start_job(self, slack_view: SlackViewInterface):
|
||||
async def start_job(self, slack_view: SlackViewInterface) -> None:
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.slack_callback_processor import (
|
||||
SlackCallbackProcessor,
|
||||
@@ -296,7 +658,7 @@ class SlackManager(Manager):
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
user_info: SlackUser = slack_view.slack_to_openhands_user
|
||||
user_info = slack_view.slack_to_openhands_user
|
||||
try:
|
||||
logger.info(
|
||||
f'[Slack] Starting job for user {user_info.slack_display_name} (id={user_info.slack_user_id})',
|
||||
@@ -368,9 +730,10 @@ class SlackManager(Manager):
|
||||
except StartingConvoException as e:
|
||||
msg_info = str(e)
|
||||
|
||||
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||
await self.send_message(msg_info, slack_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Slack]: Error starting job')
|
||||
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||
await self.send_message(
|
||||
'Uh oh! There was an unexpected error starting the job :(', slack_view
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.types import SummaryExtractionTracker
|
||||
from jinja2 import Environment
|
||||
@@ -7,24 +8,114 @@ from storage.slack_user import SlackUser
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
@dataclass
|
||||
class SlackMessageView:
|
||||
"""Minimal view for sending messages to Slack.
|
||||
|
||||
This class contains only the fields needed to send messages,
|
||||
without requiring user authentication. Can be used directly for
|
||||
simple message operations or as a base class for more complex views.
|
||||
"""
|
||||
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser | None
|
||||
saas_user_auth: UserAuth | None
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
team_id: str
|
||||
|
||||
def to_log_context(self) -> dict:
|
||||
"""Return dict suitable for structured logging."""
|
||||
return {
|
||||
'slack_channel_id': self.channel_id,
|
||||
'slack_user_id': self.slack_user_id,
|
||||
'slack_team_id': self.team_id,
|
||||
'slack_thread_ts': self.thread_ts,
|
||||
'slack_message_ts': self.message_ts,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
async def from_payload(
|
||||
cls,
|
||||
payload: dict,
|
||||
slack_team_store,
|
||||
) -> 'SlackMessageView | None':
|
||||
"""Create a view from a raw Slack payload.
|
||||
|
||||
This factory method handles the various payload formats from different
|
||||
Slack interactions (events, form submissions, block suggestions).
|
||||
|
||||
Args:
|
||||
payload: Raw Slack payload dictionary
|
||||
slack_team_store: Store for retrieving bot tokens
|
||||
|
||||
Returns:
|
||||
SlackMessageView if all required fields are available,
|
||||
None if required fields are missing or bot token unavailable.
|
||||
"""
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
team_id = payload.get('team', {}).get('id') or payload.get('team_id')
|
||||
channel_id = (
|
||||
payload.get('container', {}).get('channel_id')
|
||||
or payload.get('channel', {}).get('id')
|
||||
or payload.get('channel_id')
|
||||
)
|
||||
user_id = payload.get('user', {}).get('id') or payload.get('slack_user_id')
|
||||
message_ts = payload.get('message_ts', '')
|
||||
thread_ts = payload.get('thread_ts')
|
||||
|
||||
if not team_id or not channel_id or not user_id:
|
||||
logger.warning(
|
||||
'slack_message_view_from_payload_missing_fields',
|
||||
extra={
|
||||
'has_team_id': bool(team_id),
|
||||
'has_channel_id': bool(channel_id),
|
||||
'has_user_id': bool(user_id),
|
||||
'payload_keys': list(payload.keys()),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
bot_token = await slack_team_store.get_team_bot_token(team_id)
|
||||
if not bot_token:
|
||||
logger.warning(
|
||||
'slack_message_view_from_payload_no_bot_token',
|
||||
extra={'team_id': team_id},
|
||||
)
|
||||
return None
|
||||
|
||||
return cls(
|
||||
bot_access_token=bot_token,
|
||||
slack_user_id=user_id,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
|
||||
class SlackViewInterface(SlackMessageView, SummaryExtractionTracker, ABC):
|
||||
"""Interface for authenticated Slack views that can create conversations.
|
||||
|
||||
All fields are required (non-None) because this interface is only used
|
||||
for users who have linked their Slack account to OpenHands.
|
||||
|
||||
Inherits from SlackMessageView:
|
||||
bot_access_token, slack_user_id, channel_id, message_ts, thread_ts, team_id
|
||||
"""
|
||||
|
||||
user_msg: str
|
||||
slack_to_openhands_user: SlackUser
|
||||
saas_user_auth: UserAuth
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
pass
|
||||
|
||||
@@ -39,4 +130,4 @@ class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Raised when trying to send message to a conversation that's is still starting up"""
|
||||
"""Raised when trying to send message to a conversation that is still starting up."""
|
||||
|
||||
@@ -2,7 +2,8 @@ import logging
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from integrations.utils import CONVERSATION_URL, get_summary_instruction
|
||||
from integrations.utils import get_summary_instruction
|
||||
from integrations.v1_utils import handle_callback_error
|
||||
from pydantic import Field
|
||||
from slack_sdk import WebClient
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
@@ -39,7 +40,6 @@ class SlackV1CallbackProcessor(EventCallbackProcessor):
|
||||
event: Event,
|
||||
) -> EventCallbackResult | None:
|
||||
"""Process events for Slack V1 integration."""
|
||||
|
||||
# Only handle ConversationStateUpdateEvent
|
||||
if not isinstance(event, ConversationStateUpdateEvent):
|
||||
return None
|
||||
@@ -62,19 +62,14 @@ class SlackV1CallbackProcessor(EventCallbackProcessor):
|
||||
detail=summary,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.exception('[Slack V1] Error processing callback: %s', e)
|
||||
|
||||
# Only try to post error to Slack if we have basic requirements
|
||||
try:
|
||||
await self._post_summary_to_slack(
|
||||
f'OpenHands encountered an error: **{str(e)}**.\n\n'
|
||||
f'[See the conversation]({CONVERSATION_URL.format(conversation_id)})'
|
||||
'for more information.'
|
||||
)
|
||||
except Exception as post_error:
|
||||
_logger.warning(
|
||||
'[Slack V1] Failed to post error message to Slack: %s', post_error
|
||||
)
|
||||
await handle_callback_error(
|
||||
error=e,
|
||||
conversation_id=conversation_id,
|
||||
service_name='Slack',
|
||||
service_logger=_logger,
|
||||
can_post_error=True, # Slack always attempts to post errors
|
||||
post_error_func=self._post_summary_to_slack,
|
||||
)
|
||||
|
||||
return EventCallbackResult(
|
||||
status=EventCallbackResultStatus.ERROR,
|
||||
@@ -88,17 +83,18 @@ class SlackV1CallbackProcessor(EventCallbackProcessor):
|
||||
# Slack helpers
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _get_bot_access_token(self):
|
||||
async def _get_bot_access_token(self) -> str | None:
|
||||
team_id = self.slack_view_data.get('team_id')
|
||||
if team_id is None:
|
||||
return None
|
||||
slack_team_store = SlackTeamStore.get_instance()
|
||||
bot_access_token = slack_team_store.get_team_bot_token(
|
||||
self.slack_view_data['team_id']
|
||||
)
|
||||
bot_access_token = await slack_team_store.get_team_bot_token(team_id)
|
||||
|
||||
return bot_access_token
|
||||
|
||||
async def _post_summary_to_slack(self, summary: str) -> None:
|
||||
"""Post a summary message to the configured Slack channel."""
|
||||
bot_access_token = self._get_bot_access_token()
|
||||
bot_access_token = await self._get_bot_access_token()
|
||||
if not bot_access_token:
|
||||
raise RuntimeError('Missing Slack bot access token')
|
||||
|
||||
@@ -148,8 +144,8 @@ class SlackV1CallbackProcessor(EventCallbackProcessor):
|
||||
send_message_request = AskAgentRequest(question=message_content)
|
||||
|
||||
url = (
|
||||
f'{agent_server_url.rstrip("/")}'
|
||||
f'/api/conversations/{conversation_id}/ask_agent'
|
||||
f"{agent_server_url.rstrip('/')}"
|
||||
f"/api/conversations/{conversation_id}/ask_agent"
|
||||
)
|
||||
headers = {'X-Session-API-Key': session_api_key}
|
||||
payload = send_message_request.model_dump()
|
||||
@@ -211,8 +207,7 @@ class SlackV1CallbackProcessor(EventCallbackProcessor):
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
async def _request_summary(self, conversation_id: UUID) -> str:
|
||||
"""
|
||||
Ask the agent to produce a summary of its work and return the agent response.
|
||||
"""Ask the agent to produce a summary of its work and return the agent response.
|
||||
|
||||
NOTE: This method now returns a string (the agent server's response text)
|
||||
and raises exceptions on errors. The wrapping into EventCallbackResult
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.resolver_context import ResolverUserContext
|
||||
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||
from integrations.slack.slack_types import (
|
||||
SlackMessageView,
|
||||
SlackViewInterface,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.slack.slack_v1_callback_processor import SlackV1CallbackProcessor
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
@@ -42,7 +47,7 @@ from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT
|
||||
|
||||
# =================================================
|
||||
# SECTION: Slack view types
|
||||
@@ -58,37 +63,10 @@ async def is_v1_enabled_for_slack_resolver(user_id: str) -> bool:
|
||||
return await get_user_v1_enabled_setting(user_id) and ENABLE_V1_SLACK_RESOLVER
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackUnkownUserView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser | None
|
||||
saas_user_auth: UserAuth | None
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
v1_enabled: bool
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackNewConversationView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
user_msg: str
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser
|
||||
saas_user_auth: UserAuth
|
||||
@@ -118,7 +96,7 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
return block['user_id']
|
||||
return ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
@@ -242,7 +220,9 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
self, jinja: Environment, provider_tokens, user_secrets
|
||||
) -> None:
|
||||
"""Create conversation using the legacy V0 system."""
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
|
||||
# Determine git provider from repository
|
||||
git_provider = None
|
||||
@@ -273,7 +253,9 @@ class SlackNewConversationView(SlackViewInterface):
|
||||
|
||||
async def _create_v1_conversation(self, jinja: Environment) -> None:
|
||||
"""Create conversation using the new V1 app conversation system."""
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja
|
||||
)
|
||||
|
||||
# Create the initial message request
|
||||
initial_message = SendMessageRequest(
|
||||
@@ -346,7 +328,7 @@ class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
slack_conversation: SlackConversation
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
@@ -389,6 +371,9 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
if agent_loop_info.event_store is None:
|
||||
raise StartingConvoException('Event store not available')
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
@@ -401,7 +386,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
instructions, _ = self._get_instructions(jinja)
|
||||
instructions, _ = await self._get_instructions(jinja)
|
||||
user_msg = MessageAction(content=instructions)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg)
|
||||
@@ -469,7 +454,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
agent_server_url = get_agent_server_url_from_sandbox(running_sandbox)
|
||||
|
||||
# 4. Prepare the message content
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg, _ = await self._get_instructions(jinja)
|
||||
|
||||
# 5. Create the message request
|
||||
send_message_request = SendMessageRequest(
|
||||
@@ -477,7 +462,7 @@ class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
)
|
||||
|
||||
# 6. Send the message to the agent server
|
||||
url = f'{agent_server_url.rstrip("/")}/api/conversations/{UUID(self.conversation_id)}/events'
|
||||
url = f"{agent_server_url.rstrip('/')}/api/conversations/{UUID(self.conversation_id)}/events"
|
||||
|
||||
headers = {'X-Session-API-Key': running_sandbox.session_api_key}
|
||||
payload = send_message_request.model_dump()
|
||||
@@ -545,11 +530,14 @@ class SlackFactory:
|
||||
return None
|
||||
|
||||
# thread_ts in slack payloads in the parent's (root level msg's) message ID
|
||||
if channel_id is None:
|
||||
return None
|
||||
return await slack_conversation_store.get_slack_conversation(
|
||||
channel_id, thread_ts
|
||||
)
|
||||
|
||||
def create_slack_view_from_payload(
|
||||
@staticmethod
|
||||
async def create_slack_view_from_payload(
|
||||
message: Message, slack_user: SlackUser | None, saas_user_auth: UserAuth | None
|
||||
):
|
||||
payload = message.message
|
||||
@@ -560,7 +548,7 @@ class SlackFactory:
|
||||
team_id = payload['team_id']
|
||||
user_msg = payload.get('user_msg')
|
||||
|
||||
bot_access_token = slack_team_store.get_team_bot_token(team_id)
|
||||
bot_access_token = await slack_team_store.get_team_bot_token(team_id)
|
||||
if not bot_access_token:
|
||||
logger.error(
|
||||
'Did not find slack team',
|
||||
@@ -572,28 +560,27 @@ class SlackFactory:
|
||||
raise Exception('Did not find slack team')
|
||||
|
||||
# Determine if this is a known slack user by openhands
|
||||
if not slack_user or not saas_user_auth or not channel_id:
|
||||
return SlackUnkownUserView(
|
||||
# Return SlackMessageView (not SlackViewInterface) for unauthenticated users
|
||||
if not slack_user or not saas_user_auth or not channel_id or not message_ts:
|
||||
return SlackMessageView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
channel_id=channel_id or '',
|
||||
message_ts=message_ts or '',
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=False,
|
||||
send_summary_instruction=False,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
conversation: SlackConversation | None = call_async_from_sync(
|
||||
SlackFactory.determine_if_updating_existing_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
message,
|
||||
# At this point, we've verified slack_user, saas_user_auth, channel_id, and message_ts are set
|
||||
# user_msg should always be present in Slack payloads
|
||||
if not user_msg:
|
||||
raise ValueError('user_msg is required but was not provided in payload')
|
||||
assert channel_id is not None
|
||||
assert message_ts is not None
|
||||
|
||||
conversation = await asyncio.wait_for(
|
||||
SlackFactory.determine_if_updating_existing_conversation(message),
|
||||
timeout=GENERAL_TIMEOUT,
|
||||
)
|
||||
if conversation:
|
||||
logger.info(
|
||||
@@ -656,3 +643,11 @@ class SlackFactory:
|
||||
team_id=team_id,
|
||||
v1_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
# Type alias for all authenticated Slack view types that can start conversations
|
||||
SlackViewType = (
|
||||
SlackNewConversationView
|
||||
| SlackNewConversationFromRepoFormView
|
||||
| SlackUpdateExistingConversationView
|
||||
)
|
||||
|
||||
@@ -42,11 +42,11 @@ async def store_repositories_in_db(repos: list[Repository], user_id: str) -> Non
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
await repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
await user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
|
||||
@@ -3,24 +3,20 @@ from uuid import UUID
|
||||
import stripe
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.org import Org
|
||||
from storage.org_store import OrgStore
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.org_id == org_id)
|
||||
.first()
|
||||
)
|
||||
async with a_session_maker() as session:
|
||||
stmt = select(StripeCustomer).where(StripeCustomer.org_id == org_id)
|
||||
result = await session.execute(stmt)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
@@ -40,9 +36,7 @@ async def find_customer_id_by_org_id(org_id: UUID) -> str | None:
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
@@ -52,9 +46,7 @@ async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
|
||||
async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
# Get the current org for the user
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
if not org:
|
||||
logger.warning(f'Org not found for user {user_id}')
|
||||
return None
|
||||
@@ -74,7 +66,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(
|
||||
keycloak_user_id=user_id,
|
||||
@@ -82,7 +74,7 @@ async def find_or_create_customer_by_user_id(user_id: str) -> dict | None:
|
||||
stripe_customer_id=customer.id,
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
@@ -108,26 +100,27 @@ async def has_payment_method_by_user_id(user_id: str) -> bool:
|
||||
return bool(payment_methods.data)
|
||||
|
||||
|
||||
async def migrate_customer(session: Session, user_id: str, org: Org):
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
async def migrate_customer(user_id: str, org: Org):
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StripeCustomer).where(StripeCustomer.keycloak_user_id == user_id)
|
||||
)
|
||||
stripe_customer = result.scalar_one_or_none()
|
||||
if stripe_customer is None:
|
||||
return
|
||||
stripe_customer.org_id = org.id
|
||||
customer = await stripe.Customer.modify_async(
|
||||
id=stripe_customer.stripe_customer_id,
|
||||
email=org.contact_email,
|
||||
metadata={'user_id': '', 'org_id': str(org.id)},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
'migrated_customer',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org.id),
|
||||
'stripe_customer_id': customer.id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from integrations.models import Message
|
||||
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationMetadata
|
||||
|
||||
|
||||
class GitLabResourceType(Enum):
|
||||
GROUP = 'group'
|
||||
@@ -31,17 +39,41 @@ class SummaryExtractionTracker:
|
||||
|
||||
@dataclass
|
||||
class ResolverViewInterface(SummaryExtractionTracker):
|
||||
installation_id: int
|
||||
# installation_id type varies by provider:
|
||||
# - GitHub: int (GitHub App installation ID)
|
||||
# - GitLab: str (webhook installation ID from our DB)
|
||||
installation_id: int | str
|
||||
user_info: UserData
|
||||
issue_number: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
raw_payload: dict
|
||||
raw_payload: 'Message'
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_new_conversation(self, jinja_env: Environment, token: str):
|
||||
"Create a new conversation"
|
||||
async def initialize_new_conversation(self) -> 'ConversationMetadata':
|
||||
"""Initialize a new conversation and return metadata.
|
||||
|
||||
For V1 conversations, creates a dummy ConversationMetadata.
|
||||
For V0 conversations, initializes through the conversation store.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_new_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: 'PROVIDER_TOKEN_TYPE',
|
||||
conversation_metadata: 'ConversationMetadata',
|
||||
saas_user_auth: 'UserAuth',
|
||||
) -> None:
|
||||
"""Create a new conversation.
|
||||
|
||||
Args:
|
||||
jinja_env: Jinja2 environment for template rendering
|
||||
git_provider_tokens: Token mapping for git providers
|
||||
conversation_metadata: Metadata for the conversation
|
||||
saas_user_auth: User authentication for SaaS
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -21,7 +21,6 @@ from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
@@ -33,7 +32,8 @@ if TYPE_CHECKING:
|
||||
HOST = WEB_HOST
|
||||
# ---- DO NOT REMOVE ----
|
||||
|
||||
HOST_URL = f'https://{HOST}' if 'localhost' not in HOST else f'http://{HOST}'
|
||||
IS_LOCAL_DEPLOYMENT = 'localhost' in HOST
|
||||
HOST_URL = f'https://{HOST}' if not IS_LOCAL_DEPLOYMENT else f'http://{HOST}'
|
||||
GITHUB_WEBHOOK_URL = f'{HOST_URL}/integration/github/events'
|
||||
GITLAB_WEBHOOK_URL = f'{HOST_URL}/integration/gitlab/events'
|
||||
conversation_prefix = 'conversations/{}'
|
||||
@@ -65,6 +65,25 @@ def get_session_expired_message(username: str | None = None) -> str:
|
||||
return f'Your session has expired. Please login again at [OpenHands Cloud]({HOST_URL}) and try again.'
|
||||
|
||||
|
||||
def get_user_not_found_message(username: str | None = None) -> str:
|
||||
"""Get a user-friendly message when a user hasn't created an OpenHands account.
|
||||
|
||||
Used by integrations to notify users when they try to use OpenHands features
|
||||
but haven't logged into OpenHands Cloud yet (no Keycloak account exists).
|
||||
|
||||
Args:
|
||||
username: Optional username to mention in the message. If provided,
|
||||
the message will include @username prefix (used by Git providers
|
||||
like GitHub, GitLab, Slack). If None, returns a generic message.
|
||||
|
||||
Returns:
|
||||
A formatted user not found message
|
||||
"""
|
||||
if username:
|
||||
return f"@{username} it looks like you haven't created an OpenHands account yet. Please sign up at [OpenHands Cloud]({HOST_URL}) and try again."
|
||||
return f"It looks like you haven't created an OpenHands account yet. Please sign up at [OpenHands Cloud]({HOST_URL}) and try again."
|
||||
|
||||
|
||||
# Toggle for solvability report feature
|
||||
ENABLE_SOLVABILITY_ANALYSIS = (
|
||||
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||
@@ -79,6 +98,11 @@ ENABLE_V1_SLACK_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_SLACK_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for V1 GitLab resolver feature
|
||||
ENABLE_V1_GITLAB_RESOLVER = (
|
||||
os.getenv('ENABLE_V1_GITLAB_RESOLVER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR = (
|
||||
os.getenv('OPENHANDS_RESOLVER_TEMPLATES_DIR')
|
||||
or 'openhands/integrations/templates/resolver/'
|
||||
@@ -122,9 +146,7 @@ async def get_user_v1_enabled_setting(user_id: str | None) -> bool:
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
org = await call_sync_from_async(
|
||||
OrgStore.get_current_org_from_keycloak_user_id, user_id
|
||||
)
|
||||
org = await OrgStore.get_current_org_from_keycloak_user_id(user_id)
|
||||
|
||||
if not org or org.v1_enabled is None:
|
||||
return False
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
import logging
|
||||
from typing import Callable, Coroutine
|
||||
from uuid import UUID
|
||||
|
||||
from integrations.utils import CONVERSATION_URL
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.auth.token_manager import TokenManager
|
||||
@@ -6,6 +11,78 @@ from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
def is_budget_exceeded_error(error_message: str) -> bool:
|
||||
"""Check if an error message indicates a budget exceeded condition.
|
||||
|
||||
This is used to downgrade error logs to info logs for budget exceeded errors
|
||||
since they are expected cost control behavior rather than unexpected errors.
|
||||
"""
|
||||
lower_message = error_message.lower()
|
||||
return 'budget' in lower_message and 'exceeded' in lower_message
|
||||
|
||||
|
||||
BUDGET_EXCEEDED_USER_MESSAGE = 'LLM budget has been exceeded, please re-fill.'
|
||||
|
||||
|
||||
async def handle_callback_error(
|
||||
error: Exception,
|
||||
conversation_id: UUID,
|
||||
service_name: str,
|
||||
service_logger: logging.Logger,
|
||||
can_post_error: bool,
|
||||
post_error_func: Callable[[str], Coroutine],
|
||||
) -> None:
|
||||
"""Handle callback processing errors with appropriate logging and user messages.
|
||||
|
||||
This centralizes the error handling logic for V1 callback processors to:
|
||||
- Log budget exceeded errors at INFO level (expected cost control behavior)
|
||||
- Log other errors at EXCEPTION level
|
||||
- Post user-friendly error messages to the integration platform
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
conversation_id: The conversation ID for logging and linking
|
||||
service_name: The service name for log messages (e.g., "GitHub", "GitLab", "Slack")
|
||||
service_logger: The logger instance to use for logging
|
||||
can_post_error: Whether the prerequisites are met to post an error message
|
||||
post_error_func: Async function to post the error message to the platform
|
||||
"""
|
||||
error_str = str(error)
|
||||
budget_exceeded = is_budget_exceeded_error(error_str)
|
||||
|
||||
# Log appropriately based on error type
|
||||
if budget_exceeded:
|
||||
service_logger.info(
|
||||
'[%s V1] Budget exceeded for conversation %s: %s',
|
||||
service_name,
|
||||
conversation_id,
|
||||
error,
|
||||
)
|
||||
else:
|
||||
service_logger.exception(
|
||||
'[%s V1] Error processing callback: %s', service_name, error
|
||||
)
|
||||
|
||||
# Try to post error message to the platform
|
||||
if can_post_error:
|
||||
try:
|
||||
error_detail = (
|
||||
BUDGET_EXCEEDED_USER_MESSAGE if budget_exceeded else error_str
|
||||
)
|
||||
await post_error_func(
|
||||
f'OpenHands encountered an error: **{error_detail}**\n\n'
|
||||
f'[See the conversation]({CONVERSATION_URL.format(conversation_id)}) '
|
||||
'for more information.'
|
||||
)
|
||||
except Exception as post_error:
|
||||
service_logger.warning(
|
||||
'[%s V1] Failed to post error message to %s: %s',
|
||||
service_name,
|
||||
service_name,
|
||||
post_error,
|
||||
)
|
||||
|
||||
|
||||
async def get_saas_user_auth(
|
||||
keycloak_user_id: str, token_manager: TokenManager
|
||||
) -> UserAuth:
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import logging
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
# Suppress alembic.runtime.plugins INFO logs during import to prevent non-JSON logs in production
|
||||
# These plugin setup messages would otherwise appear before logging is configured
|
||||
logging.getLogger('alembic.runtime.plugins').setLevel(logging.WARNING)
|
||||
|
||||
from alembic import context # noqa: E402
|
||||
from google.cloud.sql.connector import Connector # noqa: E402
|
||||
from sqlalchemy import create_engine # noqa: E402
|
||||
from storage.base import Base # noqa: E402
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Add byor_export_enabled flag to org table.
|
||||
|
||||
Revision ID: 091
|
||||
Revises: 090
|
||||
Create Date: 2025-01-15 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '091'
|
||||
down_revision: Union[str, None] = '090'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add byor_export_enabled column to org table with default false
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'byor_export_enabled',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
# Set byor_export_enabled to true for orgs that have completed billing sessions
|
||||
op.execute(
|
||||
sa.text("""
|
||||
UPDATE org SET byor_export_enabled = TRUE
|
||||
WHERE id IN (
|
||||
SELECT DISTINCT org_id FROM billing_sessions
|
||||
WHERE status = 'completed' AND org_id IS NOT NULL
|
||||
)
|
||||
""")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'byor_export_enabled')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""Rename 'user' role to 'member' in role table.
|
||||
|
||||
Revision ID: 092
|
||||
Revises: 091
|
||||
Create Date: 2025-02-12 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '092'
|
||||
down_revision: Union[str, None] = '091'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Rename 'user' role to 'member' for clarity
|
||||
# This avoids confusion between the 'user' role and the 'user' entity/account
|
||||
op.execute(sa.text("UPDATE role SET name = 'member' WHERE name = 'user'"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert 'member' role back to 'user'
|
||||
op.execute(sa.text("UPDATE role SET name = 'user' WHERE name = 'member'"))
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Add pending_free_credits flag to org table.
|
||||
|
||||
Revision ID: 093
|
||||
Revises: 092
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '093'
|
||||
down_revision: Union[str, None] = '092'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add pending_free_credits column to org table with default false.
|
||||
# New orgs will have this set to TRUE at creation time.
|
||||
# Existing orgs default to FALSE (not eligible - they already got $10 at signup).
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
@@ -0,0 +1,110 @@
|
||||
"""create org_invitation table
|
||||
|
||||
Revision ID: 094
|
||||
Revises: 093
|
||||
Create Date: 2026-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '094'
|
||||
down_revision: Union[str, None] = '093'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create org_invitation table
|
||||
op.create_table(
|
||||
'org_invitation',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('token', sa.String(64), nullable=False),
|
||||
sa.Column('org_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column('email', sa.String(255), nullable=False),
|
||||
sa.Column('role_id', sa.Integer, nullable=False),
|
||||
sa.Column('inviter_id', postgresql.UUID(as_uuid=True), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.String(20),
|
||||
nullable=False,
|
||||
server_default=sa.text("'pending'"),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('expires_at', sa.DateTime, nullable=False),
|
||||
sa.Column('accepted_at', sa.DateTime, nullable=True),
|
||||
sa.Column('accepted_by_user_id', postgresql.UUID(as_uuid=True), nullable=True),
|
||||
# Foreign key constraints
|
||||
sa.ForeignKeyConstraint(
|
||||
['org_id'],
|
||||
['org.id'],
|
||||
name='org_invitation_org_fkey',
|
||||
ondelete='CASCADE',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['role_id'],
|
||||
['role.id'],
|
||||
name='org_invitation_role_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['inviter_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_inviter_fkey',
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
['accepted_by_user_id'],
|
||||
['user.id'],
|
||||
name='org_invitation_accepter_fkey',
|
||||
),
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index(
|
||||
'ix_org_invitation_token',
|
||||
'org_invitation',
|
||||
['token'],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_id',
|
||||
'org_invitation',
|
||||
['org_id'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_email',
|
||||
'org_invitation',
|
||||
['email'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_org_invitation_status',
|
||||
'org_invitation',
|
||||
['status'],
|
||||
)
|
||||
# Composite index for checking pending invitations
|
||||
op.create_index(
|
||||
'ix_org_invitation_org_email_status',
|
||||
'org_invitation',
|
||||
['org_id', 'email', 'status'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop indexes
|
||||
op.drop_index('ix_org_invitation_org_email_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_status', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_email', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_org_id', table_name='org_invitation')
|
||||
op.drop_index('ix_org_invitation_token', table_name='org_invitation')
|
||||
|
||||
# Drop table
|
||||
op.drop_table('org_invitation')
|
||||
@@ -0,0 +1,37 @@
|
||||
"""Drop pending_free_credits column from org table.
|
||||
|
||||
Revision ID: 095
|
||||
Revises: 094
|
||||
Create Date: 2025-02-18 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '095'
|
||||
down_revision: Union[str, None] = '094'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop the pending_free_credits column from org table.
|
||||
# This column was used for tracking free credit eligibility but is no longer needed.
|
||||
op.drop_column('org', 'pending_free_credits')
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Re-add pending_free_credits column with default false.
|
||||
op.add_column(
|
||||
'org',
|
||||
sa.Column(
|
||||
'pending_free_credits',
|
||||
sa.Boolean,
|
||||
nullable=False,
|
||||
server_default=sa.text('false'),
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Create resend_synced_users table.
|
||||
|
||||
Revision ID: 096
|
||||
Revises: 095
|
||||
Create Date: 2025-02-17 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '096'
|
||||
down_revision: Union[str, None] = '095'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create resend_synced_users table for tracking users synced to Resend audiences."""
|
||||
op.create_table(
|
||||
'resend_synced_users',
|
||||
sa.Column(
|
||||
'id',
|
||||
sa.UUID(as_uuid=True),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('audience_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'synced_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint(
|
||||
'email', 'audience_id', name='uq_resend_synced_email_audience'
|
||||
),
|
||||
)
|
||||
|
||||
# Create index on email for fast lookups
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_email',
|
||||
'resend_synced_users',
|
||||
['email'],
|
||||
)
|
||||
|
||||
# Create index on audience_id for filtering by audience
|
||||
op.create_index(
|
||||
'ix_resend_synced_users_audience_id',
|
||||
'resend_synced_users',
|
||||
['audience_id'],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop resend_synced_users table."""
|
||||
op.drop_index(
|
||||
'ix_resend_synced_users_audience_id', table_name='resend_synced_users'
|
||||
)
|
||||
op.drop_index('ix_resend_synced_users_email', table_name='resend_synced_users')
|
||||
op.drop_table('resend_synced_users')
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Add session_api_key_hash to v1_remote_sandbox table
|
||||
|
||||
Revision ID: 097
|
||||
Revises: 096
|
||||
Create Date: 2025-02-24 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '097'
|
||||
down_revision: Union[str, None] = '096'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add session_api_key_hash column to v1_remote_sandbox table."""
|
||||
op.add_column(
|
||||
'v1_remote_sandbox',
|
||||
sa.Column('session_api_key_hash', sa.String(), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
'v1_remote_sandbox',
|
||||
['session_api_key_hash'],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove session_api_key_hash column from v1_remote_sandbox table."""
|
||||
op.drop_index(
|
||||
op.f('ix_v1_remote_sandbox_session_api_key_hash'),
|
||||
table_name='v1_remote_sandbox',
|
||||
)
|
||||
op.drop_column('v1_remote_sandbox', 'session_api_key_hash')
|
||||
@@ -0,0 +1,92 @@
|
||||
"""Create verified_models table.
|
||||
|
||||
Revision ID: 098
|
||||
Revises: 097
|
||||
Create Date: 2026-02-26 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '098'
|
||||
down_revision: Union[str, None] = '097'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create verified_models table and seed with current model list."""
|
||||
op.create_table(
|
||||
'verified_models',
|
||||
sa.Column('id', sa.Integer, sa.Identity(), primary_key=True),
|
||||
sa.Column('model_name', sa.String(255), nullable=False),
|
||||
sa.Column('provider', sa.String(100), nullable=False),
|
||||
sa.Column(
|
||||
'is_enabled',
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text('true'),
|
||||
),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
'model_name', 'provider', name='uq_verified_model_provider'
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
'ix_verified_models_provider',
|
||||
'verified_models',
|
||||
['provider'],
|
||||
)
|
||||
op.create_index(
|
||||
'ix_verified_models_is_enabled',
|
||||
'verified_models',
|
||||
['is_enabled'],
|
||||
)
|
||||
|
||||
# Seed with current openhands provider models
|
||||
models = [
|
||||
('claude-opus-4-5-20251101', 'openhands'),
|
||||
('claude-sonnet-4-5-20250929', 'openhands'),
|
||||
('gpt-5.2-codex', 'openhands'),
|
||||
('gpt-5.2', 'openhands'),
|
||||
('minimax-m2.5', 'openhands'),
|
||||
('gemini-3-pro-preview', 'openhands'),
|
||||
('gemini-3-flash-preview', 'openhands'),
|
||||
('deepseek-chat', 'openhands'),
|
||||
('devstral-medium-2512', 'openhands'),
|
||||
('kimi-k2-0711-preview', 'openhands'),
|
||||
('qwen3-coder-480b', 'openhands'),
|
||||
]
|
||||
|
||||
for model_name, provider in models:
|
||||
op.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO verified_models (model_name, provider)
|
||||
VALUES (:model_name, :provider)
|
||||
"""
|
||||
).bindparams(model_name=model_name, provider=provider)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop verified_models table."""
|
||||
op.drop_index('ix_verified_models_is_enabled', table_name='verified_models')
|
||||
op.drop_index('ix_verified_models_provider', table_name='verified_models')
|
||||
op.drop_table('verified_models')
|
||||
@@ -0,0 +1,136 @@
|
||||
"""Create user_authorizations table and migrate blocked_email_domains
|
||||
|
||||
Revision ID: 099
|
||||
Revises: 098
|
||||
Create Date: 2025-03-05 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '099'
|
||||
down_revision: Union[str, None] = '098'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def _seed_from_environment() -> None:
|
||||
"""Seed user_authorizations table from environment variables.
|
||||
|
||||
Reads EMAIL_PATTERN_BLACKLIST and EMAIL_PATTERN_WHITELIST environment variables.
|
||||
Each should be a comma-separated list of SQL LIKE patterns (e.g., '%@example.com').
|
||||
|
||||
If the environment variables are not set or empty, this function does nothing.
|
||||
|
||||
This allows us to set up feature deployments with particular patterns already
|
||||
blacklisted or whitelisted. (For example, you could blacklist everything with
|
||||
`%`, and then whitelist certain email accounts.)
|
||||
"""
|
||||
blacklist_patterns = os.environ.get('EMAIL_PATTERN_BLACKLIST', '').strip()
|
||||
whitelist_patterns = os.environ.get('EMAIL_PATTERN_WHITELIST', '').strip()
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
if blacklist_patterns:
|
||||
for pattern in blacklist_patterns.split(','):
|
||||
pattern = pattern.strip()
|
||||
if pattern:
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
INSERT INTO user_authorizations
|
||||
(email_pattern, provider_type, type)
|
||||
VALUES
|
||||
(:pattern, NULL, 'blacklist')
|
||||
"""),
|
||||
{'pattern': pattern},
|
||||
)
|
||||
|
||||
if whitelist_patterns:
|
||||
for pattern in whitelist_patterns.split(','):
|
||||
pattern = pattern.strip()
|
||||
if pattern:
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
INSERT INTO user_authorizations
|
||||
(email_pattern, provider_type, type)
|
||||
VALUES
|
||||
(:pattern, NULL, 'whitelist')
|
||||
"""),
|
||||
{'pattern': pattern},
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Create user_authorizations table, migrate data, and drop blocked_email_domains."""
|
||||
# Create user_authorizations table
|
||||
op.create_table(
|
||||
'user_authorizations',
|
||||
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||
sa.Column('email_pattern', sa.String(), nullable=True),
|
||||
sa.Column('provider_type', sa.String(), nullable=True),
|
||||
sa.Column('type', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'created_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.Column(
|
||||
'updated_at',
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||
),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
# Create index on email_pattern for efficient LIKE queries
|
||||
op.create_index(
|
||||
'ix_user_authorizations_email_pattern',
|
||||
'user_authorizations',
|
||||
['email_pattern'],
|
||||
)
|
||||
|
||||
# Create index on type for efficient filtering
|
||||
op.create_index(
|
||||
'ix_user_authorizations_type',
|
||||
'user_authorizations',
|
||||
['type'],
|
||||
)
|
||||
|
||||
# Migrate existing blocked_email_domains to user_authorizations as blacklist entries
|
||||
# The domain patterns are converted to SQL LIKE patterns:
|
||||
# - 'example.com' becomes '%@example.com' (matches user@example.com)
|
||||
# - '.us' becomes '%@%.us' (matches user@anything.us)
|
||||
# We also add '%.' prefix for subdomain matching
|
||||
op.execute("""
|
||||
INSERT INTO user_authorizations (email_pattern, provider_type, type, created_at, updated_at)
|
||||
SELECT
|
||||
CASE
|
||||
WHEN domain LIKE '.%' THEN '%' || domain
|
||||
ELSE '%@%' || domain
|
||||
END as email_pattern,
|
||||
NULL as provider_type,
|
||||
'blacklist' as type,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM blocked_email_domains
|
||||
""")
|
||||
|
||||
# Seed additional patterns from environment variables (if set)
|
||||
_seed_from_environment()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Recreate blocked_email_domains table and migrate data back."""
|
||||
# Drop user_authorizations table
|
||||
op.drop_index('ix_user_authorizations_type', table_name='user_authorizations')
|
||||
op.drop_index(
|
||||
'ix_user_authorizations_email_pattern', table_name='user_authorizations'
|
||||
)
|
||||
op.drop_table('user_authorizations')
|
||||
1100
enterprise/poetry.lock
generated
1100
enterprise/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,6 @@ packages = [
|
||||
{ include = "storage" },
|
||||
{ include = "sync" },
|
||||
{ include = "integrations" },
|
||||
{ include = "experiments" },
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
@@ -44,6 +43,12 @@ httpx = "*"
|
||||
scikit-learn = "^1.7.0"
|
||||
shap = "^0.48.0"
|
||||
google-cloud-recaptcha-enterprise = "^1.24.0"
|
||||
# Dependencies previously only in Dockerfile, now managed via poetry.lock
|
||||
prometheus-client = "^0.24.0"
|
||||
pandas = "^2.2.0"
|
||||
numpy = "^2.2.0"
|
||||
mcp = "^1.10.0"
|
||||
pillow = "^12.1.1"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ruff = "0.8.3"
|
||||
|
||||
@@ -21,11 +21,12 @@ async def main():
|
||||
|
||||
|
||||
def set_stale_task_error():
|
||||
# started_at is naive UTC; strip tzinfo before comparing.
|
||||
cutoff = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta(hours=1)
|
||||
with session_maker() as session:
|
||||
session.query(MaintenanceTask).filter(
|
||||
MaintenanceTask.status == MaintenanceTaskStatus.WORKING,
|
||||
MaintenanceTask.started_at
|
||||
< datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
MaintenanceTask.started_at < cutoff,
|
||||
).update({MaintenanceTask.status: MaintenanceTaskStatus.ERROR})
|
||||
session.commit()
|
||||
|
||||
@@ -37,9 +38,10 @@ async def run_tasks():
|
||||
if not task:
|
||||
return
|
||||
|
||||
# Update the status
|
||||
# started_at/updated_at are naive UTC; strip tzinfo.
|
||||
now_utc = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
task.status = MaintenanceTaskStatus.WORKING
|
||||
task.updated_at = task.started_at = datetime.now(timezone.utc)
|
||||
task.updated_at = task.started_at = now_utc
|
||||
session.commit()
|
||||
|
||||
try:
|
||||
|
||||
@@ -14,6 +14,7 @@ from fastapi.middleware.cors import CORSMiddleware # noqa: E402
|
||||
from fastapi.responses import JSONResponse # noqa: E402
|
||||
from server.auth.auth_error import ExpiredError, NoCredentialsError # noqa: E402
|
||||
from server.auth.constants import ( # noqa: E402
|
||||
BITBUCKET_DATA_CENTER_HOST,
|
||||
ENABLE_JIRA,
|
||||
ENABLE_JIRA_DC,
|
||||
ENABLE_LINEAR,
|
||||
@@ -27,7 +28,6 @@ from server.rate_limit import setup_rate_limit_handler # noqa: E402
|
||||
from server.routes.api_keys import api_router as api_keys_router # noqa: E402
|
||||
from server.routes.auth import api_router, oauth_router # noqa: E402
|
||||
from server.routes.billing import billing_router # noqa: E402
|
||||
from server.routes.debugging import add_debugging_routes # noqa: E402
|
||||
from server.routes.email import api_router as email_router # noqa: E402
|
||||
from server.routes.event_webhook import event_webhook_router # noqa: E402
|
||||
from server.routes.feedback import router as feedback_router # noqa: E402
|
||||
@@ -38,15 +38,28 @@ from server.routes.integration.linear import linear_integration_router # noqa:
|
||||
from server.routes.integration.slack import slack_router # noqa: E402
|
||||
from server.routes.mcp_patch import patch_mcp_server # noqa: E402
|
||||
from server.routes.oauth_device import oauth_device_router # noqa: E402
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
accept_router as invitation_accept_router,
|
||||
)
|
||||
from server.routes.org_invitations import ( # noqa: E402
|
||||
invitation_router,
|
||||
)
|
||||
from server.routes.orgs import org_router # noqa: E402
|
||||
from server.routes.readiness import readiness_router # noqa: E402
|
||||
from server.routes.user import saas_user_router # noqa: E402
|
||||
from server.routes.user_app_settings import user_app_settings_router # noqa: E402
|
||||
from server.sharing.shared_conversation_router import ( # noqa: E402
|
||||
router as shared_conversation_router,
|
||||
)
|
||||
from server.sharing.shared_event_router import ( # noqa: E402
|
||||
router as shared_event_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
api_router as verified_models_router,
|
||||
)
|
||||
from server.verified_models.verified_model_router import ( # noqa: E402
|
||||
override_llm_models_dependency,
|
||||
)
|
||||
|
||||
from openhands.server.app import app as base_app # noqa: E402
|
||||
from openhands.server.listen_socket import sio # noqa: E402
|
||||
@@ -70,6 +83,7 @@ base_app.include_router(api_router) # Add additional route for github auth
|
||||
base_app.include_router(oauth_router) # Add additional route for oauth callback
|
||||
base_app.include_router(oauth_device_router) # Add OAuth 2.0 Device Flow routes
|
||||
base_app.include_router(saas_user_router) # Add additional route SAAS user calls
|
||||
base_app.include_router(user_app_settings_router) # Add routes for user app settings
|
||||
base_app.include_router(
|
||||
billing_router
|
||||
) # Add routes for credit management and Stripe payment integration
|
||||
@@ -78,8 +92,15 @@ base_app.include_router(shared_event_router)
|
||||
|
||||
# Add GitHub integration router only if GITHUB_APP_CLIENT_ID is set
|
||||
if GITHUB_APP_CLIENT_ID:
|
||||
# Make sure that the callback processor is loaded here so we don't get an error when deserializing
|
||||
from integrations.github.github_v1_callback_processor import ( # noqa: E402
|
||||
GithubV1CallbackProcessor,
|
||||
)
|
||||
from server.routes.integration.github import github_integration_router # noqa: E402
|
||||
|
||||
# Bludgeon mypy into not deleting my import
|
||||
logger.debug(f'Loaded {GithubV1CallbackProcessor.__name__}')
|
||||
|
||||
base_app.include_router(
|
||||
github_integration_router
|
||||
) # Add additional route for integration webhook events
|
||||
@@ -92,10 +113,17 @@ if GITLAB_APP_CLIENT_ID:
|
||||
|
||||
base_app.include_router(api_keys_router) # Add routes for API key management
|
||||
base_app.include_router(org_router) # Add routes for organization management
|
||||
base_app.include_router(
|
||||
verified_models_router
|
||||
) # Add routes for verified models management
|
||||
|
||||
# Override the default LLM models implementation with SaaS version
|
||||
# This must happen after all routers are included
|
||||
override_llm_models_dependency(base_app)
|
||||
|
||||
base_app.include_router(invitation_router) # Add routes for org invitation management
|
||||
base_app.include_router(invitation_accept_router) # Add route for accepting invitations
|
||||
add_github_proxy_routes(base_app)
|
||||
add_debugging_routes(
|
||||
base_app
|
||||
) # Add diagnostic routes for testing and debugging (disabled in production)
|
||||
base_app.include_router(slack_router)
|
||||
if ENABLE_JIRA:
|
||||
base_app.include_router(jira_integration_router)
|
||||
@@ -103,6 +131,12 @@ if ENABLE_JIRA_DC:
|
||||
base_app.include_router(jira_dc_integration_router)
|
||||
if ENABLE_LINEAR:
|
||||
base_app.include_router(linear_integration_router)
|
||||
if BITBUCKET_DATA_CENTER_HOST:
|
||||
from server.routes.bitbucket_dc_proxy import (
|
||||
router as bitbucket_dc_proxy_router, # noqa: E402
|
||||
)
|
||||
|
||||
base_app.include_router(bitbucket_dc_proxy_router)
|
||||
base_app.include_router(email_router) # Add routes for email management
|
||||
base_app.include_router(feedback_router) # Add routes for conversation feedback
|
||||
base_app.include_router(
|
||||
|
||||
@@ -38,3 +38,9 @@ class ExpiredError(AuthError):
|
||||
"""Error when a token has expired (Usually the refresh token)"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TokenRefreshError(AuthError):
|
||||
"""Error when token refresh fails due to timeout or lock contention"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
import os
|
||||
|
||||
from server.auth.sheets_client import GoogleSheetsClient
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured."""
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if not waitlist:
|
||||
logger.debug('GITHUB_USER_LIST_FILE not configured')
|
||||
return
|
||||
|
||||
if not os.path.exists(waitlist):
|
||||
logger.error(f'User list file not found: {waitlist}')
|
||||
raise FileNotFoundError(f'User list file not found: {waitlist}')
|
||||
|
||||
try:
|
||||
with open(waitlist, 'r') as f:
|
||||
self.file_users = [line.strip().lower() for line in f if line.strip()]
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f'Error reading user list file {waitlist}')
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured."""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.debug('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.debug('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration."""
|
||||
logger.debug(f'Checking if GitHub user {username} is allowed')
|
||||
if self.file_users:
|
||||
if username.lower() in self.file_users:
|
||||
logger.debug(f'User {username} found in text file allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = [
|
||||
u.lower() for u in self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
]
|
||||
if username.lower() in sheet_users:
|
||||
logger.debug(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
user_verifier = UserVerifier()
|
||||
281
enterprise/server/auth/authorization.py
Normal file
281
enterprise/server/auth/authorization.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Permission-based authorization dependencies for API endpoints.
|
||||
|
||||
This module provides FastAPI dependencies for checking user permissions
|
||||
within organizations. It uses a permission-based authorization model where
|
||||
roles (owner, admin, member) are mapped to specific permissions.
|
||||
|
||||
Permissions are defined in the Permission enum and mapped to roles via
|
||||
ROLE_PERMISSIONS. This allows fine-grained access control while maintaining
|
||||
the familiar role-based hierarchy.
|
||||
|
||||
Usage:
|
||||
from server.auth.authorization import (
|
||||
Permission,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with VIEW_LLM_SETTINGS permission can access
|
||||
...
|
||||
|
||||
@router.patch('/{org_id}/settings')
|
||||
async def update_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.EDIT_LLM_SETTINGS)),
|
||||
):
|
||||
# Only users with EDIT_LLM_SETTINGS permission can access
|
||||
...
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from storage.org_member_store import OrgMemberStore
|
||||
from storage.role import Role
|
||||
from storage.role_store import RoleStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.user_auth import get_user_id
|
||||
|
||||
|
||||
class Permission(str, Enum):
|
||||
"""Permissions that can be assigned to roles."""
|
||||
|
||||
# Secrets
|
||||
MANAGE_SECRETS = 'manage_secrets'
|
||||
|
||||
# MCP
|
||||
MANAGE_MCP = 'manage_mcp'
|
||||
|
||||
# Integrations
|
||||
MANAGE_INTEGRATIONS = 'manage_integrations'
|
||||
|
||||
# Application Settings
|
||||
MANAGE_APPLICATION_SETTINGS = 'manage_application_settings'
|
||||
|
||||
# API Keys
|
||||
MANAGE_API_KEYS = 'manage_api_keys'
|
||||
|
||||
# LLM Settings
|
||||
VIEW_LLM_SETTINGS = 'view_llm_settings'
|
||||
EDIT_LLM_SETTINGS = 'edit_llm_settings'
|
||||
|
||||
# Billing
|
||||
VIEW_BILLING = 'view_billing'
|
||||
ADD_CREDITS = 'add_credits'
|
||||
|
||||
# Organization Members
|
||||
INVITE_USER_TO_ORGANIZATION = 'invite_user_to_organization'
|
||||
CHANGE_USER_ROLE_MEMBER = 'change_user_role:member'
|
||||
CHANGE_USER_ROLE_ADMIN = 'change_user_role:admin'
|
||||
CHANGE_USER_ROLE_OWNER = 'change_user_role:owner'
|
||||
|
||||
# Organization Management
|
||||
VIEW_ORG_SETTINGS = 'view_org_settings'
|
||||
CHANGE_ORGANIZATION_NAME = 'change_organization_name'
|
||||
DELETE_ORGANIZATION = 'delete_organization'
|
||||
|
||||
# Temporary permissions until we finish the API updates.
|
||||
EDIT_ORG_SETTINGS = 'edit_org_settings'
|
||||
|
||||
|
||||
class RoleName(str, Enum):
|
||||
"""Role names used in the system."""
|
||||
|
||||
OWNER = 'owner'
|
||||
ADMIN = 'admin'
|
||||
MEMBER = 'member'
|
||||
|
||||
|
||||
# Permission mappings for each role
|
||||
ROLE_PERMISSIONS: dict[RoleName, frozenset[Permission]] = {
|
||||
RoleName.OWNER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
Permission.CHANGE_USER_ROLE_OWNER,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
# Organization Management (Owner only)
|
||||
Permission.CHANGE_ORGANIZATION_NAME,
|
||||
Permission.DELETE_ORGANIZATION,
|
||||
]
|
||||
),
|
||||
RoleName.ADMIN: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
Permission.EDIT_LLM_SETTINGS,
|
||||
Permission.VIEW_BILLING,
|
||||
Permission.ADD_CREDITS,
|
||||
# Organization Members
|
||||
Permission.INVITE_USER_TO_ORGANIZATION,
|
||||
Permission.CHANGE_USER_ROLE_MEMBER,
|
||||
Permission.CHANGE_USER_ROLE_ADMIN,
|
||||
# Organization Management
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.EDIT_ORG_SETTINGS,
|
||||
]
|
||||
),
|
||||
RoleName.MEMBER: frozenset(
|
||||
[
|
||||
# Settings (Full access)
|
||||
Permission.MANAGE_SECRETS,
|
||||
Permission.MANAGE_MCP,
|
||||
Permission.MANAGE_INTEGRATIONS,
|
||||
Permission.MANAGE_APPLICATION_SETTINGS,
|
||||
Permission.MANAGE_API_KEYS,
|
||||
# Settings (View only)
|
||||
Permission.VIEW_ORG_SETTINGS,
|
||||
Permission.VIEW_LLM_SETTINGS,
|
||||
]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def get_user_org_role(user_id: str, org_id: UUID | None) -> Role | None:
|
||||
"""
|
||||
Get the user's role in an organization.
|
||||
|
||||
Args:
|
||||
user_id: User ID (string that will be converted to UUID)
|
||||
org_id: Organization ID, or None to use the user's current organization
|
||||
|
||||
Returns:
|
||||
Role object if user is a member, None otherwise
|
||||
"""
|
||||
from uuid import UUID as parse_uuid
|
||||
|
||||
if org_id is None:
|
||||
org_member = await OrgMemberStore.get_org_member_for_current_org(
|
||||
parse_uuid(user_id)
|
||||
)
|
||||
else:
|
||||
org_member = await OrgMemberStore.get_org_member(org_id, parse_uuid(user_id))
|
||||
if not org_member:
|
||||
return None
|
||||
|
||||
return await RoleStore.get_role_by_id(org_member.role_id)
|
||||
|
||||
|
||||
def get_role_permissions(role_name: str) -> frozenset[Permission]:
|
||||
"""
|
||||
Get the permissions for a role.
|
||||
|
||||
Args:
|
||||
role_name: Name of the role
|
||||
|
||||
Returns:
|
||||
Set of permissions for the role
|
||||
"""
|
||||
try:
|
||||
role_enum = RoleName(role_name)
|
||||
return ROLE_PERMISSIONS.get(role_enum, frozenset())
|
||||
except ValueError:
|
||||
return frozenset()
|
||||
|
||||
|
||||
def has_permission(user_role: Role, permission: Permission) -> bool:
|
||||
"""
|
||||
Check if a role has a specific permission.
|
||||
|
||||
Args:
|
||||
user_role: User's Role object
|
||||
permission: Permission to check
|
||||
|
||||
Returns:
|
||||
True if the role has the permission
|
||||
"""
|
||||
permissions = get_role_permissions(user_role.name)
|
||||
return permission in permissions
|
||||
|
||||
|
||||
def require_permission(permission: Permission):
|
||||
"""
|
||||
Factory function that creates a dependency to require a specific permission.
|
||||
|
||||
This creates a FastAPI dependency that:
|
||||
1. Extracts org_id from the path parameter
|
||||
2. Gets the authenticated user_id
|
||||
3. Checks if the user has the required permission in the organization
|
||||
4. Returns the user_id if authorized, raises HTTPException otherwise
|
||||
|
||||
Usage:
|
||||
@router.get('/{org_id}/settings')
|
||||
async def get_settings(
|
||||
org_id: UUID,
|
||||
user_id: str = Depends(require_permission(Permission.VIEW_LLM_SETTINGS)),
|
||||
):
|
||||
...
|
||||
|
||||
Args:
|
||||
permission: The permission required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function that validates permission and returns user_id
|
||||
"""
|
||||
|
||||
async def permission_checker(
|
||||
org_id: UUID | None = None,
|
||||
user_id: str | None = Depends(get_user_id),
|
||||
) -> str:
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail='User not authenticated',
|
||||
)
|
||||
|
||||
user_role = await get_user_org_role(user_id, org_id)
|
||||
|
||||
if not user_role:
|
||||
logger.warning(
|
||||
'User not a member of organization',
|
||||
extra={'user_id': user_id, 'org_id': str(org_id)},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail='User is not a member of this organization',
|
||||
)
|
||||
|
||||
if not has_permission(user_role, permission):
|
||||
logger.warning(
|
||||
'Insufficient permissions',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'org_id': str(org_id),
|
||||
'user_role': user_role.name,
|
||||
'required_permission': permission.value,
|
||||
},
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f'Requires {permission.value} permission',
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
return permission_checker
|
||||
@@ -40,6 +40,16 @@ ROLE_CHECK_ENABLED = os.getenv('ROLE_CHECK_ENABLED', 'false').lower() in (
|
||||
)
|
||||
|
||||
DUPLICATE_EMAIL_CHECK = os.getenv('DUPLICATE_EMAIL_CHECK', 'true') in ('1', 'true')
|
||||
BITBUCKET_DATA_CENTER_CLIENT_ID = os.getenv(
|
||||
'BITBUCKET_DATA_CENTER_CLIENT_ID', ''
|
||||
).strip()
|
||||
BITBUCKET_DATA_CENTER_CLIENT_SECRET = os.getenv(
|
||||
'BITBUCKET_DATA_CENTER_CLIENT_SECRET', ''
|
||||
).strip()
|
||||
BITBUCKET_DATA_CENTER_HOST = os.getenv('BITBUCKET_DATA_CENTER_HOST', '').strip()
|
||||
BITBUCKET_DATA_CENTER_TOKEN_URL = (
|
||||
f'https://{BITBUCKET_DATA_CENTER_HOST}/rest/oauth2/latest/token'
|
||||
)
|
||||
|
||||
# reCAPTCHA Enterprise
|
||||
RECAPTCHA_PROJECT_ID = os.getenv('RECAPTCHA_PROJECT_ID', '').strip()
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
from storage.blocked_email_domain_store import BlockedEmailDomainStore
|
||||
from storage.database import session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
class DomainBlocker:
|
||||
def __init__(self, store: BlockedEmailDomainStore) -> None:
|
||||
logger.debug('Initializing DomainBlocker')
|
||||
self.store = store
|
||||
|
||||
def _extract_domain(self, email: str) -> str | None:
|
||||
"""Extract and normalize email domain from email address"""
|
||||
if not email:
|
||||
return None
|
||||
try:
|
||||
# Extract domain part after @
|
||||
if '@' not in email:
|
||||
return None
|
||||
domain = email.split('@')[1].strip().lower()
|
||||
return domain if domain else None
|
||||
except Exception:
|
||||
logger.debug(f'Error extracting domain from email: {email}', exc_info=True)
|
||||
return None
|
||||
|
||||
def is_domain_blocked(self, email: str) -> bool:
|
||||
"""Check if email domain is blocked by querying the database directly via SQL.
|
||||
|
||||
Supports blocking:
|
||||
- Exact domains: 'example.com' blocks 'user@example.com'
|
||||
- Subdomains: 'example.com' blocks 'user@subdomain.example.com'
|
||||
- TLDs: '.us' blocks 'user@company.us' and 'user@subdomain.company.us'
|
||||
|
||||
The blocking logic is handled efficiently in SQL, avoiding the need to load
|
||||
all blocked domains into memory.
|
||||
"""
|
||||
if not email:
|
||||
logger.debug('No email provided for domain check')
|
||||
return False
|
||||
|
||||
domain = self._extract_domain(email)
|
||||
if not domain:
|
||||
logger.debug(f'Could not extract domain from email: {email}')
|
||||
return False
|
||||
|
||||
try:
|
||||
# Query database directly via SQL to check if domain is blocked
|
||||
is_blocked = self.store.is_domain_blocked(domain)
|
||||
|
||||
if is_blocked:
|
||||
logger.warning(f'Email domain {domain} is blocked for email: {email}')
|
||||
else:
|
||||
logger.debug(f'Email domain {domain} is not blocked')
|
||||
|
||||
return is_blocked
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Error checking if domain is blocked for email {email}: {e}',
|
||||
exc_info=True,
|
||||
)
|
||||
# Fail-safe: if database query fails, don't block (allow auth to proceed)
|
||||
return False
|
||||
|
||||
|
||||
# Initialize store and domain blocker
|
||||
_store = BlockedEmailDomainStore(session_maker=session_maker)
|
||||
domain_blocker = DomainBlocker(store=_store)
|
||||
@@ -1,87 +1,11 @@
|
||||
import os
|
||||
|
||||
from integrations.github.github_service import SaaSGitHubService
|
||||
from pydantic import SecretStr
|
||||
from server.auth.sheets_client import GoogleSheetsClient
|
||||
from server.auth.auth_utils import user_verifier
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_types import GitHubUser
|
||||
|
||||
|
||||
class UserVerifier:
|
||||
def __init__(self) -> None:
|
||||
logger.debug('Initializing UserVerifier')
|
||||
self.file_users: list[str] | None = None
|
||||
self.sheets_client: GoogleSheetsClient | None = None
|
||||
self.spreadsheet_id: str | None = None
|
||||
|
||||
# Initialize from environment variables
|
||||
self._init_file_users()
|
||||
self._init_sheets_client()
|
||||
|
||||
def _init_file_users(self) -> None:
|
||||
"""Load users from text file if configured"""
|
||||
waitlist = os.getenv('GITHUB_USER_LIST_FILE')
|
||||
if not waitlist:
|
||||
logger.debug('GITHUB_USER_LIST_FILE not configured')
|
||||
return
|
||||
|
||||
if not os.path.exists(waitlist):
|
||||
logger.error(f'User list file not found: {waitlist}')
|
||||
raise FileNotFoundError(f'User list file not found: {waitlist}')
|
||||
|
||||
try:
|
||||
with open(waitlist, 'r') as f:
|
||||
self.file_users = [line.strip().lower() for line in f if line.strip()]
|
||||
logger.info(
|
||||
f'Successfully loaded {len(self.file_users)} users from {waitlist}'
|
||||
)
|
||||
except Exception:
|
||||
logger.error(f'Error reading user list file {waitlist}', exc_info=True)
|
||||
|
||||
def _init_sheets_client(self) -> None:
|
||||
"""Initialize Google Sheets client if configured"""
|
||||
sheet_id = os.getenv('GITHUB_USERS_SHEET_ID')
|
||||
|
||||
if not sheet_id:
|
||||
logger.debug('GITHUB_USERS_SHEET_ID not configured')
|
||||
return
|
||||
|
||||
logger.debug('Initializing Google Sheets integration')
|
||||
self.sheets_client = GoogleSheetsClient()
|
||||
self.spreadsheet_id = sheet_id
|
||||
|
||||
def is_active(self) -> bool:
|
||||
if os.getenv('DISABLE_WAITLIST', '').lower() == 'true':
|
||||
logger.info('Waitlist disabled via DISABLE_WAITLIST env var')
|
||||
return False
|
||||
return bool(self.file_users or (self.sheets_client and self.spreadsheet_id))
|
||||
|
||||
def is_user_allowed(self, username: str) -> bool:
|
||||
"""Check if user is allowed based on file and/or sheet configuration"""
|
||||
logger.debug(f'Checking if GitHub user {username} is allowed')
|
||||
if self.file_users:
|
||||
if username.lower() in self.file_users:
|
||||
logger.debug(f'User {username} found in text file allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in text file allowlist')
|
||||
|
||||
if self.sheets_client and self.spreadsheet_id:
|
||||
sheet_users = [
|
||||
u.lower() for u in self.sheets_client.get_usernames(self.spreadsheet_id)
|
||||
]
|
||||
if username.lower() in sheet_users:
|
||||
logger.debug(f'User {username} found in Google Sheets allowlist')
|
||||
return True
|
||||
logger.debug(f'User {username} not found in Google Sheets allowlist')
|
||||
|
||||
logger.debug(f'User {username} not found in any allowlist')
|
||||
return False
|
||||
|
||||
|
||||
user_verifier = UserVerifier()
|
||||
|
||||
|
||||
def is_user_allowed(user_login: str):
|
||||
if user_verifier.is_active() and not user_verifier.is_user_allowed(user_login):
|
||||
logger.warning(f'GitHub user {user_login} not in allow list')
|
||||
|
||||
@@ -13,16 +13,19 @@ from server.auth.auth_error import (
|
||||
ExpiredError,
|
||||
NoCredentialsError,
|
||||
)
|
||||
from server.auth.domain_blocker import domain_blocker
|
||||
from server.auth.constants import BITBUCKET_DATA_CENTER_HOST
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from server.rate_limit import RateLimiter, create_redis_rate_limiter
|
||||
from sqlalchemy import delete, select
|
||||
from storage.api_key_store import ApiKeyStore
|
||||
from storage.auth_tokens import AuthTokens
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
from storage.user_authorization_store import UserAuthorizationStore
|
||||
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
||||
|
||||
from openhands.integrations.provider import (
|
||||
@@ -118,13 +121,12 @@ class SaasUserAuth(UserAuth):
|
||||
self._settings = settings
|
||||
return settings
|
||||
|
||||
async def get_secrets_store(self):
|
||||
async def get_secrets_store(self) -> SaasSecretsStore:
|
||||
logger.debug('saas_user_auth_get_secrets_store')
|
||||
secrets_store = self.secrets_store
|
||||
if secrets_store:
|
||||
return secrets_store
|
||||
user_id = await self.get_user_id()
|
||||
secrets_store = SaasSecretsStore(user_id, session_maker, get_config())
|
||||
secrets_store = SaasSecretsStore(self.user_id, get_config())
|
||||
self.secrets_store = secrets_store
|
||||
return secrets_store
|
||||
|
||||
@@ -161,12 +163,13 @@ class SaasUserAuth(UserAuth):
|
||||
|
||||
try:
|
||||
# TODO: I think we can do this in a single request if we refactor
|
||||
with session_maker() as session:
|
||||
tokens = (
|
||||
session.query(AuthTokens)
|
||||
.where(AuthTokens.keycloak_user_id == self.user_id)
|
||||
.all()
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(AuthTokens).where(
|
||||
AuthTokens.keycloak_user_id == self.user_id
|
||||
)
|
||||
)
|
||||
tokens = result.scalars().all()
|
||||
|
||||
for token in tokens:
|
||||
idp_type = ProviderType(token.identity_provider)
|
||||
@@ -175,6 +178,9 @@ class SaasUserAuth(UserAuth):
|
||||
if user_secrets and idp_type in user_secrets.provider_tokens:
|
||||
host = user_secrets.provider_tokens[idp_type].host
|
||||
|
||||
if idp_type == ProviderType.BITBUCKET_DATA_CENTER and not host:
|
||||
host = BITBUCKET_DATA_CENTER_HOST or None
|
||||
|
||||
provider_token = await token_manager.get_idp_token(
|
||||
access_token.get_secret_value(),
|
||||
idp=idp_type,
|
||||
@@ -192,11 +198,11 @@ class SaasUserAuth(UserAuth):
|
||||
'idp_type': token.identity_provider,
|
||||
},
|
||||
)
|
||||
with session_maker() as session:
|
||||
session.query(AuthTokens).filter(
|
||||
AuthTokens.id == token.id
|
||||
).delete()
|
||||
session.commit()
|
||||
async with a_session_maker() as session:
|
||||
await session.execute(
|
||||
delete(AuthTokens).where(AuthTokens.id == token.id)
|
||||
)
|
||||
await session.commit()
|
||||
raise
|
||||
|
||||
self.provider_tokens = MappingProxyType(provider_tokens)
|
||||
@@ -209,8 +215,7 @@ class SaasUserAuth(UserAuth):
|
||||
settings_store = self.settings_store
|
||||
if settings_store:
|
||||
return settings_store
|
||||
user_id = await self.get_user_id()
|
||||
settings_store = SaasSettingsStore(user_id, session_maker, get_config())
|
||||
settings_store = SaasSettingsStore(self.user_id, get_config())
|
||||
self.settings_store = settings_store
|
||||
return settings_store
|
||||
|
||||
@@ -278,7 +283,7 @@ async def saas_user_auth_from_bearer(request: Request) -> SaasUserAuth | None:
|
||||
return None
|
||||
|
||||
api_key_store = ApiKeyStore.get_instance()
|
||||
user_id = api_key_store.validate_api_key(api_key)
|
||||
user_id = await api_key_store.validate_api_key(api_key)
|
||||
if not user_id:
|
||||
return None
|
||||
offline_token = await token_manager.load_offline_token(user_id)
|
||||
@@ -326,14 +331,16 @@ async def saas_user_auth_from_signed_token(signed_token: str) -> SaasUserAuth:
|
||||
email = access_token_payload['email']
|
||||
email_verified = access_token_payload['email_verified']
|
||||
|
||||
# Check if email domain is blocked
|
||||
if email and domain_blocker.is_domain_blocked(email):
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
# Check if email is blacklisted (whitelist takes precedence)
|
||||
if email:
|
||||
auth_type = await UserAuthorizationStore.get_authorization_type(email, None)
|
||||
if auth_type == UserAuthorizationType.BLACKLIST:
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for existing user with email: {email}'
|
||||
)
|
||||
raise AuthError(
|
||||
'Access denied: Your email domain is not allowed to access this service'
|
||||
)
|
||||
|
||||
logger.debug('saas_user_auth_from_signed_token:return')
|
||||
|
||||
|
||||
@@ -16,10 +16,15 @@ from keycloak.exceptions import (
|
||||
KeycloakError,
|
||||
KeycloakPostError,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from server.auth.auth_error import ExpiredError
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_APP_CLIENT_SECRET,
|
||||
BITBUCKET_DATA_CENTER_CLIENT_ID,
|
||||
BITBUCKET_DATA_CENTER_CLIENT_SECRET,
|
||||
BITBUCKET_DATA_CENTER_HOST,
|
||||
BITBUCKET_DATA_CENTER_TOKEN_URL,
|
||||
DUPLICATE_EMAIL_CHECK,
|
||||
GITHUB_APP_CLIENT_ID,
|
||||
GITHUB_APP_CLIENT_SECRET,
|
||||
@@ -38,9 +43,9 @@ from server.auth.keycloak_manager import get_keycloak_admin, get_keycloak_openid
|
||||
from server.config import get_config
|
||||
from server.logger import logger
|
||||
from sqlalchemy import String as SQLString
|
||||
from sqlalchemy import type_coerce
|
||||
from sqlalchemy import select, type_coerce
|
||||
from storage.auth_token_store import AuthTokenStore
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
from storage.github_app_installation import GithubAppInstallation
|
||||
from storage.offline_token_store import OfflineTokenStore
|
||||
from tenacity import RetryCallState, retry, retry_if_exception_type, stop_after_attempt
|
||||
@@ -50,6 +55,34 @@ from openhands.server.types import SessionExpiredError
|
||||
from openhands.utils.http_session import httpx_verify_option
|
||||
|
||||
|
||||
class KeycloakUserInfo(BaseModel):
|
||||
"""Pydantic model for Keycloak UserInfo endpoint response.
|
||||
|
||||
Based on OIDC standard claims. 'sub' is always required per OIDC spec.
|
||||
Additional fields from Keycloak are captured via model_config extra='allow'.
|
||||
"""
|
||||
|
||||
model_config = {'extra': 'allow'}
|
||||
|
||||
sub: str
|
||||
name: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
preferred_username: str | None = None
|
||||
email: str | None = None
|
||||
email_verified: bool | None = None
|
||||
picture: str | None = None
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
identity_provider: str | None = None
|
||||
company: str | None = None
|
||||
roles: list[str] | None = None
|
||||
|
||||
|
||||
# HTTP timeout for external IDP calls (in seconds)
|
||||
# This prevents indefinite blocking if an IDP is slow or unresponsive
|
||||
IDP_HTTP_TIMEOUT = 15.0
|
||||
|
||||
|
||||
def _before_sleep_callback(retry_state: RetryCallState) -> None:
|
||||
logger.info(f'Retry attempt {retry_state.attempt_number} for Keycloak operation')
|
||||
|
||||
@@ -137,22 +170,22 @@ class TokenManager:
|
||||
new_keycloak_tokens['refresh_token'],
|
||||
)
|
||||
|
||||
# UserInfo from Keycloak return a dictionary with the following format:
|
||||
# {
|
||||
# 'sub': '248289761001',
|
||||
# 'name': 'Jane Doe',
|
||||
# 'given_name': 'Jane',
|
||||
# 'family_name': 'Doe',
|
||||
# 'preferred_username': 'j.doe',
|
||||
# 'email': 'janedoe@example.com',
|
||||
# 'picture': 'http://example.com/janedoe/me.jpg'
|
||||
# 'github_id': '354322532'
|
||||
# }
|
||||
async def get_user_info(self, access_token: str) -> dict:
|
||||
if not access_token:
|
||||
return {}
|
||||
async def get_user_info(self, access_token: str) -> KeycloakUserInfo:
|
||||
"""Get user info from Keycloak userinfo endpoint.
|
||||
|
||||
Args:
|
||||
access_token: A valid Keycloak access token
|
||||
|
||||
Returns:
|
||||
KeycloakUserInfo with user claims. 'sub' is always present per OIDC spec.
|
||||
|
||||
Raises:
|
||||
KeycloakAuthenticationError: If the token is invalid
|
||||
ValidationError: If the response is missing the required 'sub' field
|
||||
"""
|
||||
user_info = await get_keycloak_openid(self.external).a_userinfo(access_token)
|
||||
return user_info
|
||||
# Pydantic validation will raise ValidationError if 'sub' is missing
|
||||
return KeycloakUserInfo.model_validate(user_info)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(2),
|
||||
@@ -202,7 +235,9 @@ class TokenManager:
|
||||
access_token: str,
|
||||
idp: ProviderType,
|
||||
) -> dict[str, str | int]:
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
base_url = KEYCLOAK_SERVER_URL_EXT if self.external else KEYCLOAK_SERVER_URL
|
||||
url = f'{base_url}/realms/{KEYCLOAK_REALM_NAME}/broker/{idp.value}/token'
|
||||
headers = {
|
||||
@@ -264,8 +299,8 @@ class TokenManager:
|
||||
) -> str:
|
||||
# Get user info to determine user_id and idp
|
||||
user_info = await self.get_user_info(access_token=access_token)
|
||||
user_id = user_info.get('sub')
|
||||
username = user_info.get('preferred_username')
|
||||
user_id = user_info.sub
|
||||
username = user_info.preferred_username
|
||||
logger.info(f'Getting token for user {username} and IDP {idp}')
|
||||
token_store = await AuthTokenStore.get_instance(
|
||||
keycloak_user_id=user_id, idp=idp
|
||||
@@ -348,6 +383,8 @@ class TokenManager:
|
||||
return await self._refresh_gitlab_token(refresh_token)
|
||||
elif idp == ProviderType.BITBUCKET:
|
||||
return await self._refresh_bitbucket_token(refresh_token)
|
||||
elif idp == ProviderType.BITBUCKET_DATA_CENTER:
|
||||
return await self._refresh_bitbucket_data_center_token(refresh_token)
|
||||
else:
|
||||
raise ValueError(f'Unsupported IDP: {idp}')
|
||||
|
||||
@@ -361,7 +398,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitHub token')
|
||||
@@ -387,7 +426,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed GitLab token')
|
||||
@@ -415,7 +456,9 @@ class TokenManager:
|
||||
'refresh_token': refresh_token,
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(verify=httpx_verify_option()) as client:
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=data, headers=headers)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed Bitbucket token')
|
||||
@@ -423,6 +466,33 @@ class TokenManager:
|
||||
data = response.json()
|
||||
return await self._parse_refresh_response(data)
|
||||
|
||||
async def _refresh_bitbucket_data_center_token(
|
||||
self, refresh_token: str
|
||||
) -> dict[str, str | int]:
|
||||
if not BITBUCKET_DATA_CENTER_HOST:
|
||||
raise ValueError(
|
||||
'BITBUCKET_DATA_CENTER_HOST is not configured. '
|
||||
'Set the BITBUCKET_DATA_CENTER_HOST environment variable.'
|
||||
)
|
||||
url = BITBUCKET_DATA_CENTER_TOKEN_URL
|
||||
logger.info(f'Refreshing Bitbucket Data Center token with URL: {url}')
|
||||
|
||||
payload = {
|
||||
'client_id': BITBUCKET_DATA_CENTER_CLIENT_ID,
|
||||
'client_secret': BITBUCKET_DATA_CENTER_CLIENT_SECRET,
|
||||
'refresh_token': refresh_token,
|
||||
'grant_type': 'refresh_token',
|
||||
}
|
||||
async with httpx.AsyncClient(
|
||||
verify=httpx_verify_option(), timeout=IDP_HTTP_TIMEOUT
|
||||
) as client:
|
||||
response = await client.post(url, data=payload)
|
||||
response.raise_for_status()
|
||||
logger.info('Successfully refreshed Bitbucket Data Center token')
|
||||
|
||||
data = response.json()
|
||||
return await self._parse_refresh_response(data)
|
||||
|
||||
async def _parse_refresh_response(self, data: dict) -> dict[str, str | int]:
|
||||
access_token = data.get('access_token')
|
||||
refresh_token = data.get('refresh_token')
|
||||
@@ -771,25 +841,24 @@ class TokenManager:
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def store_org_token(self, installation_id: int, installation_token: str):
|
||||
async def store_org_token(self, installation_id: int, installation_token: str):
|
||||
"""Store a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
installation_id: GitHub installation ID (integer or string)
|
||||
installation_token: The token to store
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Ensure installation_id is a string
|
||||
str_installation_id = str(installation_id)
|
||||
# Use type_coerce to ensure SQLAlchemy treats the parameter as a string
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if installation:
|
||||
installation.encrypted_token = self.encrypt_text(installation_token)
|
||||
else:
|
||||
@@ -799,9 +868,9 @@ class TokenManager:
|
||||
encrypted_token=self.encrypt_text(installation_token),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def load_org_token(self, installation_id: int) -> str | None:
|
||||
async def load_org_token(self, installation_id: int) -> str | None:
|
||||
"""Load a GitHub App installation token.
|
||||
|
||||
Args:
|
||||
@@ -810,17 +879,16 @@ class TokenManager:
|
||||
Returns:
|
||||
The decrypted token if found, None otherwise
|
||||
"""
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
# Ensure installation_id is a string and use type_coerce
|
||||
str_installation_id = str(installation_id)
|
||||
installation = (
|
||||
session.query(GithubAppInstallation)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(GithubAppInstallation).filter(
|
||||
GithubAppInstallation.installation_id
|
||||
== type_coerce(str_installation_id, SQLString)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
installation = result.scalars().first()
|
||||
if not installation:
|
||||
return None
|
||||
token = self.decrypt_text(installation.encrypted_token)
|
||||
|
||||
98
enterprise/server/auth/user/default_user_authorizer.py
Normal file
98
enterprise/server/auth/user/default_user_authorizer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from server.auth.email_validation import extract_base_email
|
||||
from server.auth.token_manager import KeycloakUserInfo, TokenManager
|
||||
from server.auth.user.user_authorizer import (
|
||||
UserAuthorizationResponse,
|
||||
UserAuthorizer,
|
||||
UserAuthorizerInjector,
|
||||
)
|
||||
from storage.user_authorization import UserAuthorizationType
|
||||
from storage.user_authorization_store import UserAuthorizationStore
|
||||
|
||||
from openhands.app_server.services.injector import InjectorState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
token_manager = TokenManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DefaultUserAuthorizer(UserAuthorizer):
|
||||
"""Class determining whether a user may be authorized.
|
||||
|
||||
Uses the user_authorizations database table to check whitelist/blacklist rules.
|
||||
"""
|
||||
|
||||
prevent_duplicates: bool
|
||||
|
||||
async def authorize_user(
|
||||
self, user_info: KeycloakUserInfo
|
||||
) -> UserAuthorizationResponse:
|
||||
user_id = user_info.sub
|
||||
email = user_info.email
|
||||
provider_type = user_info.identity_provider
|
||||
try:
|
||||
if not email:
|
||||
logger.warning(f'No email provided for user_id: {user_id}')
|
||||
return UserAuthorizationResponse(
|
||||
success=False, error_detail='missing_email'
|
||||
)
|
||||
|
||||
if self.prevent_duplicates:
|
||||
has_duplicate = await token_manager.check_duplicate_base_email(
|
||||
email, user_id
|
||||
)
|
||||
if has_duplicate:
|
||||
logger.warning(
|
||||
f'Blocked signup attempt for email {email} - duplicate base email found',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
return UserAuthorizationResponse(
|
||||
success=False, error_detail='duplicate_email'
|
||||
)
|
||||
|
||||
# Check authorization rules (whitelist takes precedence over blacklist)
|
||||
base_email = extract_base_email(email)
|
||||
if base_email is None:
|
||||
return UserAuthorizationResponse(
|
||||
success=False, error_detail='invalid_email'
|
||||
)
|
||||
auth_type = await UserAuthorizationStore.get_authorization_type(
|
||||
base_email, provider_type
|
||||
)
|
||||
|
||||
if auth_type == UserAuthorizationType.WHITELIST:
|
||||
logger.debug(
|
||||
f'User {email} matched whitelist rule',
|
||||
extra={'user_id': user_id, 'email': email},
|
||||
)
|
||||
return UserAuthorizationResponse(success=True)
|
||||
|
||||
if auth_type == UserAuthorizationType.BLACKLIST:
|
||||
logger.warning(
|
||||
f'Blocked authentication attempt for email: {email}, user_id: {user_id}'
|
||||
)
|
||||
return UserAuthorizationResponse(success=False, error_detail='blocked')
|
||||
|
||||
return UserAuthorizationResponse(success=True)
|
||||
except Exception:
|
||||
logger.exception('error authorizing user', extra={'user_id': user_id})
|
||||
return UserAuthorizationResponse(success=False)
|
||||
|
||||
|
||||
class DefaultUserAuthorizerInjector(UserAuthorizerInjector):
|
||||
prevent_duplicates: bool = Field(
|
||||
default=True,
|
||||
description='Whether duplicate emails (containing +) are filtered',
|
||||
)
|
||||
|
||||
async def inject(
|
||||
self, state: InjectorState, request: Request | None = None
|
||||
) -> AsyncGenerator[UserAuthorizer, None]:
|
||||
yield DefaultUserAuthorizer(
|
||||
prevent_duplicates=self.prevent_duplicates,
|
||||
)
|
||||
48
enterprise/server/auth/user/user_authorizer.py
Normal file
48
enterprise/server/auth/user/user_authorizer.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from server.auth.token_manager import KeycloakUserInfo
|
||||
|
||||
from openhands.agent_server.env_parser import from_env
|
||||
from openhands.app_server.services.injector import Injector
|
||||
from openhands.sdk.utils.models import DiscriminatedUnionMixin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UserAuthorizationResponse(BaseModel):
|
||||
success: bool
|
||||
error_detail: str | None = None
|
||||
|
||||
|
||||
class UserAuthorizer(ABC):
|
||||
"""Class determining whether a user may be authorized."""
|
||||
|
||||
@abstractmethod
|
||||
async def authorize_user(
|
||||
self, user_info: KeycloakUserInfo
|
||||
) -> UserAuthorizationResponse:
|
||||
"""Determine whether the info given is permitted."""
|
||||
|
||||
|
||||
class UserAuthorizerInjector(DiscriminatedUnionMixin, Injector[UserAuthorizer], ABC):
|
||||
pass
|
||||
|
||||
|
||||
def depends_user_authorizer():
|
||||
from server.auth.user.default_user_authorizer import (
|
||||
DefaultUserAuthorizerInjector,
|
||||
)
|
||||
|
||||
try:
|
||||
injector: UserAuthorizerInjector = from_env(
|
||||
UserAuthorizerInjector, 'OH_USER_AUTHORIZER'
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
logger.info('Using default UserAuthorizer')
|
||||
injector = DefaultUserAuthorizerInjector()
|
||||
|
||||
return Depends(injector.depends)
|
||||
@@ -7,7 +7,8 @@ from uuid import uuid4
|
||||
import socketio
|
||||
from server.logger import logger
|
||||
from server.utils.conversation_callback_utils import invoke_conversation_callbacks
|
||||
from storage.database import session_maker
|
||||
from sqlalchemy import select
|
||||
from storage.database import a_session_maker
|
||||
from storage.stored_conversation_metadata_saas import StoredConversationMetadataSaas
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
@@ -523,15 +524,14 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
f'local_connection_to_stopped_conversation:{connection_id}:{conversation_id}'
|
||||
)
|
||||
# Look up the user_id from the database
|
||||
with session_maker() as session:
|
||||
conversation_metadata_saas = (
|
||||
session.query(StoredConversationMetadataSaas)
|
||||
.filter(
|
||||
async with a_session_maker() as session:
|
||||
result = await session.execute(
|
||||
select(StoredConversationMetadataSaas).where(
|
||||
StoredConversationMetadataSaas.conversation_id
|
||||
== conversation_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
conversation_metadata_saas = result.scalars().first()
|
||||
user_id = (
|
||||
str(conversation_metadata_saas.user_id)
|
||||
if conversation_metadata_saas
|
||||
@@ -749,6 +749,9 @@ class ClusteredConversationManager(StandaloneConversationManager):
|
||||
config = load_openhands_config()
|
||||
settings_store = await SaasSettingsStore.get_instance(config, user_id)
|
||||
settings = await settings_store.load()
|
||||
if not settings:
|
||||
logger.error(f'Failed to load settings for user {user_id}')
|
||||
return
|
||||
await self.maybe_start_agent_loop(conversation_id, settings, user_id)
|
||||
|
||||
async def _start_agent_loop(
|
||||
|
||||
@@ -9,6 +9,7 @@ import requests # type: ignore
|
||||
from fastapi import HTTPException
|
||||
from server.auth.constants import (
|
||||
BITBUCKET_APP_CLIENT_ID,
|
||||
BITBUCKET_DATA_CENTER_CLIENT_ID,
|
||||
ENABLE_ENTERPRISE_SSO,
|
||||
ENABLE_JIRA,
|
||||
ENABLE_JIRA_DC,
|
||||
@@ -164,6 +165,9 @@ class SaaSServerConfig(ServerConfig):
|
||||
if ENABLE_ENTERPRISE_SSO:
|
||||
providers_configured.append(ProviderType.ENTERPRISE_SSO)
|
||||
|
||||
if BITBUCKET_DATA_CENTER_CLIENT_ID:
|
||||
providers_configured.append(ProviderType.BITBUCKET_DATA_CENTER)
|
||||
|
||||
config: dict[str, typing.Any] = {
|
||||
'APP_MODE': self.app_mode,
|
||||
'APP_SLUG': self.app_slug,
|
||||
|
||||
@@ -15,6 +15,11 @@ IS_FEATURE_ENV = (
|
||||
) # Does not include the staging deployment
|
||||
IS_LOCAL_ENV = bool(HOST == 'localhost')
|
||||
|
||||
# Role name constants
|
||||
ROLE_OWNER = 'owner'
|
||||
ROLE_ADMIN = 'admin'
|
||||
ROLE_MEMBER = 'member'
|
||||
|
||||
# Deprecated - billing margins are now handled internally in litellm
|
||||
DEFAULT_BILLING_MARGIN = float(os.environ.get('DEFAULT_BILLING_MARGIN', '1.0'))
|
||||
|
||||
@@ -25,7 +30,9 @@ PERSONAL_WORKSPACE_VERSION_TO_MODEL = {
|
||||
2: 'claude-3-7-sonnet-20250219',
|
||||
3: 'claude-sonnet-4-20250514',
|
||||
4: 'claude-sonnet-4-20250514',
|
||||
5: 'claude-opus-4-5-20251101',
|
||||
# Minimax is now the default as it gives results close to claude in terms of quality
|
||||
# but at a much lower price
|
||||
5: 'minimax-m2.5',
|
||||
}
|
||||
|
||||
LITELLM_DEFAULT_MODEL = os.getenv('LITELLM_DEFAULT_MODEL')
|
||||
@@ -54,7 +61,6 @@ SUBSCRIPTION_PRICE_DATA = {
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_INITIAL_BUDGET = float(os.environ.get('DEFAULT_INITIAL_BUDGET', '10'))
|
||||
STRIPE_API_KEY = os.environ.get('STRIPE_API_KEY', None)
|
||||
REQUIRE_PAYMENT = os.environ.get('REQUIRE_PAYMENT', '0') in ('1', 'true')
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from integrations.github.github_manager import GithubManager
|
||||
from integrations.github.github_view import GithubViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -35,16 +34,12 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_github(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitHub.
|
||||
"""Send a message to GitHub.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitHub
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
|
||||
@@ -53,8 +48,8 @@ class GithubCallbackProcessor(ConversationCallbackProcessor):
|
||||
|
||||
github_manager = GithubManager(token_manager, GitHubDataCollector())
|
||||
|
||||
# Send the message
|
||||
await github_manager.send_message(message_obj, self.github_view)
|
||||
# Send the message directly as a string
|
||||
await github_manager.send_message(message, self.github_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Sent summary message to {self.github_view.full_repo_name}#{self.github_view.issue_number}'
|
||||
|
||||
@@ -3,7 +3,6 @@ from datetime import datetime
|
||||
|
||||
from integrations.gitlab.gitlab_manager import GitlabManager
|
||||
from integrations.gitlab.gitlab_view import GitlabViewType
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.utils import (
|
||||
extract_summary_from_conversation_manager,
|
||||
get_summary_instruction,
|
||||
@@ -14,7 +13,7 @@ from storage.conversation_callback import (
|
||||
ConversationCallback,
|
||||
ConversationCallbackProcessor,
|
||||
)
|
||||
from storage.database import session_maker
|
||||
from storage.database import a_session_maker
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
@@ -28,8 +27,7 @@ gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
|
||||
class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to GitLab.
|
||||
"""Processor for sending conversation summaries to GitLab.
|
||||
|
||||
This processor is used to send summaries of conversations to GitLab
|
||||
when agent state changes occur.
|
||||
@@ -39,22 +37,18 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
send_summary_instruction: bool = True
|
||||
|
||||
async def _send_message_to_gitlab(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to GitLab.
|
||||
"""Send a message to GitLab.
|
||||
|
||||
Args:
|
||||
message: The message content to send to GitLab
|
||||
"""
|
||||
try:
|
||||
# Create a message object for GitHub
|
||||
message_obj = Message(source=SourceType.OPENHANDS, message=message)
|
||||
|
||||
# Get the token manager
|
||||
token_manager = TokenManager()
|
||||
gitlab_manager = GitlabManager(token_manager)
|
||||
|
||||
# Send the message
|
||||
await gitlab_manager.send_message(message_obj, self.gitlab_view)
|
||||
# Send the message directly as a string
|
||||
await gitlab_manager.send_message(message, self.gitlab_view)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Sent summary message to {self.gitlab_view.full_repo_name}#{self.gitlab_view.issue_number}'
|
||||
@@ -111,9 +105,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
self.send_summary_instruction = False
|
||||
callback.set_processor(self)
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
# Extract the summary from the event store
|
||||
@@ -132,9 +126,9 @@ class GitlabCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Mark callback as completed status
|
||||
callback.status = CallbackStatus.COMPLETED
|
||||
callback.updated_at = datetime.now()
|
||||
with session_maker() as session:
|
||||
async with a_session_maker() as session:
|
||||
session.merge(callback)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
|
||||
@@ -37,8 +37,7 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_jira(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira issue.
|
||||
"""Send a comment to Jira issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira
|
||||
@@ -59,8 +58,9 @@ class JiraCallbackProcessor(ConversationCallbackProcessor):
|
||||
# Decrypt API key
|
||||
api_key = jira_manager.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_manager.send_message(
|
||||
jira_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
issue_key=self.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
|
||||
@@ -37,8 +37,7 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
base_api_url: str
|
||||
|
||||
async def _send_comment_to_jira_dc(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Jira DC issue.
|
||||
"""Send a comment to Jira DC issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Jira DC
|
||||
@@ -61,8 +60,9 @@ class JiraDcCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment directly as a string
|
||||
await jira_dc_manager.send_message(
|
||||
jira_dc_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
issue_key=self.issue_key,
|
||||
base_api_url=self.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
|
||||
@@ -36,8 +36,7 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace_name: str
|
||||
|
||||
async def _send_comment_to_linear(self, message: str) -> None:
|
||||
"""
|
||||
Send a comment to Linear issue.
|
||||
"""Send a comment to Linear issue.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Linear
|
||||
@@ -60,9 +59,9 @@ class LinearCallbackProcessor(ConversationCallbackProcessor):
|
||||
workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
# Send comment
|
||||
# Send comment directly as a string
|
||||
await linear_manager.send_message(
|
||||
linear_manager.create_outgoing_message(msg=message),
|
||||
message,
|
||||
self.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
@@ -26,8 +26,7 @@ slack_manager = SlackManager(token_manager)
|
||||
|
||||
|
||||
class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
"""
|
||||
Processor for sending conversation summaries to Slack.
|
||||
"""Processor for sending conversation summaries to Slack.
|
||||
|
||||
This processor is used to send summaries of conversations to Slack channels
|
||||
when agent state changes occur.
|
||||
@@ -41,14 +40,13 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
last_user_msg_id: int | None = None
|
||||
|
||||
async def _send_message_to_slack(self, message: str) -> None:
|
||||
"""
|
||||
Send a message to Slack using the conversation_manager's send_to_event_stream method.
|
||||
"""Send a message to Slack.
|
||||
|
||||
Args:
|
||||
message: The message content to send to Slack
|
||||
"""
|
||||
try:
|
||||
# Create a message object for Slack
|
||||
# Create a message object for Slack view creation (incoming message format)
|
||||
message_obj = Message(
|
||||
source=SourceType.SLACK,
|
||||
message={
|
||||
@@ -64,12 +62,11 @@ class SlackCallbackProcessor(ConversationCallbackProcessor):
|
||||
slack_user, saas_user_auth = await slack_manager.authenticate_user(
|
||||
self.slack_user_id
|
||||
)
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
slack_view = await SlackFactory.create_slack_view_from_payload(
|
||||
message_obj, slack_user, saas_user_auth
|
||||
)
|
||||
await slack_manager.send_message(
|
||||
slack_manager.create_outgoing_message(message), slack_view
|
||||
)
|
||||
# Send the message directly as a string
|
||||
await slack_manager.send_message(message, slack_view)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Sent summary message to channel {self.channel_id} '
|
||||
|
||||
@@ -51,6 +51,14 @@ def custom_json_serializer(obj, **kwargs):
|
||||
obj['stack_info'] = format_stack(stack_info)
|
||||
|
||||
result = json.dumps(obj, **kwargs)
|
||||
|
||||
# Swap out newlines to make things easier to read. This will produce
|
||||
# invalid json but means we can have similar logs in local development
|
||||
# to production, making things easier to correlate. Obviously,
|
||||
# LOG_JSON_FOR_CONSOLE should not be used in production environments.
|
||||
if LOG_JSON_FOR_CONSOLE:
|
||||
result = result.replace('\\n', '\n')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user