mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-04-29 03:00:45 -04:00
Compare commits
82 Commits
refactor/s
...
test-user
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3a7df33acf | ||
|
|
7222730df0 | ||
|
|
910177fc57 | ||
|
|
ac9badbd20 | ||
|
|
02c299d88f | ||
|
|
f65fbef649 | ||
|
|
3c2acad28d | ||
|
|
0f1780728e | ||
|
|
d3f3378a4c | ||
|
|
65f4164749 | ||
|
|
3f984d878b | ||
|
|
10b871f4ab | ||
|
|
d664f516db | ||
|
|
e74bbd81d1 | ||
|
|
ab893f93f0 | ||
|
|
5aba498e77 | ||
|
|
1523555eea | ||
|
|
30604c40fc | ||
|
|
8dc46b7206 | ||
|
|
69498bebb4 | ||
|
|
77ee9e25d9 | ||
|
|
74753036bb | ||
|
|
95d7c10608 | ||
|
|
c142cc27ff | ||
|
|
0e20fc206b | ||
|
|
e21475a88e | ||
|
|
921fec0019 | ||
|
|
049f839a62 | ||
|
|
0dde758e13 | ||
|
|
8257ae70cc | ||
|
|
4513bcc622 | ||
|
|
b5b9a3f40b | ||
|
|
8ea1259943 | ||
|
|
ddb2794adf | ||
|
|
79fdcad7ef | ||
|
|
1de70b8ce4 | ||
|
|
3baeecb27c | ||
|
|
b08238c841 | ||
|
|
831084df4c | ||
|
|
eb4dacb577 | ||
|
|
8e71459601 | ||
|
|
fc29815aa0 | ||
|
|
a809d74b7d | ||
|
|
b090d097ed | ||
|
|
79f32a34a0 | ||
|
|
805bc5608e | ||
|
|
61e1957cee | ||
|
|
a25826a5f9 | ||
|
|
df9320f8ab | ||
|
|
af0ab5a9f2 | ||
|
|
9960d11d08 | ||
|
|
d5d5e265f8 | ||
|
|
69fddecc7f | ||
|
|
989a4e662b | ||
|
|
ecfbae2285 | ||
|
|
3afe5ccee5 | ||
|
|
3d5a8dcf5a | ||
|
|
c9cf351697 | ||
|
|
aca568cfbe | ||
|
|
2ee1abe22c | ||
|
|
148940f553 | ||
|
|
3366ad9de7 | ||
|
|
f442e07b33 | ||
|
|
fdf8b21b84 | ||
|
|
93e843a06b | ||
|
|
e37f7b0e0f | ||
|
|
bd8b1bfa25 | ||
|
|
a4f11006f6 | ||
|
|
c6950946bb | ||
|
|
81d6341f9d | ||
|
|
55a6bbd9a4 | ||
|
|
20e5c40969 | ||
|
|
3e8dc41bdf | ||
|
|
1f09296136 | ||
|
|
70e5d12ba9 | ||
|
|
bcb3160d95 | ||
|
|
174c691744 | ||
|
|
af34d446e9 | ||
|
|
6604924f76 | ||
|
|
b2def1e438 | ||
|
|
2b8e47aca9 | ||
|
|
dba8b28824 |
23
.github/workflows/dispatch-to-docs.yml
vendored
Normal file
23
.github/workflows/dispatch-to-docs.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: Dispatch to docs repo
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'docs/**'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
dispatch:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
repo: ["All-Hands-AI/docs"]
|
||||
steps:
|
||||
- name: Push to docs repo
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.ALLHANDS_BOT_GITHUB_PAT }}
|
||||
repository: ${{ matrix.repo }}
|
||||
event-type: update
|
||||
client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "module": "openhands", "branch": "main"}'
|
||||
29
.github/workflows/enterprise-preview.yml
vendored
Normal file
29
.github/workflows/enterprise-preview.yml
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
# Feature branch preview for enterprise code
|
||||
name: Enterprise Preview
|
||||
|
||||
# Run on PRs labeled
|
||||
on:
|
||||
pull_request:
|
||||
types: [labeled]
|
||||
|
||||
# Match ghcr-build.yml, but don't interrupt it.
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: false
|
||||
|
||||
jobs:
|
||||
# This must happen for the PR Docker workflow when the label is present,
|
||||
# and also if it's added after the fact. Thus, it exists in both places.
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event.label.name == 'deploy'
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
# This should match the version in ghcr-build.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/All-Hands-AI/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
96
.github/workflows/ghcr-build.yml
vendored
96
.github/workflows/ghcr-build.yml
vendored
@@ -10,14 +10,14 @@ on:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- '*'
|
||||
- "*"
|
||||
pull_request:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
reason:
|
||||
description: 'Reason for manual trigger'
|
||||
description: "Reason for manual trigger"
|
||||
required: true
|
||||
default: ''
|
||||
default: ""
|
||||
|
||||
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
|
||||
concurrency:
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: "3.12"
|
||||
cache: poetry
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: make install-python-dependencies POETRY_GROUP=main INSTALL_PLAYWRIGHT=0
|
||||
@@ -166,6 +166,90 @@ jobs:
|
||||
name: runtime-src-${{ matrix.base_image.tag }}
|
||||
path: containers/runtime
|
||||
|
||||
ghcr_build_enterprise:
|
||||
name: Push Enterprise Image
|
||||
runs-on: blacksmith-8vcpu-ubuntu-2204
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
needs: [define-matrix, ghcr_build_app]
|
||||
# Do not build enterprise in forks
|
||||
if: github.event.pull_request.head.repo.fork != true
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
|
||||
# Set up Docker Buildx for better performance
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
with:
|
||||
driver-opts: network=host
|
||||
|
||||
- name: Login to GHCR
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.repository_owner }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ghcr.io/all-hands-ai/enterprise-server
|
||||
tags: |
|
||||
type=ref,event=branch
|
||||
type=ref,event=pr
|
||||
type=sha
|
||||
type=sha,format=long
|
||||
type=semver,pattern={{version}}
|
||||
type=semver,pattern={{major}}.{{minor}}
|
||||
type=semver,pattern={{major}}
|
||||
flavor: |
|
||||
latest=auto
|
||||
prefix=
|
||||
suffix=
|
||||
- name: Determine app image tag
|
||||
shell: bash
|
||||
run: |
|
||||
# Duplicated with build.sh
|
||||
sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g')
|
||||
OPENHANDS_BUILD_VERSION=$sanitized_ref_name
|
||||
sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging
|
||||
echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV
|
||||
- name: Build and push Docker image
|
||||
uses: useblacksmith/build-push-action@v1
|
||||
with:
|
||||
context: .
|
||||
file: enterprise/Dockerfile
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
build-args: |
|
||||
OPENHANDS_VERSION=${{ env.OPENHANDS_DOCKER_TAG }}
|
||||
platforms: linux/amd64
|
||||
# Add build provenance
|
||||
provenance: true
|
||||
# Add build attestations for better security
|
||||
sbom: true
|
||||
|
||||
enterprise-preview:
|
||||
name: Enterprise preview
|
||||
if: github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'deploy')
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
needs: [ghcr_build_enterprise]
|
||||
steps:
|
||||
# This should match the version in enterprise-preview.yml
|
||||
- name: Trigger remote job
|
||||
run: |
|
||||
curl --fail-with-body -sS -X POST \
|
||||
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||
-H "Accept: application/vnd.github+json" \
|
||||
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||
https://api.github.com/repos/All-Hands-AI/deploy/actions/workflows/deploy.yaml/dispatches
|
||||
|
||||
# Run unit tests with the Docker runtime Docker images as root
|
||||
test_runtime_root:
|
||||
name: RT Unit Tests (Root)
|
||||
@@ -202,7 +286,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: "3.12"
|
||||
cache: poetry
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: make install-python-dependencies INSTALL_PLAYWRIGHT=0
|
||||
@@ -264,7 +348,7 @@ jobs:
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: '3.12'
|
||||
python-version: "3.12"
|
||||
cache: poetry
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: make install-python-dependencies POETRY_GROUP=main,test,runtime INSTALL_PLAYWRIGHT=0
|
||||
|
||||
18
.github/workflows/lint.yml
vendored
18
.github/workflows/lint.yml
vendored
@@ -55,6 +55,24 @@ jobs:
|
||||
- name: Run pre-commit hooks
|
||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
lint-enterprise-python:
|
||||
name: Lint enterprise python
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Set up python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: 3.12
|
||||
cache: "pip"
|
||||
- name: Install pre-commit
|
||||
run: pip install pre-commit==4.2.0
|
||||
- name: Run pre-commit hooks
|
||||
working-directory: ./enterprise
|
||||
run: pre-commit run --all-files --config ./dev_config/python/.pre-commit-config.yaml
|
||||
|
||||
# Check version consistency across documentation
|
||||
check-version-consistency:
|
||||
name: Check version consistency
|
||||
|
||||
70
.github/workflows/mdx-lint.yml
vendored
Normal file
70
.github/workflows/mdx-lint.yml
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
# Workflow that checks MDX format in docs/ folder
|
||||
name: MDX Lint
|
||||
|
||||
# Run on pushes to main and on pull requests that modify docs/ files
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- 'docs/**/*.mdx'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'docs/**/*.mdx'
|
||||
|
||||
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ (github.head_ref && github.ref) || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
mdx-lint:
|
||||
name: Lint MDX files
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Install Node.js 22
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install MDX dependencies
|
||||
run: |
|
||||
npm install @mdx-js/mdx@3 glob@10
|
||||
|
||||
- name: Validate MDX files
|
||||
run: |
|
||||
node -e "
|
||||
const {compile} = require('@mdx-js/mdx');
|
||||
const fs = require('fs');
|
||||
const path = require('path');
|
||||
const glob = require('glob');
|
||||
|
||||
async function validateMDXFiles() {
|
||||
const files = glob.sync('docs/**/*.mdx');
|
||||
console.log('Found', files.length, 'MDX files to validate');
|
||||
|
||||
let hasErrors = false;
|
||||
|
||||
for (const file of files) {
|
||||
try {
|
||||
const content = fs.readFileSync(file, 'utf8');
|
||||
await compile(content);
|
||||
console.log('✅ MDX parsing successful for', file);
|
||||
} catch (err) {
|
||||
console.error('❌ MDX parsing failed for', file, ':', err.message);
|
||||
hasErrors = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasErrors) {
|
||||
console.error('\\n❌ Some MDX files have parsing errors. Please fix them before merging.');
|
||||
process.exit(1);
|
||||
} else {
|
||||
console.log('\\n✅ All MDX files are valid!');
|
||||
}
|
||||
}
|
||||
|
||||
validateMDXFiles();
|
||||
"
|
||||
33
.github/workflows/py-tests.yml
vendored
33
.github/workflows/py-tests.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
name: Python Tests on Linux
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
env:
|
||||
INSTALL_DOCKER: '0' # Set to '0' to skip Docker installation
|
||||
INSTALL_DOCKER: "0" # Set to '0' to skip Docker installation
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.12']
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Docker Buildx
|
||||
@@ -35,14 +35,14 @@ jobs:
|
||||
- name: Setup Node.js
|
||||
uses: useblacksmith/setup-node@v5
|
||||
with:
|
||||
node-version: '22.x'
|
||||
node-version: "22.x"
|
||||
- name: Install poetry via pipx
|
||||
run: pipx install poetry
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache: "poetry"
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --with dev,test,runtime
|
||||
- name: Build Environment
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.12']
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install pipx
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: 'poetry'
|
||||
cache: "poetry"
|
||||
- name: Install Python dependencies using Poetry
|
||||
run: poetry install --with dev,test,runtime
|
||||
- name: Run Windows unit tests
|
||||
@@ -83,3 +83,24 @@ jobs:
|
||||
PYTHONPATH: ".;$env:PYTHONPATH"
|
||||
TEST_RUNTIME: local
|
||||
DEBUG: "1"
|
||||
test-enterprise:
|
||||
name: Enterprise Python Unit Tests
|
||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Install poetry via pipx
|
||||
run: pipx install poetry
|
||||
- name: Set up Python
|
||||
uses: useblacksmith/setup-python@v6
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
cache: "poetry"
|
||||
- name: Install Python dependencies using Poetry
|
||||
working-directory: ./enterprise
|
||||
run: poetry install --with dev,test
|
||||
- name: Run Unit Tests
|
||||
working-directory: ./enterprise
|
||||
run: PYTHONPATH=".:$PYTHONPATH" poetry run pytest --forked -n auto -svv -p no:ddtrace -p no:ddtrace.pytest_bdd -p no:ddtrace.pytest_benchmark ./tests/unit
|
||||
|
||||
@@ -87,8 +87,6 @@ VSCode Extension:
|
||||
|
||||
If you are starting a pull request (PR), please follow the template in `.github/pull_request_template.md`.
|
||||
|
||||
If you need to add labels when opening a PR, check the existing labels defined on that repository and select from existing ones. Do not invent your own labels.
|
||||
|
||||
## Implementation Details
|
||||
|
||||
These details may or may not be useful for your current task.
|
||||
|
||||
@@ -159,7 +159,7 @@ poetry run pytest ./tests/unit/test_*.py
|
||||
To reduce build time (e.g., if no changes were made to the client-runtime component), you can use an existing Docker
|
||||
container image by setting the SANDBOX_RUNTIME_CONTAINER_IMAGE environment variable to the desired Docker image.
|
||||
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:0.55-nikolaik`
|
||||
Example: `export SANDBOX_RUNTIME_CONTAINER_IMAGE=ghcr.io/all-hands-ai/runtime:0.56-nikolaik`
|
||||
|
||||
## Develop inside Docker container
|
||||
|
||||
|
||||
@@ -1,277 +0,0 @@
|
||||
# Migration Guide: From Shared Globals to Context System
|
||||
|
||||
This guide explains how to migrate from the deprecated `openhands.server.shared` globals to the new context system.
|
||||
|
||||
## Overview
|
||||
|
||||
The new context system replaces global variables with dependency injection, providing:
|
||||
|
||||
- **Better testability**: Easy to mock dependencies in tests
|
||||
- **SaaS extensibility**: Custom contexts for multi-tenant scenarios
|
||||
- **Per-request contexts**: Different configurations per request
|
||||
- **No import-time side effects**: Lazy initialization of dependencies
|
||||
- **Type safety**: Better IDE support and type checking
|
||||
|
||||
## Quick Migration
|
||||
|
||||
### Before (Deprecated)
|
||||
```python
|
||||
from openhands.server.shared import config, server_config, file_store, sio
|
||||
|
||||
def my_function():
|
||||
# Use global variables
|
||||
workspace_dir = config.workspace_dir
|
||||
app_mode = server_config.app_mode
|
||||
file_store.save_file(...)
|
||||
```
|
||||
|
||||
### After (Recommended)
|
||||
```python
|
||||
from fastapi import Depends, Request
|
||||
from openhands.server.context import get_server_context, ServerContext
|
||||
|
||||
@app.get('/my-endpoint')
|
||||
async def my_endpoint(
|
||||
request: Request,
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
# Use context instead of globals
|
||||
config = context.get_config()
|
||||
server_config = context.get_server_config()
|
||||
file_store = context.get_file_store()
|
||||
|
||||
workspace_dir = config.workspace_dir
|
||||
app_mode = server_config.app_mode
|
||||
file_store.save_file(...)
|
||||
```
|
||||
|
||||
## Detailed Migration Steps
|
||||
|
||||
### 1. Route Handlers
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
from openhands.server.shared import config, conversation_manager
|
||||
|
||||
@app.post('/conversations')
|
||||
async def create_conversation(request: ConversationRequest):
|
||||
conversation = conversation_manager.create_conversation(
|
||||
request.user_id,
|
||||
config.default_agent
|
||||
)
|
||||
return conversation
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
from fastapi import Depends
|
||||
from openhands.server.context import get_server_context, ServerContext
|
||||
|
||||
@app.post('/conversations')
|
||||
async def create_conversation(
|
||||
request: ConversationRequest,
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
config = context.get_config()
|
||||
conversation_manager = context.get_conversation_manager()
|
||||
|
||||
conversation = conversation_manager.create_conversation(
|
||||
request.user_id,
|
||||
config.default_agent
|
||||
)
|
||||
return conversation
|
||||
```
|
||||
|
||||
### 2. Service Classes
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
from openhands.server.shared import file_store, monitoring_listener
|
||||
|
||||
class MyService:
|
||||
def process_file(self, file_path: str):
|
||||
content = file_store.read(file_path)
|
||||
monitoring_listener.log_event('file_processed')
|
||||
return content
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
from openhands.server.context import ServerContext
|
||||
|
||||
class MyService:
|
||||
def __init__(self, context: ServerContext):
|
||||
self.context = context
|
||||
|
||||
def process_file(self, file_path: str):
|
||||
file_store = self.context.get_file_store()
|
||||
monitoring_listener = self.context.get_monitoring_listener()
|
||||
|
||||
content = file_store.read(file_path)
|
||||
monitoring_listener.log_event('file_processed')
|
||||
return content
|
||||
|
||||
# In route handler:
|
||||
@app.post('/process')
|
||||
async def process_endpoint(
|
||||
request: ProcessRequest,
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
service = MyService(context)
|
||||
return service.process_file(request.file_path)
|
||||
```
|
||||
|
||||
### 3. Store Classes
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
from openhands.server.shared import SettingsStoreImpl
|
||||
|
||||
def get_user_settings(user_id: str):
|
||||
store = SettingsStoreImpl(user_id)
|
||||
return store.load()
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
from openhands.server.context import ServerContext
|
||||
|
||||
def get_user_settings(user_id: str, context: ServerContext):
|
||||
SettingsStoreClass = context.get_settings_store_class()
|
||||
store = SettingsStoreClass(user_id)
|
||||
return store.load()
|
||||
|
||||
# In route handler:
|
||||
@app.get('/settings/{user_id}')
|
||||
async def get_settings(
|
||||
user_id: str,
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
return get_user_settings(user_id, context)
|
||||
```
|
||||
|
||||
### 4. Testing
|
||||
|
||||
**Before:**
|
||||
```python
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
|
||||
def test_my_function():
|
||||
with patch('openhands.server.shared.config') as mock_config:
|
||||
mock_config.workspace_dir = '/test'
|
||||
result = my_function()
|
||||
assert result == expected
|
||||
```
|
||||
|
||||
**After:**
|
||||
```python
|
||||
import pytest
|
||||
from openhands.server.context import create_server_context
|
||||
|
||||
class MockServerContext:
|
||||
def get_config(self):
|
||||
mock_config = Mock()
|
||||
mock_config.workspace_dir = '/test'
|
||||
return mock_config
|
||||
|
||||
def test_my_function():
|
||||
context = MockServerContext()
|
||||
result = my_function(context)
|
||||
assert result == expected
|
||||
```
|
||||
|
||||
## SaaS Extension Example
|
||||
|
||||
The new context system makes it easy to extend OpenHands for SaaS scenarios:
|
||||
|
||||
```python
|
||||
from openhands.server.context import ServerContext, set_context_class
|
||||
|
||||
class SaaSServerContext(ServerContext):
|
||||
def __init__(self, user_id: str, org_id: str):
|
||||
self.user_id = user_id
|
||||
self.org_id = org_id
|
||||
|
||||
def get_file_store(self):
|
||||
# Return tenant-isolated file store
|
||||
return MultiTenantFileStore(self.user_id, self.org_id)
|
||||
|
||||
def get_server_config(self):
|
||||
# Return SaaS-specific configuration
|
||||
return SaaSServerConfig(org_id=self.org_id)
|
||||
|
||||
# Configure globally
|
||||
set_context_class('myapp.context.SaaSServerContext')
|
||||
|
||||
# Use in routes with tenant context
|
||||
@app.get('/tenant/{org_id}/files')
|
||||
async def get_tenant_files(
|
||||
org_id: str,
|
||||
context: SaaSServerContext = Depends(get_server_context)
|
||||
):
|
||||
file_store = context.get_file_store()
|
||||
return file_store.list_files()
|
||||
```
|
||||
|
||||
## Migration Checklist
|
||||
|
||||
- [ ] Replace `from openhands.server.shared import ...` with context injection
|
||||
- [ ] Update route handlers to use `Depends(get_server_context)`
|
||||
- [ ] Modify service classes to accept `ServerContext` parameter
|
||||
- [ ] Update tests to use mock contexts instead of patching globals
|
||||
- [ ] Remove direct imports of shared globals
|
||||
- [ ] Test that all functionality still works
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The old `openhands.server.shared` module still works but is deprecated. It will show deprecation warnings when imported. The globals are now implemented using the new context system internally.
|
||||
|
||||
## Benefits After Migration
|
||||
|
||||
1. **Better Testing**: Easy to mock dependencies without patching globals
|
||||
2. **Type Safety**: Better IDE support and type checking
|
||||
3. **Extensibility**: Easy to create custom contexts for different scenarios
|
||||
4. **Performance**: Lazy initialization reduces startup time
|
||||
5. **Maintainability**: Clear dependency relationships
|
||||
|
||||
## Common Issues
|
||||
|
||||
### Issue: Import errors during migration
|
||||
**Solution**: Make sure to import the context system correctly:
|
||||
```python
|
||||
from openhands.server.context import get_server_context, ServerContext
|
||||
```
|
||||
|
||||
### Issue: Context not available in non-route functions
|
||||
**Solution**: Pass the context as a parameter:
|
||||
```python
|
||||
def helper_function(data: str, context: ServerContext):
|
||||
config = context.get_config()
|
||||
# ... use config
|
||||
```
|
||||
|
||||
### Issue: Testing becomes more complex
|
||||
**Solution**: Create reusable mock contexts:
|
||||
```python
|
||||
# test_utils.py
|
||||
class TestServerContext(ServerContext):
|
||||
def __init__(self):
|
||||
self.mock_config = create_mock_config()
|
||||
self.mock_file_store = create_mock_file_store()
|
||||
|
||||
def get_config(self):
|
||||
return self.mock_config
|
||||
|
||||
def get_file_store(self):
|
||||
return self.mock_file_store
|
||||
```
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you encounter issues during migration:
|
||||
|
||||
1. Check the examples in `examples/saas_extension.py`
|
||||
2. Look at the implementation in `openhands/server/context/`
|
||||
3. Review existing route handlers that have been migrated
|
||||
4. Create an issue if you find bugs or need clarification
|
||||
12
README.md
12
README.md
@@ -11,7 +11,7 @@
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/stargazers"><img src="https://img.shields.io/github/stars/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="Stargazers"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://dub.sh/openhands"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Join our Slack community"></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Join our Discord community"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="Credits"></a>
|
||||
<br/>
|
||||
@@ -79,17 +79,17 @@ You'll find OpenHands running at [http://localhost:3000](http://localhost:3000)
|
||||
You can also run OpenHands directly with Docker:
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -142,7 +142,7 @@ troubleshooting resources, and advanced configuration options.
|
||||
OpenHands is a community-driven project, and we welcome contributions from everyone. We do most of our communication
|
||||
through Slack, so this is the best place to start, but we also are happy to have you contact us on Discord or Github:
|
||||
|
||||
- [Join our Slack workspace](https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA) - Here we talk about research, architecture, and future development.
|
||||
- [Join our Slack workspace](https://dub.sh/openhands) - Here we talk about research, architecture, and future development.
|
||||
- [Join our Discord server](https://discord.gg/ESHStjSjD4) - This is a community-run server for general discussion, questions, and feedback.
|
||||
- [Read or post Github Issues](https://github.com/All-Hands-AI/OpenHands/issues) - Check out the issues we're working on, or add your own ideas.
|
||||
|
||||
@@ -160,7 +160,7 @@ See the monthly OpenHands roadmap [here](https://github.com/orgs/All-Hands-AI/pr
|
||||
|
||||
## 📜 License
|
||||
|
||||
Distributed under the MIT License. See [`LICENSE`](./LICENSE) for more information.
|
||||
Distributed under the MIT License, with the exception of the `enterprise/` folder. See [`LICENSE`](./LICENSE) for more information.
|
||||
|
||||
## 🙏 Acknowledgements
|
||||
|
||||
|
||||
10
README_CN.md
10
README_CN.md
@@ -12,7 +12,7 @@
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/stargazers"><img src="https://img.shields.io/github/stars/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="Stargazers"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="加入我们的Slack社区"></a>
|
||||
<a href="https://dub.sh/openhands"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="加入我们的Slack社区"></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="加入我们的Discord社区"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="致谢"></a>
|
||||
<br/>
|
||||
@@ -51,17 +51,17 @@ OpenHands也可以使用Docker在本地系统上运行。
|
||||
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56
|
||||
```
|
||||
|
||||
> **注意**: 如果您在0.44版本之前使用过OpenHands,您可能需要运行 `mv ~/.openhands-state ~/.openhands` 来将对话历史迁移到新位置。
|
||||
@@ -107,7 +107,7 @@ docker run -it --rm --pull=always \
|
||||
OpenHands是一个社区驱动的项目,我们欢迎每个人的贡献。我们大部分沟通
|
||||
通过Slack进行,因此这是开始的最佳场所,但我们也很乐意您通过Discord或Github与我们联系:
|
||||
|
||||
- [加入我们的Slack工作空间](https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA) - 这里我们讨论研究、架构和未来发展。
|
||||
- [加入我们的Slack工作空间](https://dub.sh/openhands) - 这里我们讨论研究、架构和未来发展。
|
||||
- [加入我们的Discord服务器](https://discord.gg/ESHStjSjD4) - 这是一个社区运营的服务器,用于一般讨论、问题和反馈。
|
||||
- [阅读或发布Github问题](https://github.com/All-Hands-AI/OpenHands/issues) - 查看我们正在处理的问题,或添加您自己的想法。
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/stargazers"><img src="https://img.shields.io/github/stars/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="Stargazers"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/LICENSE"><img src="https://img.shields.io/github/license/All-Hands-AI/OpenHands?style=for-the-badge&color=blue" alt="MIT License"></a>
|
||||
<br/>
|
||||
<a href="https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Slackコミュニティに参加"></a>
|
||||
<a href="https://dub.sh/openhands"><img src="https://img.shields.io/badge/Slack-Join%20Us-red?logo=slack&logoColor=white&style=for-the-badge" alt="Slackコミュニティに参加"></a>
|
||||
<a href="https://discord.gg/ESHStjSjD4"><img src="https://img.shields.io/badge/Discord-Join%20Us-purple?logo=discord&logoColor=white&style=for-the-badge" alt="Discordコミュニティに参加"></a>
|
||||
<a href="https://github.com/All-Hands-AI/OpenHands/blob/main/CREDITS.md"><img src="https://img.shields.io/badge/Project-Credits-blue?style=for-the-badge&color=FFE165&logo=github&logoColor=white" alt="クレジット"></a>
|
||||
<br/>
|
||||
@@ -42,17 +42,17 @@ OpenHandsはDockerを利用してローカル環境でも実行できます。
|
||||
> 公共ネットワークで実行していますか?[Hardened Docker Installation Guide](https://docs.all-hands.dev/usage/runtimes/docker#hardened-docker-installation)を参照して、ネットワークバインディングの制限や追加のセキュリティ対策を実施してください。
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56
|
||||
```
|
||||
|
||||
**注**: バージョン0.44以前のOpenHandsを使用していた場合は、会話履歴を移行するために `mv ~/.openhands-state ~/.openhands` を実行してください。
|
||||
|
||||
230
REFACTOR_PLAN.md
230
REFACTOR_PLAN.md
@@ -1,230 +0,0 @@
|
||||
# OpenHands Server Context Refactoring Plan
|
||||
|
||||
## Problem Statement
|
||||
|
||||
The current OpenHands architecture has globals in `server/shared.py` that are initialized at import time based on environment variables. This creates several issues for the SaaS version:
|
||||
|
||||
1. **Import-time dependencies**: All globals are created when modules are imported
|
||||
2. **Hard to extend**: SaaS can't easily override or extend components
|
||||
3. **CI/CD issues**: Everything depends on env vars being set correctly at import time
|
||||
4. **Per-user behavior**: Difficult to implement per-user/per-request behavior
|
||||
5. **Outside repo issues**: Hard to run SaaS from outside repo due to import dependencies
|
||||
|
||||
## Current Problematic Globals
|
||||
|
||||
From `openhands/server/shared.py`:
|
||||
- `config: OpenHandsConfig` - Core app configuration
|
||||
- `server_config: ServerConfig` - Server-specific configuration
|
||||
- `file_store: FileStore` - File storage implementation
|
||||
- `sio: socketio.AsyncServer` - Socket.IO server instance
|
||||
- `conversation_manager` - Conversation management implementation
|
||||
- `monitoring_listener` - Monitoring implementation
|
||||
- `SettingsStoreImpl`, `SecretsStoreImpl`, `ConversationStoreImpl` - Storage implementations
|
||||
|
||||
## Solution: ServerContext Pattern
|
||||
|
||||
### 1. Create ServerContext Base Class
|
||||
|
||||
Create `openhands/server/context/server_context.py`:
|
||||
|
||||
```python
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
import socketio
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.server.config.server_config import ServerConfig
|
||||
from openhands.storage.files import FileStore
|
||||
# ... other imports
|
||||
|
||||
class ServerContext(ABC):
|
||||
"""Base class for server context that holds all server dependencies.
|
||||
|
||||
This replaces the global variables in shared.py and allows for:
|
||||
- Dependency injection
|
||||
- Easy extensibility for SaaS
|
||||
- Per-request contexts
|
||||
- Testability
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._config: Optional[OpenHandsConfig] = None
|
||||
self._server_config: Optional[ServerConfig] = None
|
||||
self._file_store: Optional[FileStore] = None
|
||||
# ... other cached instances
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self) -> OpenHandsConfig:
|
||||
"""Get the OpenHands configuration"""
|
||||
|
||||
@abstractmethod
|
||||
def get_server_config(self) -> ServerConfig:
|
||||
"""Get the server configuration"""
|
||||
|
||||
@abstractmethod
|
||||
def get_file_store(self) -> FileStore:
|
||||
"""Get the file store implementation"""
|
||||
|
||||
# ... other abstract methods for all current globals
|
||||
```
|
||||
|
||||
### 2. Create Default Implementation
|
||||
|
||||
Create `openhands/server/context/default_server_context.py`:
|
||||
|
||||
```python
|
||||
class DefaultServerContext(ServerContext):
|
||||
"""Default implementation that maintains current behavior"""
|
||||
|
||||
def get_config(self) -> OpenHandsConfig:
|
||||
if self._config is None:
|
||||
self._config = load_openhands_config()
|
||||
return self._config
|
||||
|
||||
def get_server_config(self) -> ServerConfig:
|
||||
if self._server_config is None:
|
||||
self._server_config = load_server_config()
|
||||
return self._server_config
|
||||
|
||||
# ... implement all methods with current logic
|
||||
```
|
||||
|
||||
### 3. Context Provider System
|
||||
|
||||
Create `openhands/server/context/context_provider.py`:
|
||||
|
||||
```python
|
||||
from fastapi import Request
|
||||
from openhands.utils.import_utils import get_impl
|
||||
|
||||
_context_class: Optional[str] = None
|
||||
|
||||
def set_context_class(context_class: str):
|
||||
"""Set the server context class to use"""
|
||||
global _context_class
|
||||
_context_class = context_class
|
||||
|
||||
async def get_server_context(request: Request) -> ServerContext:
|
||||
"""Get server context from request, with caching"""
|
||||
context = getattr(request.state, 'server_context', None)
|
||||
if context:
|
||||
return context
|
||||
|
||||
# Use configured context class or default
|
||||
context_cls_name = _context_class or 'openhands.server.context.default_server_context.DefaultServerContext'
|
||||
context_cls = get_impl(ServerContext, context_cls_name)
|
||||
context = context_cls()
|
||||
|
||||
request.state.server_context = context
|
||||
return context
|
||||
```
|
||||
|
||||
### 4. Update Shared.py (Backward Compatibility)
|
||||
|
||||
Keep `shared.py` for backward compatibility but make it use the context:
|
||||
|
||||
```python
|
||||
# openhands/server/shared.py
|
||||
from openhands.server.context.default_server_context import DefaultServerContext
|
||||
|
||||
# Create default context for backward compatibility
|
||||
_default_context = DefaultServerContext()
|
||||
|
||||
# Expose globals for backward compatibility
|
||||
config = _default_context.get_config()
|
||||
server_config = _default_context.get_server_config()
|
||||
file_store = _default_context.get_file_store()
|
||||
# ... etc
|
||||
```
|
||||
|
||||
### 5. Update Routes to Use Context
|
||||
|
||||
Update all route files to use dependency injection:
|
||||
|
||||
```python
|
||||
# Example: openhands/server/routes/settings.py
|
||||
from openhands.server.context import get_server_context
|
||||
|
||||
@app.get('/settings')
|
||||
async def get_settings(
|
||||
request: Request,
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
config = context.get_config()
|
||||
# ... use config instead of importing from shared
|
||||
```
|
||||
|
||||
## Benefits for SaaS
|
||||
|
||||
### 1. Easy Extension
|
||||
|
||||
SaaS can create their own context:
|
||||
|
||||
```python
|
||||
# In SaaS repo: saas/server_context.py
|
||||
from openhands.server.context import ServerContext
|
||||
|
||||
class SaaSServerContext(ServerContext):
|
||||
def get_server_config(self) -> ServerConfig:
|
||||
# Return SaaS-specific config with enterprise features
|
||||
return SaaSServerConfig()
|
||||
|
||||
def get_conversation_manager(self) -> ConversationManager:
|
||||
# Return multi-tenant conversation manager
|
||||
return MultiTenantConversationManager()
|
||||
```
|
||||
|
||||
### 2. Per-Request Contexts
|
||||
|
||||
SaaS can implement per-user contexts:
|
||||
|
||||
```python
|
||||
class PerUserServerContext(ServerContext):
|
||||
def __init__(self, user_id: str, org_id: str):
|
||||
super().__init__()
|
||||
self.user_id = user_id
|
||||
self.org_id = org_id
|
||||
|
||||
def get_file_store(self) -> FileStore:
|
||||
# Return user-specific file store
|
||||
return UserFileStore(self.user_id, self.org_id)
|
||||
```
|
||||
|
||||
### 3. No Import-Time Dependencies
|
||||
|
||||
SaaS can run without setting environment variables at import time:
|
||||
|
||||
```python
|
||||
# In SaaS startup
|
||||
from openhands.server.context import set_context_class
|
||||
set_context_class('saas.server_context.SaaSServerContext')
|
||||
```
|
||||
|
||||
## Migration Strategy
|
||||
|
||||
### Phase 1: Create Context System
|
||||
1. Create ServerContext base class and default implementation
|
||||
2. Create context provider system
|
||||
3. Update shared.py for backward compatibility
|
||||
|
||||
### Phase 2: Update Routes Gradually
|
||||
1. Update one route at a time to use context injection
|
||||
2. Test each route to ensure no regressions
|
||||
3. Keep backward compatibility during transition
|
||||
|
||||
### Phase 3: Clean Up
|
||||
1. Remove globals from shared.py once all routes are updated
|
||||
2. Update documentation
|
||||
3. Create examples for SaaS extension
|
||||
|
||||
## Implementation Order
|
||||
|
||||
1. `openhands/server/context/server_context.py` - Base class
|
||||
2. `openhands/server/context/default_server_context.py` - Default implementation
|
||||
3. `openhands/server/context/context_provider.py` - Provider system
|
||||
4. `openhands/server/context/__init__.py` - Public API
|
||||
5. Update `openhands/server/shared.py` for backward compatibility
|
||||
6. Update routes one by one to use context injection
|
||||
7. Update tests to use context system
|
||||
8. Documentation and examples
|
||||
|
||||
This approach provides a clean migration path while maintaining backward compatibility and enabling the SaaS extensibility requirements.
|
||||
@@ -1,206 +0,0 @@
|
||||
# OpenHands Server Globals Refactoring - Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Successfully refactored OpenHands server globals in `shared.py` and `server_config.py` to enable SaaS extensibility without import-time dependencies. The refactoring introduces a dependency injection pattern using a `ServerContext` system that maintains backward compatibility while enabling multi-tenant SaaS scenarios.
|
||||
|
||||
## Problem Solved
|
||||
|
||||
### Before Refactoring
|
||||
- **Global variables on import**: `shared.py` created globals like `config`, `server_config`, `file_store`, `sio`, etc. on module import
|
||||
- **Import-time side effects**: Loading the module triggered configuration loading and dependency initialization
|
||||
- **SaaS integration issues**: External SaaS repos had CI/CD problems due to environment variable dependencies
|
||||
- **Testing difficulties**: Hard to mock dependencies due to global state
|
||||
- **No extensibility**: Impossible to customize behavior for different tenants or environments
|
||||
|
||||
### After Refactoring
|
||||
- **Dependency injection**: Clean `ServerContext` pattern with lazy initialization
|
||||
- **No import-time side effects**: Dependencies only loaded when actually needed
|
||||
- **SaaS extensibility**: Easy to create custom contexts for multi-tenant scenarios
|
||||
- **Better testability**: Easy to mock contexts for testing
|
||||
- **Backward compatibility**: Existing code continues to work with deprecation warnings
|
||||
|
||||
## Architecture Changes
|
||||
|
||||
### New Context System
|
||||
|
||||
```
|
||||
openhands/server/context/
|
||||
├── __init__.py # Public API
|
||||
├── server_context.py # Abstract base class
|
||||
├── default_server_context.py # Default implementation
|
||||
└── context_provider.py # Dependency injection system
|
||||
```
|
||||
|
||||
### Key Components
|
||||
|
||||
1. **ServerContext (Abstract Base Class)**
|
||||
- Defines interface for all server dependencies
|
||||
- 9 abstract methods for different dependency types
|
||||
- Extensible for SaaS implementations
|
||||
|
||||
2. **DefaultServerContext**
|
||||
- Maintains exact behavior of original shared.py
|
||||
- Lazy initialization of all dependencies
|
||||
- No import-time side effects
|
||||
|
||||
3. **Context Provider System**
|
||||
- `get_server_context()` for FastAPI dependency injection
|
||||
- `set_context_class()` for global configuration
|
||||
- `create_server_context()` for testing/CLI usage
|
||||
|
||||
4. **Backward Compatibility Layer**
|
||||
- `shared.py` now uses `__getattr__` for lazy loading
|
||||
- All existing imports continue to work
|
||||
- Deprecation warnings guide migration
|
||||
|
||||
## SaaS Extensibility
|
||||
|
||||
### Multi-Tenant Context Example
|
||||
|
||||
```python
|
||||
class SaaSServerContext(ServerContext):
|
||||
def __init__(self, user_id: str, org_id: str):
|
||||
self.user_id = user_id
|
||||
self.org_id = org_id
|
||||
|
||||
def get_file_store(self):
|
||||
# Return tenant-isolated file store
|
||||
return MultiTenantFileStore(self.user_id, self.org_id)
|
||||
|
||||
def get_server_config(self):
|
||||
# Return SaaS-specific configuration
|
||||
return SaaSServerConfig(org_id=self.org_id)
|
||||
|
||||
# Configure globally
|
||||
set_context_class('myapp.context.SaaSServerContext')
|
||||
```
|
||||
|
||||
### Benefits for SaaS
|
||||
- **Per-tenant isolation**: Different storage, config, and features per organization
|
||||
- **Enterprise features**: Easy to add billing, advanced monitoring, etc.
|
||||
- **Scalable architecture**: Context per request enables horizontal scaling
|
||||
- **Clean separation**: SaaS code stays in external repo, extends OpenHands cleanly
|
||||
|
||||
## Migration Path
|
||||
|
||||
### For OpenHands Core
|
||||
- **Phase 1**: Refactoring complete, backward compatibility maintained
|
||||
- **Phase 2**: Gradually migrate routes to use dependency injection
|
||||
- **Phase 3**: Remove deprecated shared.py (future release)
|
||||
|
||||
### For SaaS Implementations
|
||||
- **Immediate**: Can use new context system for new features
|
||||
- **Gradual**: Migrate existing code using migration guide
|
||||
- **Benefits**: Cleaner architecture, better testing, easier deployment
|
||||
|
||||
## Files Created/Modified
|
||||
|
||||
### New Files
|
||||
- `openhands/server/context/__init__.py` - Public API
|
||||
- `openhands/server/context/server_context.py` - Abstract base class
|
||||
- `openhands/server/context/default_server_context.py` - Default implementation
|
||||
- `openhands/server/context/context_provider.py` - Dependency injection
|
||||
- `examples/saas_extension.py` - SaaS extension example
|
||||
- `MIGRATION_GUIDE.md` - Detailed migration instructions
|
||||
- `test_refactor.py` - Comprehensive test suite
|
||||
|
||||
### Modified Files
|
||||
- `openhands/server/shared.py` - Backward compatibility layer
|
||||
|
||||
## Testing Results
|
||||
|
||||
Comprehensive test suite with 5 test categories:
|
||||
|
||||
1. ✅ **Context System**: Import, creation, class switching
|
||||
2. ✅ **Backward Compatibility**: Lazy loading, attribute access
|
||||
3. ✅ **Abstract Base Class**: Proper abstraction, required methods
|
||||
4. ✅ **Default Context**: Instantiation, method availability
|
||||
5. ✅ **SaaS Example**: Multi-tenant context structure
|
||||
|
||||
**Result: 5/5 tests passed** 🎉
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### New Way (Recommended)
|
||||
```python
|
||||
from fastapi import Depends
|
||||
from openhands.server.context import get_server_context, ServerContext
|
||||
|
||||
@app.get('/conversations')
|
||||
async def get_conversations(
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
config = context.get_config()
|
||||
conversation_manager = context.get_conversation_manager()
|
||||
return conversation_manager.list_conversations()
|
||||
```
|
||||
|
||||
### Old Way (Still Works)
|
||||
```python
|
||||
from openhands.server.shared import config, conversation_manager
|
||||
|
||||
@app.get('/conversations')
|
||||
async def get_conversations():
|
||||
# Shows deprecation warning but works
|
||||
return conversation_manager.list_conversations()
|
||||
```
|
||||
|
||||
### SaaS Extension
|
||||
```python
|
||||
# In SaaS application startup
|
||||
from openhands.server.context import set_context_class
|
||||
set_context_class('myapp.context.SaaSServerContext')
|
||||
|
||||
# Routes automatically get tenant-aware context
|
||||
@app.get('/tenant/{org_id}/conversations')
|
||||
async def get_tenant_conversations(
|
||||
org_id: str,
|
||||
context: SaaSServerContext = Depends(get_server_context)
|
||||
):
|
||||
# context.org_id and context.user_id available
|
||||
# All dependencies are tenant-isolated
|
||||
conversation_manager = context.get_conversation_manager()
|
||||
return conversation_manager.list_conversations()
|
||||
```
|
||||
|
||||
## Benefits Achieved
|
||||
|
||||
### For OpenHands Core
|
||||
- ✅ **Better Architecture**: Clean dependency injection pattern
|
||||
- ✅ **Improved Testing**: Easy to mock dependencies
|
||||
- ✅ **No Breaking Changes**: Full backward compatibility
|
||||
- ✅ **Performance**: Lazy loading reduces startup time
|
||||
- ✅ **Type Safety**: Better IDE support and type checking
|
||||
|
||||
### For SaaS Implementations
|
||||
- ✅ **Multi-Tenancy**: Per-organization contexts and isolation
|
||||
- ✅ **Extensibility**: Easy to add enterprise features
|
||||
- ✅ **Clean Integration**: No need to fork OpenHands
|
||||
- ✅ **Deployment Flexibility**: Can run from external repos
|
||||
- ✅ **CI/CD Fixes**: No more environment variable dependencies
|
||||
|
||||
### For Development
|
||||
- ✅ **Maintainability**: Clear dependency relationships
|
||||
- ✅ **Debugging**: Easier to trace dependency issues
|
||||
- ✅ **Documentation**: Clear migration path and examples
|
||||
- ✅ **Future-Proof**: Extensible architecture for new features
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. **Immediate**: Refactoring is complete and tested
|
||||
2. **Short-term**: Begin migrating core routes to use dependency injection
|
||||
3. **Medium-term**: SaaS implementations can adopt new context system
|
||||
4. **Long-term**: Remove deprecated shared.py in future major release
|
||||
|
||||
## Conclusion
|
||||
|
||||
The refactoring successfully addresses all the original problems:
|
||||
|
||||
- ❌ **Import-time dependencies** → ✅ **Lazy initialization**
|
||||
- ❌ **Global state pollution** → ✅ **Clean dependency injection**
|
||||
- ❌ **SaaS integration issues** → ✅ **Multi-tenant context system**
|
||||
- ❌ **Testing difficulties** → ✅ **Easy mocking and testing**
|
||||
- ❌ **No extensibility** → ✅ **Pluggable context implementations**
|
||||
|
||||
The new architecture enables OpenHands to support SaaS scenarios while maintaining full backward compatibility and improving the overall codebase quality.
|
||||
@@ -219,6 +219,14 @@ correct_num = 5
|
||||
api_key = ""
|
||||
model = "gpt-4o"
|
||||
|
||||
# Example routing LLM configuration for multimodal model routing
|
||||
# Uncomment and configure to enable model routing with a secondary model
|
||||
#[llm.secondary_model]
|
||||
#model = "kimi-k2"
|
||||
#api_key = ""
|
||||
#for_routing = true
|
||||
#max_input_tokens = 128000
|
||||
|
||||
|
||||
#################################### Agent ###################################
|
||||
# Configuration for agents (group name starts with 'agent')
|
||||
@@ -480,3 +488,14 @@ type = "noop"
|
||||
|
||||
# Run the runtime sandbox container in privileged mode for use with docker-in-docker
|
||||
#privileged = false
|
||||
|
||||
#################################### Model Routing ############################
|
||||
# Configuration for experimental model routing feature
|
||||
# Enables intelligent switching between different LLM models for specific purposes
|
||||
##############################################################################
|
||||
[model_routing]
|
||||
# Router to use for model selection
|
||||
# Available options:
|
||||
# - "noop_router" (default): No routing, always uses primary LLM
|
||||
# - "multimodal_router": A router that switches between primary and secondary models, depending on whether the input is multimodal or not
|
||||
#router_name = "noop_router"
|
||||
|
||||
@@ -12,7 +12,7 @@ services:
|
||||
- SANDBOX_API_HOSTNAME=host.docker.internal
|
||||
- DOCKER_HOST_ADDR=host.docker.internal
|
||||
#
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/all-hands-ai/runtime:0.55-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-ghcr.io/all-hands-ai/runtime:0.56-nikolaik}
|
||||
- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234}
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -3,9 +3,9 @@ repos:
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
- id: end-of-file-fixer
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/)
|
||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||
- id: check-yaml
|
||||
args: ["--allow-multiple-documents"]
|
||||
- id: debug-statements
|
||||
@@ -28,19 +28,28 @@ repos:
|
||||
entry: ruff check --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
args: [--fix, --unsafe-fixes]
|
||||
exclude: third_party/
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
entry: ruff format --config dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
exclude: third_party/
|
||||
exclude: ^(third_party/|enterprise/)
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
[types-requests, types-setuptools, types-pyyaml, types-toml, types-docker, types-Markdown, pydantic, lxml]
|
||||
[
|
||||
types-requests,
|
||||
types-setuptools,
|
||||
types-pyyaml,
|
||||
types-toml,
|
||||
types-docker,
|
||||
types-Markdown,
|
||||
pydantic,
|
||||
lxml,
|
||||
]
|
||||
# To see gaps add `--html-report mypy-report/`
|
||||
entry: mypy --config-file dev_config/python/mypy.ini openhands/
|
||||
always_run: true
|
||||
|
||||
@@ -7,9 +7,10 @@ warn_unreachable = True
|
||||
warn_redundant_casts = True
|
||||
no_implicit_optional = True
|
||||
strict_optional = True
|
||||
disable_error_code = type-abstract
|
||||
|
||||
# Exclude third-party runtime directory from type checking
|
||||
exclude = third_party/
|
||||
exclude = (third_party/|enterprise/)
|
||||
|
||||
[mypy-openhands.memory.condenser.impl.*]
|
||||
disable_error_code = override
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Exclude third-party runtime directory from linting
|
||||
exclude = ["third_party/"]
|
||||
exclude = ["third_party/", "enterprise/"]
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
|
||||
@@ -7,7 +7,7 @@ services:
|
||||
image: openhands:latest
|
||||
container_name: openhands-app-${DATE:-}
|
||||
environment:
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik}
|
||||
- SANDBOX_RUNTIME_CONTAINER_IMAGE=${SANDBOX_RUNTIME_CONTAINER_IMAGE:-docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik}
|
||||
#- SANDBOX_USER_ID=${SANDBOX_USER_ID:-1234} # enable this only if you want a specific non-root sandbox user but you will have to manually adjust permissions of ~/.openhands for this user
|
||||
- WORKSPACE_MOUNT_PATH=${WORKSPACE_BASE:-$PWD/workspace}
|
||||
ports:
|
||||
|
||||
@@ -1,272 +0,0 @@
|
||||
# OpenHands Extensibility Migration Guide
|
||||
|
||||
This guide explains how to migrate from the old global variable approach to the new factory-based extensibility system.
|
||||
|
||||
## Overview
|
||||
|
||||
OpenHands has been refactored to eliminate import-time dependencies on environment variables and global state. This enables external repositories to cleanly extend OpenHands without configuration conflicts.
|
||||
|
||||
## The Problem We Solved
|
||||
|
||||
### Before (Problematic)
|
||||
```python
|
||||
# In OpenHands shared.py - loaded at import time
|
||||
config = Config() # Reads environment variables
|
||||
server_config = ServerConfig() # More environment variables
|
||||
|
||||
# External repos had to:
|
||||
# 1. Set environment variables before importing OpenHands
|
||||
# 2. Deal with global state conflicts
|
||||
# 3. Couldn't easily override specific behaviors
|
||||
```
|
||||
|
||||
### After (Clean)
|
||||
```python
|
||||
# External repos can now:
|
||||
from openhands.server.factory import create_openhands_app
|
||||
|
||||
app = create_openhands_app(
|
||||
context_factory=lambda: MyCustomContext(),
|
||||
include_oss_routes=False
|
||||
)
|
||||
```
|
||||
|
||||
## Migration Paths
|
||||
|
||||
### 1. For External Repositories (Recommended)
|
||||
|
||||
**Old Way (Don't do this):**
|
||||
```python
|
||||
# external_repo/main.py
|
||||
import os
|
||||
os.environ['OPENHANDS_CONFIG_CLS'] = 'my_config.MyConfig'
|
||||
os.environ['CONVERSATION_MANAGER_CLASS'] = 'my_manager.MyManager'
|
||||
|
||||
from openhands.server.app import app # Imports with global state
|
||||
```
|
||||
|
||||
**New Way (Recommended):**
|
||||
```python
|
||||
# external_repo/main.py
|
||||
from openhands.server.factory import create_openhands_app
|
||||
from external_repo.context import ExternalRepoContext
|
||||
|
||||
def create_app():
|
||||
return create_openhands_app(
|
||||
context_factory=lambda: ExternalRepoContext(),
|
||||
include_oss_routes=False, # Skip OSS-specific routes
|
||||
title='My Enterprise Platform'
|
||||
)
|
||||
|
||||
app = create_app()
|
||||
|
||||
# Add your own routes
|
||||
@app.get('/enterprise/dashboard')
|
||||
async def dashboard():
|
||||
return {'status': 'enterprise'}
|
||||
```
|
||||
|
||||
### 2. For OpenHands Core Development
|
||||
|
||||
**Old Way:**
|
||||
```python
|
||||
# In route handlers
|
||||
from openhands.server.shared import config, server_config
|
||||
|
||||
@app.get('/example')
|
||||
async def example_route():
|
||||
storage_path = config.workspace_base
|
||||
app_mode = server_config.app_mode
|
||||
```
|
||||
|
||||
**New Way:**
|
||||
```python
|
||||
# In route handlers
|
||||
from fastapi import Depends
|
||||
from openhands.server.context import get_server_context, ServerContext
|
||||
|
||||
@app.get('/example')
|
||||
async def example_route(
|
||||
context: ServerContext = Depends(get_server_context)
|
||||
):
|
||||
config = context.get_config()
|
||||
server_config = context.get_server_config()
|
||||
storage_path = config.workspace_base
|
||||
app_mode = server_config.app_mode
|
||||
```
|
||||
|
||||
## Custom Context Implementation
|
||||
|
||||
### Step 1: Create Your Context Class
|
||||
|
||||
```python
|
||||
# my_extension/context.py
|
||||
from openhands.server.context.server_context import ServerContext
|
||||
|
||||
class MyCustomContext(ServerContext):
|
||||
def __init__(self, tenant_id: str = 'default'):
|
||||
super().__init__()
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def get_config(self):
|
||||
"""Override with tenant-specific configuration."""
|
||||
config = super().get_config()
|
||||
config.workspace_base = f'/data/tenants/{self.tenant_id}/workspace'
|
||||
return config
|
||||
|
||||
def get_server_config(self):
|
||||
"""Override server configuration."""
|
||||
server_config = super().get_server_config()
|
||||
server_config.app_mode = 'ENTERPRISE'
|
||||
server_config.enable_billing = True
|
||||
return server_config
|
||||
```
|
||||
|
||||
### Step 2: Create Your FastAPI App
|
||||
|
||||
```python
|
||||
# my_extension/app.py
|
||||
from openhands.server.factory import create_openhands_app
|
||||
from my_extension.context import MyCustomContext
|
||||
|
||||
def create_my_app():
|
||||
# Option A: Extend OpenHands app directly
|
||||
app = create_openhands_app(
|
||||
context_factory=lambda: MyCustomContext(),
|
||||
title='My Enterprise Platform'
|
||||
)
|
||||
|
||||
# Add your routes
|
||||
@app.get('/enterprise/status')
|
||||
async def enterprise_status():
|
||||
return {'mode': 'enterprise'}
|
||||
|
||||
return app
|
||||
|
||||
# Option B: Create your own app and mount OpenHands
|
||||
from fastapi import FastAPI
|
||||
|
||||
def create_my_app_with_mount():
|
||||
main_app = FastAPI(title='My Platform')
|
||||
|
||||
openhands_app = create_openhands_app(
|
||||
context_factory=lambda: MyCustomContext()
|
||||
)
|
||||
|
||||
main_app.mount('/openhands', openhands_app)
|
||||
|
||||
@main_app.get('/my-dashboard')
|
||||
async def dashboard():
|
||||
return {'dashboard': 'data'}
|
||||
|
||||
return main_app
|
||||
```
|
||||
|
||||
### Step 3: Run Your Application
|
||||
|
||||
```python
|
||||
# my_extension/main.py
|
||||
import uvicorn
|
||||
from my_extension.app import create_my_app
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = create_my_app()
|
||||
uvicorn.run(app, host='0.0.0.0', port=8000)
|
||||
```
|
||||
|
||||
## Advanced Patterns
|
||||
|
||||
### Multi-Tenant Context
|
||||
|
||||
```python
|
||||
class MultiTenantContext(ServerContext):
|
||||
def __init__(self, request: Request):
|
||||
super().__init__()
|
||||
# Extract tenant from request
|
||||
self.tenant_id = request.headers.get('X-Tenant-ID', 'default')
|
||||
|
||||
def get_file_store(self):
|
||||
# Return tenant-isolated file store
|
||||
return TenantFileStore(tenant_id=self.tenant_id)
|
||||
|
||||
# Use with factory
|
||||
def create_tenant_context(request: Request):
|
||||
return MultiTenantContext(request)
|
||||
|
||||
app = create_openhands_app(
|
||||
context_factory=create_tenant_context
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Lifespan Management
|
||||
|
||||
```python
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@asynccontextmanager
|
||||
async def my_lifespan(app: FastAPI):
|
||||
# Startup
|
||||
print("Starting my custom services...")
|
||||
await initialize_my_database()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
print("Shutting down my custom services...")
|
||||
await cleanup_my_database()
|
||||
|
||||
app = create_openhands_app(
|
||||
context_factory=MyContext,
|
||||
custom_lifespan=my_lifespan
|
||||
)
|
||||
```
|
||||
|
||||
## Testing Your Extension
|
||||
|
||||
```python
|
||||
# tests/test_my_extension.py
|
||||
from fastapi.testclient import TestClient
|
||||
from my_extension.app import create_my_app
|
||||
|
||||
def test_my_extension():
|
||||
app = create_my_app()
|
||||
client = TestClient(app)
|
||||
|
||||
# Test your custom routes
|
||||
response = client.get('/enterprise/status')
|
||||
assert response.status_code == 200
|
||||
assert response.json()['mode'] == 'enterprise'
|
||||
|
||||
# Test OpenHands routes still work
|
||||
response = client.get('/api/health')
|
||||
assert response.status_code == 200
|
||||
```
|
||||
|
||||
## Benefits of the New Approach
|
||||
|
||||
1. **No Environment Variables**: Configuration is done through code, not environment variables
|
||||
2. **Clean Separation**: External repos don't modify OpenHands globals
|
||||
3. **Dependency Injection**: Proper FastAPI dependency injection patterns
|
||||
4. **Testability**: Easy to mock contexts for testing
|
||||
5. **Flexibility**: Can create multiple apps with different configurations
|
||||
6. **No Import-Time Side Effects**: Safe to import OpenHands modules
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The old `openhands.server.shared` module still works but is deprecated. It will show deprecation warnings and should be migrated to the new context system.
|
||||
|
||||
## Common Pitfalls
|
||||
|
||||
1. **Don't set environment variables**: Use the factory pattern instead
|
||||
2. **Don't import `openhands.server.app` directly**: Use the factory to create your own app
|
||||
3. **Don't modify global state**: Use dependency injection through contexts
|
||||
4. **Don't forget to override dependencies**: Use `app.dependency_overrides` if needed
|
||||
|
||||
## Getting Help
|
||||
|
||||
If you need help migrating your extension, please:
|
||||
1. Check the examples in `examples/external_repo_extension.py`
|
||||
2. Look at the test cases for patterns
|
||||
3. Open an issue with your specific use case
|
||||
|
||||
The new system is designed to be more flexible and maintainable while enabling clean extensibility for all types of OpenHands deployments.
|
||||
@@ -1,17 +1,36 @@
|
||||
# Setup
|
||||
# OpenHands Documentation
|
||||
|
||||
```
|
||||
This directory contains the documentation for OpenHands. The documentation is automatically synchronized with the [All-Hands-AI/docs](https://github.com/All-Hands-AI/docs) repository, which hosts the unified documentation site using Mintlify.
|
||||
|
||||
## Documentation Structure
|
||||
|
||||
The documentation files in this directory are automatically included in the main documentation site via Git submodules. When you make changes to documentation in this repository, they will be automatically synchronized to the docs repository.
|
||||
|
||||
## How It Works
|
||||
|
||||
1. **Automatic Sync**: When documentation changes are pushed to the `main` branch, a GitHub Action automatically notifies the docs repository
|
||||
2. **Submodule Update**: The docs repository updates its submodule reference to include your latest changes
|
||||
3. **Site Rebuild**: Mintlify automatically rebuilds and deploys the documentation site
|
||||
|
||||
## Making Documentation Changes
|
||||
|
||||
Simply edit the documentation files in this directory as usual. The synchronization happens automatically when changes are merged to the main branch.
|
||||
|
||||
## Local Development
|
||||
|
||||
For local documentation development in this repository only:
|
||||
|
||||
```bash
|
||||
npm install -g mint
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```
|
||||
# or
|
||||
yarn global add mint
|
||||
```
|
||||
|
||||
# Preview
|
||||
|
||||
```
|
||||
# Preview local changes
|
||||
mint dev
|
||||
```
|
||||
|
||||
For the complete unified documentation site, work with the [All-Hands-AI/docs](https://github.com/All-Hands-AI/docs) repository.
|
||||
|
||||
## Configuration
|
||||
|
||||
The Mintlify configuration (`docs.json`) has been moved to the root of the [All-Hands-AI/docs](https://github.com/All-Hands-AI/docs) repository to enable unified documentation across multiple repositories.
|
||||
|
||||
@@ -208,7 +208,7 @@
|
||||
},
|
||||
"footer": {
|
||||
"socials": {
|
||||
"slack": "https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA",
|
||||
"slack": "https://dub.sh/openhands",
|
||||
"github": "https://github.com/All-Hands-AI/OpenHands",
|
||||
"discord": "https://discord.gg/ESHStjSjD4"
|
||||
}
|
||||
|
||||
@@ -124,7 +124,7 @@ This tagging approach allows OpenHands to efficiently manage both development an
|
||||
OpenHands supports both bind mounts and Docker named volumes in SandboxConfig.volumes:
|
||||
|
||||
- Bind mount: "/abs/host/path:/container/path[:mode]"
|
||||
- Named volume: "volume:<name>:/container/path[:mode]" or any non-absolute host spec treated as a named volume
|
||||
- Named volume: "volume:`<name>`:/container/path[:mode]" or any non-absolute host spec treated as a named volume
|
||||
|
||||
Overlay mode (copy-on-write layer) is supported for bind mounts by appending ":overlay" to the mode (e.g., ":ro,overlay").
|
||||
To enable overlay COW, set SANDBOX_VOLUME_OVERLAYS to a writable host directory; per-container upper/work dirs are created under it. If SANDBOX_VOLUME_OVERLAYS is unset, overlay mounts are skipped.
|
||||
|
||||
@@ -8,6 +8,11 @@ description: This page outlines all available configuration options for OpenHand
|
||||
In GUI Mode, any settings applied through the Settings UI will take precedence.
|
||||
</Note>
|
||||
|
||||
<Note>
|
||||
**Looking for Environment Variables?** All configuration options can also be set using environment variables.
|
||||
See the [Environment Variables Reference](./environment-variables) for a complete list with examples.
|
||||
</Note>
|
||||
|
||||
## Location of the `config.toml` File
|
||||
|
||||
When running OpenHands in CLI, headless, or development mode, you can use a project-specific `config.toml` file for configuration, which must be
|
||||
@@ -18,6 +23,11 @@ specify a different path to the `config.toml` file.
|
||||
|
||||
The core configuration options are defined in the `[core]` section of the `config.toml` file.
|
||||
|
||||
Core configuration options can be set as environment variables by converting to uppercase. For example:
|
||||
- `debug` → `DEBUG`
|
||||
- `cache_dir` → `CACHE_DIR`
|
||||
- `runtime` → `RUNTIME`
|
||||
|
||||
### Workspace
|
||||
- `workspace_base` **(Deprecated)**
|
||||
- Type: `str`
|
||||
@@ -141,6 +151,11 @@ The LLM (Large Language Model) configuration options are defined in the `[llm]`
|
||||
|
||||
To use these with the docker command, pass in `-e LLM_<option>`. Example: `-e LLM_NUM_RETRIES`.
|
||||
|
||||
All LLM configuration options can be set as environment variables by prefixing with `LLM_` and converting to uppercase. For example:
|
||||
- `model` → `LLM_MODEL`
|
||||
- `api_key` → `LLM_API_KEY`
|
||||
- `base_url` → `LLM_BASE_URL`
|
||||
|
||||
<Note>
|
||||
For development setups, you can also define custom named LLM configurations. See [Custom LLM Configurations](./llms/custom-llm-configs) for details.
|
||||
</Note>
|
||||
@@ -277,6 +292,11 @@ For development setups, you can also define custom named LLM configurations. See
|
||||
|
||||
The agent configuration options are defined in the `[agent]` and `[agent.<agent_name>]` sections of the `config.toml` file.
|
||||
|
||||
Agent configuration options can be set as environment variables by prefixing with `AGENT_` and converting to uppercase. For example:
|
||||
- `enable_browsing` → `AGENT_ENABLE_BROWSING`
|
||||
- `function_calling` → `AGENT_FUNCTION_CALLING`
|
||||
- `llm_config` → `AGENT_LLM_CONFIG`
|
||||
|
||||
### LLM Configuration
|
||||
- `llm_config`
|
||||
- Type: `str`
|
||||
@@ -328,6 +348,11 @@ The sandbox configuration options are defined in the `[sandbox]` section of the
|
||||
|
||||
To use these with the docker command, pass in `-e SANDBOX_<option>`. Example: `-e SANDBOX_TIMEOUT`.
|
||||
|
||||
All sandbox configuration options can be set as environment variables by prefixing with `SANDBOX_` and converting to uppercase. For example:
|
||||
- `timeout` → `SANDBOX_TIMEOUT`
|
||||
- `user_id` → `SANDBOX_USER_ID`
|
||||
- `base_container_image` → `SANDBOX_BASE_CONTAINER_IMAGE`
|
||||
|
||||
### Execution
|
||||
- `timeout`
|
||||
- Type: `int`
|
||||
@@ -390,6 +415,10 @@ The security configuration options are defined in the `[security]` section of th
|
||||
|
||||
To use these with the docker command, pass in `-e SECURITY_<option>`. Example: `-e SECURITY_CONFIRMATION_MODE`.
|
||||
|
||||
All security configuration options can be set as environment variables by prefixing with `SECURITY_` and converting to uppercase. For example:
|
||||
- `confirmation_mode` → `SECURITY_CONFIRMATION_MODE`
|
||||
- `security_analyzer` → `SECURITY_SECURITY_ANALYZER`
|
||||
|
||||
### Confirmation Mode
|
||||
- `confirmation_mode`
|
||||
- Type: `bool`
|
||||
|
||||
251
docs/usage/environment-variables.mdx
Normal file
251
docs/usage/environment-variables.mdx
Normal file
@@ -0,0 +1,251 @@
|
||||
---
|
||||
title: Environment Variables Reference
|
||||
description: Complete reference of all environment variables supported by OpenHands
|
||||
---
|
||||
|
||||
This page provides a reference of environment variables that can be used to configure OpenHands. Environment variables provide an alternative to TOML configuration files and are particularly useful for containerized deployments, CI/CD pipelines, and cloud environments.
|
||||
|
||||
## Environment Variable Naming Convention
|
||||
|
||||
OpenHands follows a consistent naming pattern for environment variables:
|
||||
|
||||
- **Core settings**: Direct uppercase mapping (e.g., `debug` → `DEBUG`)
|
||||
- **LLM settings**: Prefixed with `LLM_` (e.g., `model` → `LLM_MODEL`)
|
||||
- **Agent settings**: Prefixed with `AGENT_` (e.g., `enable_browsing` → `AGENT_ENABLE_BROWSING`)
|
||||
- **Sandbox settings**: Prefixed with `SANDBOX_` (e.g., `timeout` → `SANDBOX_TIMEOUT`)
|
||||
- **Security settings**: Prefixed with `SECURITY_` (e.g., `confirmation_mode` → `SECURITY_CONFIRMATION_MODE`)
|
||||
|
||||
## Core Configuration Variables
|
||||
|
||||
These variables correspond to the `[core]` section in `config.toml`:
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `DEBUG` | boolean | `false` | Enable debug logging throughout the application |
|
||||
| `DISABLE_COLOR` | boolean | `false` | Disable colored output in terminal |
|
||||
| `CACHE_DIR` | string | `"/tmp/cache"` | Directory path for caching |
|
||||
| `SAVE_TRAJECTORY_PATH` | string | `"./trajectories"` | Path to store conversation trajectories |
|
||||
| `REPLAY_TRAJECTORY_PATH` | string | `""` | Path to load and replay a trajectory file |
|
||||
| `FILE_STORE_PATH` | string | `"/tmp/file_store"` | File store directory path |
|
||||
| `FILE_STORE` | string | `"memory"` | File store type (`memory`, `local`, etc.) |
|
||||
| `FILE_UPLOADS_MAX_FILE_SIZE_MB` | integer | `0` | Maximum file upload size in MB (0 = no limit) |
|
||||
| `FILE_UPLOADS_RESTRICT_FILE_TYPES` | boolean | `false` | Whether to restrict file upload types |
|
||||
| `FILE_UPLOADS_ALLOWED_EXTENSIONS` | list | `[".*"]` | List of allowed file extensions for uploads |
|
||||
| `MAX_BUDGET_PER_TASK` | float | `0.0` | Maximum budget per task (0.0 = no limit) |
|
||||
| `MAX_ITERATIONS` | integer | `100` | Maximum number of iterations per task |
|
||||
| `RUNTIME` | string | `"docker"` | Runtime environment (`docker`, `local`, `cli`, etc.) |
|
||||
| `DEFAULT_AGENT` | string | `"CodeActAgent"` | Default agent class to use |
|
||||
| `JWT_SECRET` | string | auto-generated | JWT secret for authentication |
|
||||
| `RUN_AS_OPENHANDS` | boolean | `true` | Whether to run as the openhands user |
|
||||
| `VOLUMES` | string | `""` | Volume mounts in format `host:container[:mode]` |
|
||||
|
||||
## LLM Configuration Variables
|
||||
|
||||
These variables correspond to the `[llm]` section in `config.toml`:
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `LLM_MODEL` | string | `"claude-3-5-sonnet-20241022"` | LLM model to use |
|
||||
| `LLM_API_KEY` | string | `""` | API key for the LLM provider |
|
||||
| `LLM_BASE_URL` | string | `""` | Custom API base URL |
|
||||
| `LLM_API_VERSION` | string | `""` | API version to use |
|
||||
| `LLM_TEMPERATURE` | float | `0.0` | Sampling temperature |
|
||||
| `LLM_TOP_P` | float | `1.0` | Top-p sampling parameter |
|
||||
| `LLM_MAX_INPUT_TOKENS` | integer | `0` | Maximum input tokens (0 = no limit) |
|
||||
| `LLM_MAX_OUTPUT_TOKENS` | integer | `0` | Maximum output tokens (0 = no limit) |
|
||||
| `LLM_MAX_MESSAGE_CHARS` | integer | `30000` | Maximum characters that will be sent to the model in observation content |
|
||||
| `LLM_TIMEOUT` | integer | `0` | API timeout in seconds (0 = no timeout) |
|
||||
| `LLM_NUM_RETRIES` | integer | `8` | Number of retry attempts |
|
||||
| `LLM_RETRY_MIN_WAIT` | integer | `15` | Minimum wait time between retries (seconds) |
|
||||
| `LLM_RETRY_MAX_WAIT` | integer | `120` | Maximum wait time between retries (seconds) |
|
||||
| `LLM_RETRY_MULTIPLIER` | float | `2.0` | Exponential backoff multiplier |
|
||||
| `LLM_DROP_PARAMS` | boolean | `false` | Drop unsupported parameters without error |
|
||||
| `LLM_CACHING_PROMPT` | boolean | `true` | Enable prompt caching if supported |
|
||||
| `LLM_DISABLE_VISION` | boolean | `false` | Disable vision capabilities for cost reduction |
|
||||
| `LLM_CUSTOM_LLM_PROVIDER` | string | `""` | Custom LLM provider name |
|
||||
| `LLM_OLLAMA_BASE_URL` | string | `""` | Base URL for Ollama API |
|
||||
| `LLM_INPUT_COST_PER_TOKEN` | float | `0.0` | Cost per input token |
|
||||
| `LLM_OUTPUT_COST_PER_TOKEN` | float | `0.0` | Cost per output token |
|
||||
| `LLM_REASONING_EFFORT` | string | `""` | Reasoning effort for o-series models (`low`, `medium`, `high`) |
|
||||
|
||||
### AWS Configuration
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `LLM_AWS_ACCESS_KEY_ID` | string | `""` | AWS access key ID |
|
||||
| `LLM_AWS_SECRET_ACCESS_KEY` | string | `""` | AWS secret access key |
|
||||
| `LLM_AWS_REGION_NAME` | string | `""` | AWS region name |
|
||||
|
||||
## Agent Configuration Variables
|
||||
|
||||
These variables correspond to the `[agent]` section in `config.toml`:
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `AGENT_LLM_CONFIG` | string | `""` | Name of LLM config group to use |
|
||||
| `AGENT_FUNCTION_CALLING` | boolean | `true` | Enable function calling |
|
||||
| `AGENT_ENABLE_BROWSING` | boolean | `false` | Enable browsing delegate |
|
||||
| `AGENT_ENABLE_LLM_EDITOR` | boolean | `false` | Enable LLM-based editor |
|
||||
| `AGENT_ENABLE_JUPYTER` | boolean | `false` | Enable Jupyter integration |
|
||||
| `AGENT_ENABLE_HISTORY_TRUNCATION` | boolean | `true` | Enable history truncation |
|
||||
| `AGENT_ENABLE_PROMPT_EXTENSIONS` | boolean | `true` | Enable microagents (prompt extensions) |
|
||||
| `AGENT_DISABLED_MICROAGENTS` | list | `[]` | List of microagents to disable |
|
||||
|
||||
## Sandbox Configuration Variables
|
||||
|
||||
These variables correspond to the `[sandbox]` section in `config.toml`:
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `SANDBOX_TIMEOUT` | integer | `120` | Sandbox timeout in seconds |
|
||||
| `SANDBOX_USER_ID` | integer | `1000` | User ID for sandbox processes |
|
||||
| `SANDBOX_BASE_CONTAINER_IMAGE` | string | `"nikolaik/python-nodejs:python3.12-nodejs22"` | Base container image |
|
||||
| `SANDBOX_USE_HOST_NETWORK` | boolean | `false` | Use host networking |
|
||||
| `SANDBOX_RUNTIME_BINDING_ADDRESS` | string | `"0.0.0.0"` | Runtime binding address |
|
||||
| `SANDBOX_ENABLE_AUTO_LINT` | boolean | `false` | Enable automatic linting |
|
||||
| `SANDBOX_INITIALIZE_PLUGINS` | boolean | `true` | Initialize sandbox plugins |
|
||||
| `SANDBOX_RUNTIME_EXTRA_DEPS` | string | `""` | Extra dependencies to install |
|
||||
| `SANDBOX_RUNTIME_STARTUP_ENV_VARS` | dict | `{}` | Environment variables for runtime |
|
||||
| `SANDBOX_BROWSERGYM_EVAL_ENV` | string | `""` | BrowserGym evaluation environment |
|
||||
| `SANDBOX_VOLUMES` | string | `""` | Volume mounts (replaces deprecated workspace settings) |
|
||||
| `SANDBOX_RUNTIME_CONTAINER_IMAGE` | string | `""` | Pre-built runtime container image |
|
||||
| `SANDBOX_KEEP_RUNTIME_ALIVE` | boolean | `false` | Keep runtime alive after session ends |
|
||||
| `SANDBOX_PAUSE_CLOSED_RUNTIMES` | boolean | `false` | Pause instead of stopping closed runtimes |
|
||||
| `SANDBOX_CLOSE_DELAY` | integer | `300` | Delay before closing idle runtimes (seconds) |
|
||||
| `SANDBOX_RM_ALL_CONTAINERS` | boolean | `false` | Remove all containers when stopping |
|
||||
| `SANDBOX_ENABLE_GPU` | boolean | `false` | Enable GPU support |
|
||||
| `SANDBOX_CUDA_VISIBLE_DEVICES` | string | `""` | Specify GPU devices by ID |
|
||||
| `SANDBOX_VSCODE_PORT` | integer | auto | Specific port for VSCode server |
|
||||
|
||||
### Sandbox Environment Variables
|
||||
Variables prefixed with `SANDBOX_ENV_` are passed through to the sandbox environment:
|
||||
|
||||
| Environment Variable | Description |
|
||||
|---------------------|-------------|
|
||||
| `SANDBOX_ENV_*` | Any variable with this prefix is passed to the sandbox (e.g., `SANDBOX_ENV_OPENAI_API_KEY`) |
|
||||
|
||||
## Security Configuration Variables
|
||||
|
||||
These variables correspond to the `[security]` section in `config.toml`:
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `SECURITY_CONFIRMATION_MODE` | boolean | `false` | Enable confirmation mode for actions |
|
||||
| `SECURITY_SECURITY_ANALYZER` | string | `"llm"` | Security analyzer to use (`llm`, `invariant`) |
|
||||
| `SECURITY_ENABLE_SECURITY_ANALYZER` | boolean | `true` | Enable security analysis |
|
||||
|
||||
## Debug and Logging Variables
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `DEBUG` | boolean | `false` | Enable general debug logging |
|
||||
| `DEBUG_LLM` | boolean | `false` | Enable LLM-specific debug logging |
|
||||
| `DEBUG_RUNTIME` | boolean | `false` | Enable runtime debug logging |
|
||||
| `LOG_TO_FILE` | boolean | auto | Log to file (auto-enabled when DEBUG=true) |
|
||||
|
||||
## Runtime-Specific Variables
|
||||
|
||||
### Docker Runtime
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `SANDBOX_VOLUME_OVERLAYS` | string | `""` | Volume overlay configurations |
|
||||
|
||||
### Remote Runtime
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `SANDBOX_API_KEY` | string | `""` | API key for remote runtime |
|
||||
| `SANDBOX_REMOTE_RUNTIME_API_URL` | string | `""` | Remote runtime API URL |
|
||||
|
||||
### Local Runtime
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `RUNTIME_URL` | string | `""` | Runtime URL for local runtime |
|
||||
| `RUNTIME_URL_PATTERN` | string | `""` | Runtime URL pattern |
|
||||
| `RUNTIME_ID` | string | `""` | Runtime identifier |
|
||||
| `LOCAL_RUNTIME_MODE` | string | `""` | Enable local runtime mode (`1` to enable) |
|
||||
|
||||
## Integration Variables
|
||||
|
||||
### GitHub Integration
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `GITHUB_TOKEN` | string | `""` | GitHub personal access token |
|
||||
|
||||
### Third-Party API Keys
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `OPENAI_API_KEY` | string | `""` | OpenAI API key |
|
||||
| `ANTHROPIC_API_KEY` | string | `""` | Anthropic API key |
|
||||
| `GOOGLE_API_KEY` | string | `""` | Google API key |
|
||||
| `AZURE_API_KEY` | string | `""` | Azure API key |
|
||||
| `TAVILY_API_KEY` | string | `""` | Tavily search API key |
|
||||
|
||||
## Server Configuration Variables
|
||||
|
||||
These are primarily used when running OpenHands as a server:
|
||||
|
||||
| Environment Variable | Type | Default | Description |
|
||||
|---------------------|------|---------|-------------|
|
||||
| `FRONTEND_PORT` | integer | `3000` | Frontend server port |
|
||||
| `BACKEND_PORT` | integer | `8000` | Backend server port |
|
||||
| `FRONTEND_HOST` | string | `"localhost"` | Frontend host address |
|
||||
| `BACKEND_HOST` | string | `"localhost"` | Backend host address |
|
||||
| `WEB_HOST` | string | `"localhost"` | Web server host |
|
||||
| `SERVE_FRONTEND` | boolean | `true` | Whether to serve frontend |
|
||||
|
||||
## Deprecated Variables
|
||||
|
||||
These variables are deprecated and should be replaced:
|
||||
|
||||
| Environment Variable | Replacement | Description |
|
||||
|---------------------|-------------|-------------|
|
||||
| `WORKSPACE_BASE` | `SANDBOX_VOLUMES` | Use volume mounting instead |
|
||||
| `WORKSPACE_MOUNT_PATH` | `SANDBOX_VOLUMES` | Use volume mounting instead |
|
||||
| `WORKSPACE_MOUNT_PATH_IN_SANDBOX` | `SANDBOX_VOLUMES` | Use volume mounting instead |
|
||||
| `WORKSPACE_MOUNT_REWRITE` | `SANDBOX_VOLUMES` | Use volume mounting instead |
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Setup with OpenAI
|
||||
```bash
|
||||
export LLM_MODEL="gpt-4o"
|
||||
export LLM_API_KEY="your-openai-api-key"
|
||||
export DEBUG=true
|
||||
```
|
||||
|
||||
### Docker Deployment with Custom Volumes
|
||||
```bash
|
||||
export RUNTIME="docker"
|
||||
export SANDBOX_VOLUMES="/host/workspace:/workspace:rw,/host/data:/data:ro"
|
||||
export SANDBOX_TIMEOUT=300
|
||||
```
|
||||
|
||||
### Remote Runtime Configuration
|
||||
```bash
|
||||
export RUNTIME="remote"
|
||||
export SANDBOX_API_KEY="your-remote-api-key"
|
||||
export SANDBOX_REMOTE_RUNTIME_API_URL="https://your-runtime-api.com"
|
||||
```
|
||||
|
||||
### Security-Enhanced Setup
|
||||
```bash
|
||||
export SECURITY_CONFIRMATION_MODE=true
|
||||
export SECURITY_SECURITY_ANALYZER="llm"
|
||||
export DEBUG_RUNTIME=true
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
1. **Boolean Values**: Environment variables expecting boolean values accept `true`/`false`, `1`/`0`, or `yes`/`no` (case-insensitive).
|
||||
|
||||
2. **List Values**: Lists should be provided as Python literal strings, e.g., `AGENT_DISABLED_MICROAGENTS='["microagent1", "microagent2"]'`.
|
||||
|
||||
3. **Dictionary Values**: Dictionaries should be provided as Python literal strings, e.g., `SANDBOX_RUNTIME_STARTUP_ENV_VARS='{"KEY": "value"}'`.
|
||||
|
||||
4. **Precedence**: Environment variables take precedence over TOML configuration files.
|
||||
|
||||
5. **Docker Usage**: When using Docker, pass environment variables with the `-e` flag:
|
||||
```bash
|
||||
docker run -e LLM_API_KEY="your-key" -e DEBUG=true openhands/openhands
|
||||
```
|
||||
|
||||
6. **Validation**: Invalid environment variable values will be logged as errors and fall back to defaults.
|
||||
@@ -89,7 +89,7 @@ If you would like to set things up more systematically, you can:
|
||||
1. **Search existing issues**: Check our [GitHub issues](https://github.com/All-Hands-AI/OpenHands/issues) to see if
|
||||
others have encountered the same problem.
|
||||
2. **Join our community**: Get help from other users and developers:
|
||||
- [Slack community](https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA)
|
||||
- [Slack community](https://dub.sh/openhands)
|
||||
- [Discord server](https://discord.gg/ESHStjSjD4)
|
||||
3. **Check our troubleshooting guide**: Common issues and solutions are documented in
|
||||
[Troubleshooting](/usage/troubleshooting/troubleshooting).
|
||||
|
||||
@@ -113,7 +113,7 @@ The conversation history will be saved in `~/.openhands/sessions`.
|
||||
```bash
|
||||
docker run -it \
|
||||
--pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e SANDBOX_USER_ID=$(id -u) \
|
||||
-e SANDBOX_VOLUMES=$SANDBOX_VOLUMES \
|
||||
-e LLM_API_KEY=$LLM_API_KEY \
|
||||
@@ -122,7 +122,7 @@ docker run -it \
|
||||
-v ~/.openhands:/.openhands \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app-$(date +%Y%m%d%H%M%S) \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55 \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56 \
|
||||
python -m openhands.cli.entry --override-cli-mode true
|
||||
```
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ Set environment variables and run the Docker command:
|
||||
|
||||
```bash
|
||||
# Set required environment variables
|
||||
export SANDBOX_VOLUMES="/path/to/workspace" # See SANDBOX_VOLUMES docs for details
|
||||
export SANDBOX_VOLUMES="/path/to/workspace:/workspace:rw" # Format: host_path:container_path:mode
|
||||
export LLM_MODEL="anthropic/claude-sonnet-4-20250514"
|
||||
export LLM_API_KEY="your-api-key"
|
||||
export SANDBOX_SELECTED_REPO="owner/repo-name" # Optional: requires GITHUB_TOKEN
|
||||
@@ -61,7 +61,7 @@ export GITHUB_TOKEN="your-token" # Required for repository operations
|
||||
# Run OpenHands
|
||||
docker run -it \
|
||||
--pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e SANDBOX_USER_ID=$(id -u) \
|
||||
-e SANDBOX_VOLUMES=$SANDBOX_VOLUMES \
|
||||
-e LLM_API_KEY=$LLM_API_KEY \
|
||||
@@ -73,7 +73,7 @@ docker run -it \
|
||||
-v ~/.openhands:/.openhands \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app-$(date +%Y%m%d%H%M%S) \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55 \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56 \
|
||||
python -m openhands.core.main -t "write a bash script that prints hi"
|
||||
```
|
||||
|
||||
|
||||
@@ -68,23 +68,23 @@ Download and install the LM Studio desktop app from [lmstudio.ai](https://lmstud
|
||||
1. Check [the installation guide](/usage/local-setup) and ensure all prerequisites are met before running OpenHands, then run:
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56
|
||||
```
|
||||
|
||||
2. Wait until the server is running (see log below):
|
||||
```
|
||||
Digest: sha256:e72f9baecb458aedb9afc2cd5bc935118d1868719e55d50da73190d3a85c674f
|
||||
Status: Image is up to date for docker.all-hands.dev/all-hands-ai/openhands:0.55
|
||||
Status: Image is up to date for docker.all-hands.dev/all-hands-ai/openhands:0.56
|
||||
Starting OpenHands...
|
||||
Running OpenHands as root
|
||||
14:22:13 - openhands:INFO: server_config.py:50 - Using config class None
|
||||
@@ -119,7 +119,7 @@ When started for the first time, OpenHands will prompt you to set up the LLM pro
|
||||
|
||||
That's it! You can now start using OpenHands with the local LLM server.
|
||||
|
||||
If you encounter any issues, let us know on [Slack](https://join.slack.com/t/openhands-ai/shared_invite/zt-3847of6xi-xuYJIPa6YIPg4ElbDWbtSA) or [Discord](https://discord.gg/ESHStjSjD4).
|
||||
If you encounter any issues, let us know on [Slack](https://dub.sh/openhands) or [Discord](https://discord.gg/ESHStjSjD4).
|
||||
|
||||
## Advanced: Alternative LLM Backends
|
||||
|
||||
|
||||
@@ -116,17 +116,17 @@ Note that you'll still need `uv` installed for the default MCP servers to work p
|
||||
<Accordion title="Docker Command (Click to expand)">
|
||||
|
||||
```bash
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik
|
||||
docker pull docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik
|
||||
|
||||
docker run -it --rm --pull=always \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.55-nikolaik \
|
||||
-e SANDBOX_RUNTIME_CONTAINER_IMAGE=docker.all-hands.dev/all-hands-ai/runtime:0.56-nikolaik \
|
||||
-e LOG_ALL_EVENTS=true \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ~/.openhands:/.openhands \
|
||||
-p 3000:3000 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name openhands-app \
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.55
|
||||
docker.all-hands.dev/all-hands-ai/openhands:0.56
|
||||
```
|
||||
|
||||
</Accordion>
|
||||
|
||||
26
enterprise/Dockerfile
Normal file
26
enterprise/Dockerfile
Normal file
@@ -0,0 +1,26 @@
|
||||
ARG OPENHANDS_VERSION=latest
|
||||
ARG BASE="ghcr.io/all-hands-ai/openhands"
|
||||
FROM ${BASE}:${OPENHANDS_VERSION}
|
||||
|
||||
# Datadog labels
|
||||
LABEL com.datadoghq.tags.service="deploy"
|
||||
LABEL com.datadoghq.tags.env="${DD_ENV}"
|
||||
|
||||
# Install Node.js v20+ and npm (which includes npx)
|
||||
RUN apt-get update && \
|
||||
apt-get install -y curl && \
|
||||
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||
apt-get install -y nodejs && \
|
||||
apt-get install -y jq gettext && \
|
||||
apt-get clean
|
||||
|
||||
RUN pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace posthog "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy
|
||||
|
||||
WORKDIR /app
|
||||
COPY enterprise .
|
||||
|
||||
RUN chown -R openhands:openhands /app && chmod -R 770 /app
|
||||
USER openhands
|
||||
|
||||
# Command will be overridden by Kubernetes deployment template
|
||||
CMD ["uvicorn", "saas_server:app", "--host", "0.0.0.0", "--port", "3000"]
|
||||
42
enterprise/Makefile
Normal file
42
enterprise/Makefile
Normal file
@@ -0,0 +1,42 @@
|
||||
BACKEND_HOST ?= "127.0.0.1"
|
||||
BACKEND_PORT = 3000
|
||||
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||
FRONTEND_PORT = 3001
|
||||
OPENHANDS_PATH ?= "../../OpenHands"
|
||||
OPENHANDS := $(OPENHANDS_PATH)
|
||||
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||
|
||||
# ANSI color codes
|
||||
GREEN=$(shell tput -Txterm setaf 2)
|
||||
YELLOW=$(shell tput -Txterm setaf 3)
|
||||
RED=$(shell tput -Txterm setaf 1)
|
||||
BLUE=$(shell tput -Txterm setaf 6)
|
||||
RESET=$(shell tput -Txterm sgr0)
|
||||
|
||||
build:
|
||||
@poetry install
|
||||
@cd $(OPENHANDS) && $(MAKE) build
|
||||
|
||||
|
||||
_run_setup:
|
||||
@echo "$(YELLOW)Starting backend server...$(RESET)"
|
||||
@cd app && FRONTEND_DIRECTORY=$(OPENHANDS_FRONTEND_PATH) poetry run uvicorn saas_server:app --host $(BACKEND_HOST) --port $(BACKEND_PORT) &
|
||||
@echo "$(YELLOW)Waiting for the backend to start...$(RESET)"
|
||||
@until nc -z localhost $(BACKEND_PORT); do sleep 0.1; done
|
||||
@echo "$(GREEN)Backend started successfully.$(RESET)"
|
||||
|
||||
run:
|
||||
@echo "$(YELLOW)Running the app...$(RESET)"
|
||||
@$(MAKE) -s _run_setup
|
||||
@cd $(OPENHANDS) && $(MAKE) -s start-frontend
|
||||
@echo "$(GREEN)Application started successfully.$(RESET)"
|
||||
|
||||
# Start backend
|
||||
start-backend:
|
||||
@echo "$(YELLOW)Starting backend...$(RESET)"
|
||||
@echo "$(OPENHANDS_FRONTEND_PATH)"
|
||||
@cd app && FRONTEND_DIRECTORY=$(OPENHANDS_FRONTEND_PATH) poetry run uvicorn saas_server:app --host $(BACKEND_HOST) --port $(BACKEND_PORT) --reload-dir $(OPENHANDS_PATH) --reload --reload-dir ./ --reload-exclude "./workspace"
|
||||
|
||||
|
||||
lint:
|
||||
@poetry run pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||
56
enterprise/README.md
Normal file
56
enterprise/README.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# OpenHands Enterprise Server
|
||||
> [!WARNING]
|
||||
> This software is licensed under the [Polyform Free Trial License](./LICENSE). This is **NOT** an open source license. Usage is limited to 30 days per calendar year without a commercial license. If you would like to use it beyond 30 days, please [contact us](https://www.all-hands.dev/contact).
|
||||
|
||||
> [!WARNING]
|
||||
> This is a work in progress and may contain bugs, incomplete features, or breaking changes.
|
||||
|
||||
This directory contains the enterprise server used by [OpenHands Cloud](https://github.com/All-Hands-AI/OpenHands-Cloud/). The official, public version of OpenHands Cloud is available at
|
||||
[app.all-hands.dev](https://app.all-hands.dev).
|
||||
|
||||
You may also want to check out the MIT-licensed [OpenHands](https://github.com/All-Hands-AI/OpenHands)
|
||||
|
||||
## Extension of OpenHands (OSS)
|
||||
|
||||
The code in `/enterprise` directory builds on top of open source (OSS) code, extending its functionality. The enterprise code is entangled with the OSS code in two ways
|
||||
|
||||
- Enterprise stacks on top of OSS. For example, the middleware in enterprise is stacked right on top of the middlewares in OSS. In `SAAS`, the middleware from BOTH repos will be present and running (which can sometimes cause conflicts)
|
||||
|
||||
- Enterprise overrides the implementation in OSS (only one is present at a time). For example, the server config SaasServerConfig which overrides [`ServerConfig`](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L8) on OSS. This is done through dynamic imports ([see here](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L37-#L45))
|
||||
|
||||
Key areas that change on `SAAS` are
|
||||
|
||||
- Authentication
|
||||
- User settings
|
||||
- etc
|
||||
|
||||
### Authentication
|
||||
|
||||
| Aspect | OSS | Enterprise |
|
||||
| ------------------------- | ------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Authentication Method** | User adds a personal access token (PAT) through the UI | User performs OAuth through the UI. The Github app provides a short-lived access token and refresh token |
|
||||
| **Token Storage** | PAT is stored in **Settings** | Token is stored in **GithubTokenManager** (a file store in our backend) |
|
||||
| **Authenticated status** | We simply check if token exists in `Settings` | We issue a signed cookie with `github_user_id` during oauth, so subsequent requests with the cookie can be considered authenticated |
|
||||
|
||||
Note that in the future, authentication will happen via keycloak. All modifications for authentication will happen in enterprise.
|
||||
|
||||
### GitHub Service
|
||||
|
||||
The github service is responsible for interacting with Github APIs. As a consequence, it uses the user's token and refreshes it if need be
|
||||
|
||||
| Aspect | OSS | Enterprise |
|
||||
| ------------------------- | -------------------------------------- | ---------------------------------------------- |
|
||||
| **Class used** | `GitHubService` | `SaaSGitHubService` |
|
||||
| **Token used** | User's PAT fetched from `Settings` | User's token fetched from `GitHubTokenManager` |
|
||||
| **Refresh functionality** | **N/A**; user provides PAT for the app | Uses the `GitHubTokenManager` to refresh |
|
||||
|
||||
NOTE: in the future we will simply replace the `GithubTokenManager` with keycloak. The `SaaSGithubService` should interact with keycloack instead.
|
||||
|
||||
# Areas that are BRITTLE!
|
||||
|
||||
## User ID vs User Token
|
||||
|
||||
- On OSS, the entire APP revolves around the Github token the user sets. `openhands/server` uses `request.state.github_token` for the entire app
|
||||
- On Enterprise, the entire APP resolves around the Github User ID. This is because the cookie sets it, so `openhands/server` AND `enterprise/server` depend on it and completly ignore `request.state.github_token` (token is fetched from `GithubTokenManager` instead)
|
||||
|
||||
Note that introducing Github User ID on OSS, for instance, will cause large breakages.
|
||||
1
enterprise/__init__.py
Normal file
1
enterprise/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# App package for OpenHands
|
||||
79
enterprise/alembic.ini
Normal file
79
enterprise/alembic.ini
Normal file
@@ -0,0 +1,79 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = DEBUG
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = DEBUG
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = DEBUG
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
2770
enterprise/allhands-realm-github-provider.json.tmpl
Normal file
2770
enterprise/allhands-realm-github-provider.json.tmpl
Normal file
File diff suppressed because it is too large
Load Diff
57
enterprise/dev_config/python/.pre-commit-config.yaml
Normal file
57
enterprise/dev_config/python/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,57 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.5.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude: docs/modules/python
|
||||
files: ^enterprise/
|
||||
- id: end-of-file-fixer
|
||||
exclude: docs/modules/python
|
||||
files: ^enterprise/
|
||||
- id: check-yaml
|
||||
files: ^enterprise/
|
||||
- id: debug-statements
|
||||
files: ^enterprise/
|
||||
- repo: https://github.com/abravalheri/validate-pyproject
|
||||
rev: v0.16
|
||||
hooks:
|
||||
- id: validate-pyproject
|
||||
types: [toml]
|
||||
files: ^enterprise/pyproject\.toml$
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.4.1
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
entry: ruff check --config enterprise/dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
args: [--fix]
|
||||
files: ^enterprise/
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
entry: ruff format --config enterprise/dev_config/python/ruff.toml
|
||||
types_or: [python, pyi, jupyter]
|
||||
files: ^enterprise/
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.9.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- types-requests
|
||||
- types-setuptools
|
||||
- types-pyyaml
|
||||
- types-toml
|
||||
- types-redis
|
||||
- lxml
|
||||
# OpenHands package in repo root
|
||||
- ./
|
||||
- stripe==11.5.0
|
||||
- pygithub==2.6.1
|
||||
# To see gaps add `--html-report mypy-report/`
|
||||
entry: mypy --config-file enterprise/dev_config/python/mypy.ini enterprise/
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
files: ^enterprise/
|
||||
17
enterprise/dev_config/python/mypy.ini
Normal file
17
enterprise/dev_config/python/mypy.ini
Normal file
@@ -0,0 +1,17 @@
|
||||
[mypy]
|
||||
warn_unused_configs = True
|
||||
ignore_missing_imports = True
|
||||
check_untyped_defs = True
|
||||
explicit_package_bases = True
|
||||
warn_unreachable = True
|
||||
warn_redundant_casts = True
|
||||
no_implicit_optional = True
|
||||
strict_optional = True
|
||||
disable_error_code = type-abstract
|
||||
exclude = (^enterprise/migrations/.*)
|
||||
|
||||
[mypy-enterprise.tests.unit.test_auth_routes.*]
|
||||
disable_error_code = union-attr
|
||||
|
||||
[mypy-enterprise.sync.install_gitlab_webhooks.*]
|
||||
disable_error_code = redundant-cast
|
||||
31
enterprise/dev_config/python/ruff.toml
Normal file
31
enterprise/dev_config/python/ruff.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[lint]
|
||||
select = [
|
||||
"E",
|
||||
"W",
|
||||
"F",
|
||||
"I",
|
||||
"Q",
|
||||
"B",
|
||||
]
|
||||
|
||||
ignore = [
|
||||
"E501",
|
||||
"B003",
|
||||
"B007",
|
||||
"B008", # Allow function calls in argument defaults (FastAPI Query pattern)
|
||||
"B009",
|
||||
"B010",
|
||||
"B904",
|
||||
"B018",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"app/migrations/*"
|
||||
]
|
||||
|
||||
[lint.flake8-quotes]
|
||||
docstring-quotes = "double"
|
||||
inline-quotes = "single"
|
||||
|
||||
[format]
|
||||
quote-style = "single"
|
||||
0
enterprise/experiments/__init__.py
Normal file
0
enterprise/experiments/__init__.py
Normal file
47
enterprise/experiments/constants.py
Normal file
47
enterprise/experiments/constants.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import os
|
||||
|
||||
import posthog
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
# Initialize PostHog
|
||||
posthog.api_key = os.environ.get('POSTHOG_CLIENT_KEY', 'phc_placeholder')
|
||||
posthog.host = os.environ.get('POSTHOG_HOST', 'https://us.i.posthog.com')
|
||||
|
||||
# Log PostHog configuration with masked API key for security
|
||||
api_key = posthog.api_key
|
||||
if api_key and len(api_key) > 8:
|
||||
masked_key = f'{api_key[:4]}...{api_key[-4:]}'
|
||||
else:
|
||||
masked_key = 'not_set_or_too_short'
|
||||
logger.info('posthog_configuration', extra={'posthog_api_key_masked': masked_key})
|
||||
|
||||
# Global toggle for the experiment manager
|
||||
ENABLE_EXPERIMENT_MANAGER = (
|
||||
os.environ.get('ENABLE_EXPERIMENT_MANAGER', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Get the current experiment type from environment variable
|
||||
# If None, no experiment is running
|
||||
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT = os.environ.get(
|
||||
'EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT', ''
|
||||
)
|
||||
# System prompt experiment toggle
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT = os.environ.get(
|
||||
'EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', ''
|
||||
)
|
||||
|
||||
EXPERIMENT_CLAUDE4_VS_GPT5 = os.environ.get('EXPERIMENT_CLAUDE4_VS_GPT5', '')
|
||||
|
||||
EXPERIMENT_CONDENSER_MAX_STEP = os.environ.get('EXPERIMENT_CONDENSER_MAX_STEP', '')
|
||||
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:experiment_config',
|
||||
extra={
|
||||
'enable_experiment_manager': ENABLE_EXPERIMENT_MANAGER,
|
||||
'experiment_litellm_default_model_experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'experiment_system_prompt_experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'experiment_claude4_vs_gpt5_experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'experiment_condenser_max_step': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
},
|
||||
)
|
||||
93
enterprise/experiments/experiment_manager.py
Normal file
93
enterprise/experiments/experiment_manager.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from experiments.constants import (
|
||||
ENABLE_EXPERIMENT_MANAGER,
|
||||
)
|
||||
from experiments.experiment_versions import (
|
||||
handle_claude4_vs_gpt5_experiment,
|
||||
handle_condenser_max_step_experiment,
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.experiments.experiment_manager import ExperimentManager
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
class SaaSExperimentManager(ExperimentManager):
|
||||
@staticmethod
|
||||
def run_conversation_variant_test(
|
||||
user_id, conversation_id, conversation_settings
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Run conversation variant test and potentially modify the conversation settings
|
||||
based on the PostHog feature flags.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings that may include convo_id and llm_model
|
||||
|
||||
Returns:
|
||||
The modified conversation settings
|
||||
"""
|
||||
logger.debug(
|
||||
'experiment_manager:run_conversation_variant_test:started',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Skip all experiment processing if the experiment manager is disabled
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_conversation_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Apply conversation-scoped experiments
|
||||
conversation_settings = handle_claude4_vs_gpt5_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
)
|
||||
conversation_settings = handle_condenser_max_step_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
)
|
||||
|
||||
return conversation_settings
|
||||
|
||||
@staticmethod
|
||||
def run_config_variant_test(
|
||||
user_id: str | None, conversation_id: str, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Run agent config variant test and potentially modify the OpenHands config
|
||||
based on the current experiment type and PostHog feature flags.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
config: The OpenHands configuration
|
||||
|
||||
Returns:
|
||||
The modified OpenHands configuration
|
||||
"""
|
||||
logger.info(
|
||||
'experiment_manager:run_config_variant_test:started',
|
||||
extra={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Skip all experiment processing if the experiment manager is disabled
|
||||
if not ENABLE_EXPERIMENT_MANAGER:
|
||||
logger.info(
|
||||
'experiment_manager:run_config_variant_test:skipped',
|
||||
extra={'reason': 'experiment_manager_disabled'},
|
||||
)
|
||||
return config
|
||||
|
||||
# Pass the entire OpenHands config to the system prompt experiment
|
||||
# Let the experiment handler directly modify the config as needed
|
||||
modified_config = handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config
|
||||
)
|
||||
|
||||
# Condenser max step experiment is applied via conversation variant test,
|
||||
# not config variant test. Return modified config from system prompt only.
|
||||
return modified_config
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT
|
||||
from server.constants import (
|
||||
IS_FEATURE_ENV,
|
||||
build_litellm_proxy_model_path,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def handle_litellm_default_model_experiment(
|
||||
user_id, conversation_id, conversation_settings
|
||||
):
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings
|
||||
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT:
|
||||
logger.info(
|
||||
'experiment_manager:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='model_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'model_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
# Set the model based on the feature flag variant
|
||||
if enabled_variant == 'claude37':
|
||||
# Use the shared utility to construct the LiteLLM proxy model path
|
||||
model = build_litellm_proxy_model_path('claude-3-7-sonnet-20250219')
|
||||
# Update the conversation settings with the selected model
|
||||
conversation_settings.llm_model = model
|
||||
else:
|
||||
# Update the conversation settings with the default model for the current version
|
||||
conversation_settings.llm_model = get_default_litellm_model()
|
||||
|
||||
return conversation_settings
|
||||
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
System prompt experiment handler.
|
||||
|
||||
This module contains the handler for the system prompt experiment that uses
|
||||
the PostHog variant as the system prompt filename.
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
|
||||
|
||||
def _get_system_prompt_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the system prompt variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||
logger.info(
|
||||
'experiment_manager_002:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='system_prompt_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='system_prompt_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'system_prompt_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_system_prompt_experiment(
|
||||
user_id, conversation_id, config: OpenHandsConfig
|
||||
) -> OpenHandsConfig:
|
||||
"""
|
||||
Handle the system prompt experiment for OpenHands config.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
config: The OpenHands configuration
|
||||
|
||||
Returns:
|
||||
Modified OpenHands configuration
|
||||
"""
|
||||
enabled_variant = _get_system_prompt_variant(user_id, conversation_id)
|
||||
|
||||
# If variant is None, experiment is not enabled or there was an error
|
||||
if enabled_variant is None:
|
||||
return config
|
||||
|
||||
# Deep copy the config to avoid modifying the original
|
||||
modified_config = copy.deepcopy(config)
|
||||
|
||||
# Set the system prompt filename based on the variant
|
||||
if enabled_variant == 'control':
|
||||
# Use the long-horizon system prompt for the control variant
|
||||
agent_config = modified_config.get_agent_config(modified_config.default_agent)
|
||||
agent_config.system_prompt_filename = 'system_prompt_long_horizon.j2'
|
||||
agent_config.enable_plan_mode = True
|
||||
elif enabled_variant == 'interactive':
|
||||
modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename = 'system_prompt_interactive.j2'
|
||||
elif enabled_variant == 'no_tools':
|
||||
modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename = 'system_prompt.j2'
|
||||
else:
|
||||
logger.error(
|
||||
'system_prompt_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'no explicit mapping; returning original config',
|
||||
},
|
||||
)
|
||||
return config
|
||||
|
||||
# Log which prompt is being used
|
||||
logger.info(
|
||||
'system_prompt_experiment:prompt_selected',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'system_prompt_filename': modified_config.get_agent_config(
|
||||
modified_config.default_agent
|
||||
).system_prompt_filename,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return modified_config
|
||||
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
LiteLLM model experiment handler.
|
||||
|
||||
This module contains the handler for the LiteLLM model experiment.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CLAUDE4_VS_GPT5
|
||||
from server.constants import (
|
||||
IS_FEATURE_ENV,
|
||||
build_litellm_proxy_model_path,
|
||||
get_default_litellm_model,
|
||||
)
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
def _get_model_variant(user_id: str | None, conversation_id: str) -> str | None:
|
||||
if not EXPERIMENT_CLAUDE4_VS_GPT5:
|
||||
logger.info(
|
||||
'experiment_manager:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_CLAUDE4_VS_GPT5, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='claude4_vs_gpt5_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='claude4_or_gpt5_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'claude4_or_gpt5_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_claude4_vs_gpt5_experiment(
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Handle the LiteLLM model experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
conversation_settings: The conversation settings
|
||||
|
||||
Returns:
|
||||
Modified conversation settings
|
||||
"""
|
||||
|
||||
enabled_variant = _get_model_variant(user_id, conversation_id)
|
||||
|
||||
if not enabled_variant:
|
||||
return conversation_settings
|
||||
|
||||
# Set the model based on the feature flag variant
|
||||
if enabled_variant == 'gpt5':
|
||||
model = build_litellm_proxy_model_path('gpt-5-2025-08-07')
|
||||
conversation_settings.llm_model = model
|
||||
else:
|
||||
conversation_settings.llm_model = get_default_litellm_model()
|
||||
|
||||
return conversation_settings
|
||||
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Condenser max step experiment handler.
|
||||
|
||||
This module contains the handler for the condenser max step experiment that tests
|
||||
different max_size values for the condenser configuration.
|
||||
"""
|
||||
|
||||
import posthog
|
||||
from experiments.constants import EXPERIMENT_CONDENSER_MAX_STEP
|
||||
from server.constants import IS_FEATURE_ENV
|
||||
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.server.session.conversation_init_data import ConversationInitData
|
||||
|
||||
|
||||
def _get_condenser_max_step_variant(user_id, conversation_id):
|
||||
"""
|
||||
Get the condenser max step variant for the experiment.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
conversation_id: The conversation ID
|
||||
|
||||
Returns:
|
||||
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||
"""
|
||||
# No-op if the specific experiment is not enabled
|
||||
if not EXPERIMENT_CONDENSER_MAX_STEP:
|
||||
logger.info(
|
||||
'experiment_manager_004:ab_testing:skipped',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'reason': 'experiment_not_enabled',
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Use experiment name as the flag key
|
||||
try:
|
||||
enabled_variant = posthog.get_feature_flag(
|
||||
EXPERIMENT_CONDENSER_MAX_STEP, conversation_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:get_feature_flag:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return None
|
||||
|
||||
# Store the experiment assignment in the database
|
||||
try:
|
||||
experiment_store = ExperimentAssignmentStore()
|
||||
experiment_store.update_experiment_variant(
|
||||
conversation_id=conversation_id,
|
||||
experiment_name='condenser_max_step_experiment',
|
||||
variant=enabled_variant,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:store_assignment:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||
return None
|
||||
|
||||
# Log the experiment event
|
||||
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||
|
||||
try:
|
||||
posthog.capture(
|
||||
distinct_id=posthog_user_id,
|
||||
event='condenser_max_step_set',
|
||||
properties={
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'original_user_id': user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'experiment_manager:posthog_capture:failed',
|
||||
extra={
|
||||
'convo_id': conversation_id,
|
||||
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
# Continue execution as this is not critical
|
||||
|
||||
logger.info(
|
||||
'posthog_capture',
|
||||
extra={
|
||||
'event': 'condenser_max_step_set',
|
||||
'posthog_user_id': posthog_user_id,
|
||||
'is_feature_env': IS_FEATURE_ENV,
|
||||
'conversation_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
},
|
||||
)
|
||||
|
||||
return enabled_variant
|
||||
|
||||
|
||||
def handle_condenser_max_step_experiment(
|
||||
user_id: str | None,
|
||||
conversation_id: str,
|
||||
conversation_settings: ConversationInitData,
|
||||
) -> ConversationInitData:
|
||||
"""
|
||||
Handle the condenser max step experiment for conversation settings.
|
||||
|
||||
We should not modify persistent user settings. Instead, apply the experiment
|
||||
variant to the conversation's in-memory settings object for this session only.
|
||||
|
||||
Variants:
|
||||
- control -> condenser_max_size = 120
|
||||
- treatment -> condenser_max_size = 80
|
||||
|
||||
Returns the (potentially) modified conversation_settings.
|
||||
"""
|
||||
|
||||
enabled_variant = _get_condenser_max_step_variant(user_id, conversation_id)
|
||||
|
||||
if enabled_variant is None:
|
||||
return conversation_settings
|
||||
|
||||
if enabled_variant == 'control':
|
||||
condenser_max_size = 120
|
||||
elif enabled_variant == 'treatment':
|
||||
condenser_max_size = 80
|
||||
else:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:unknown_variant',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'unknown variant; returning original conversation settings',
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
try:
|
||||
# Apply the variant to this conversation only; do not persist to DB.
|
||||
# Not all OpenHands versions expose `condenser_max_size` on settings.
|
||||
if hasattr(conversation_settings, 'condenser_max_size'):
|
||||
conversation_settings.condenser_max_size = condenser_max_size
|
||||
logger.info(
|
||||
'condenser_max_step_experiment:conversation_settings_applied',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'condenser_max_size': condenser_max_size,
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'condenser_max_step_experiment:field_missing_on_settings',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'reason': 'condenser_max_size not present on ConversationInitData',
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'condenser_max_step_experiment:apply_failed',
|
||||
extra={
|
||||
'user_id': user_id,
|
||||
'convo_id': conversation_id,
|
||||
'variant': enabled_variant,
|
||||
'error': str(e),
|
||||
},
|
||||
)
|
||||
return conversation_settings
|
||||
|
||||
return conversation_settings
|
||||
25
enterprise/experiments/experiment_versions/__init__.py
Normal file
25
enterprise/experiments/experiment_versions/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Experiment versions package.
|
||||
|
||||
This package contains handlers for different experiment versions.
|
||||
"""
|
||||
|
||||
from experiments.experiment_versions._001_litellm_default_model_experiment import (
|
||||
handle_litellm_default_model_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._002_system_prompt_experiment import (
|
||||
handle_system_prompt_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._003_llm_claude4_vs_gpt5_experiment import (
|
||||
handle_claude4_vs_gpt5_experiment,
|
||||
)
|
||||
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||
handle_condenser_max_step_experiment,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'handle_litellm_default_model_experiment',
|
||||
'handle_system_prompt_experiment',
|
||||
'handle_claude4_vs_gpt5_experiment',
|
||||
'handle_condenser_max_step_experiment',
|
||||
]
|
||||
0
enterprise/integrations/__init__.py
Normal file
0
enterprise/integrations/__init__.py
Normal file
70
enterprise/integrations/bitbucket/bitbucket_service.py
Normal file
70
enterprise/integrations/bitbucket/bitbucket_service.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.bitbucket.bitbucket_service import BitBucketService
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
|
||||
|
||||
class SaaSBitBucketService(BitBucketService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.info(
|
||||
f'SaaSBitBucketService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, bitbucket_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_auth_id = external_auth_id
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
bitbucket_token = None
|
||||
if self.external_auth_token:
|
||||
bitbucket_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(),
|
||||
idp=ProviderType.BITBUCKET,
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got BitBucket token {bitbucket_token} from access token: {self.external_auth_token}'
|
||||
)
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
bitbucket_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.BITBUCKET
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'Got BitBucket token {bitbucket_token.get_secret_value()} from external auth user ID: {self.external_auth_id}'
|
||||
)
|
||||
elif self.user_id:
|
||||
bitbucket_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.BITBUCKET
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got BitBucket token {bitbucket_token} from user ID: {self.user_id}'
|
||||
)
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return bitbucket_token
|
||||
692
enterprise/integrations/github/data_collector.py
Normal file
692
enterprise/integrations/github/data_collector.py
Normal file
@@ -0,0 +1,692 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from integrations.github.github_view import (
|
||||
GithubIssue,
|
||||
)
|
||||
from integrations.github.queries import PR_QUERY_BY_NODE_ID
|
||||
from integrations.models import Message
|
||||
from integrations.types import PRStatus, ResolverViewInterface
|
||||
from integrations.utils import HOST
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from storage.openhands_pr import OpenhandsPR
|
||||
from storage.openhands_pr_store import OpenhandsPRStore
|
||||
|
||||
from openhands.core.config import load_openhands_config
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.service_types import ProviderType
|
||||
from openhands.storage import get_file_store
|
||||
from openhands.storage.locations import get_conversation_dir
|
||||
|
||||
config = load_openhands_config()
|
||||
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||
|
||||
|
||||
COLLECT_GITHUB_INTERACTIONS = (
|
||||
os.getenv('COLLECT_GITHUB_INTERACTIONS', 'false') == 'true'
|
||||
)
|
||||
|
||||
|
||||
class TriggerType(str, Enum):
|
||||
ISSUE_LABEL = 'issue-label'
|
||||
ISSUE_COMMENT = 'issue-coment'
|
||||
PR_COMMENT_MACRO = 'label'
|
||||
INLINE_PR_COMMENT_MACRO = 'inline-label'
|
||||
|
||||
|
||||
class GitHubDataCollector:
|
||||
"""
|
||||
Saves data on Cloud Resolver Interactions
|
||||
|
||||
1. We always save
|
||||
- Resolver trigger (comment or label)
|
||||
- Metadata (who started the job, repo name, issue number)
|
||||
|
||||
2. We save data for the type of interaction
|
||||
a. For labelled issues, we save
|
||||
- {conversation_dir}/{conversation_id}/github_data/issue__{repo_name}_{issue_number}.json
|
||||
- issue number
|
||||
- trigger
|
||||
- metadata
|
||||
- body
|
||||
- title
|
||||
- comments
|
||||
|
||||
- {conversation_dir}/{conversation_id}/github_data/pr__{repo_name}_{pr_number}.json
|
||||
- pr_number
|
||||
- metadata
|
||||
- body
|
||||
- title
|
||||
- commits/authors
|
||||
|
||||
3. For all PRs that were opened with the resolver, we save
|
||||
- github_data/prs/{repo_name}_{pr_number}/data.json
|
||||
- pr_number
|
||||
- title
|
||||
- body
|
||||
- commits/authors
|
||||
- code diffs
|
||||
- merge status (either merged/closed)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.file_store = file_store
|
||||
self.issues_path = 'github_data/issue-{}-{}/data.json'
|
||||
self.matching_pr_path = 'github_data/pr-{}-{}/data.json'
|
||||
# self.full_saved_pr_path = 'github_data/prs/{}-{}/data.json'
|
||||
self.full_saved_pr_path = 'prs/github/{}-{}/data.json'
|
||||
self.github_integration = GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
)
|
||||
self.conversation_id = None
|
||||
|
||||
async def _get_repo_node_id(self, repo_id: str, gh_client) -> str:
|
||||
"""
|
||||
Get the new GitHub GraphQL node ID for a repository using the GitHub client.
|
||||
|
||||
Args:
|
||||
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||
gh_client: SaaSGitHubService client with authentication
|
||||
|
||||
Returns:
|
||||
New format node ID for GraphQL queries (e.g., "R_kgDOLfkiww")
|
||||
"""
|
||||
try:
|
||||
return await gh_client.get_repository_node_id(repo_id)
|
||||
except Exception:
|
||||
# Fallback to old format if REST API fails
|
||||
node_string = f'010:Repository{repo_id}'
|
||||
return base64.b64encode(node_string.encode()).decode()
|
||||
|
||||
def _create_file_name(
|
||||
self, path: str, repo_id: str, number: int, conversation_id: str | None
|
||||
):
|
||||
suffix = path.format(repo_id, number)
|
||||
|
||||
if conversation_id:
|
||||
return f'{get_conversation_dir(conversation_id)}{suffix}'
|
||||
|
||||
return suffix
|
||||
|
||||
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||
token_data = self.github_integration.get_access_token(
|
||||
installation_id # type: ignore[arg-type]
|
||||
)
|
||||
return token_data.token
|
||||
|
||||
def _check_openhands_author(self, name, login) -> bool:
|
||||
return (
|
||||
name == 'openhands'
|
||||
or login == 'openhands'
|
||||
or login == 'openhands-agent'
|
||||
or login == 'openhands-ai'
|
||||
or login == 'openhands-staging'
|
||||
or login == 'openhands-exp'
|
||||
or (login and 'openhands' in login.lower())
|
||||
)
|
||||
|
||||
def _get_issue_comments(
|
||||
self, installation_id: str, repo_name: str, issue_number: int, conversation_id
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Retrieve all comments from an issue until a comment with conversation_id is found
|
||||
"""
|
||||
|
||||
try:
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(repo_name)
|
||||
issue = repo.get_issue(issue_number)
|
||||
comments = []
|
||||
|
||||
for comment in issue.get_comments():
|
||||
comment_data = {
|
||||
'id': comment.id,
|
||||
'body': comment.body,
|
||||
'created_at': comment.created_at.isoformat(),
|
||||
'user': comment.user.login,
|
||||
}
|
||||
|
||||
# If we find a comment containing conversation_id, stop collecting comments
|
||||
if conversation_id in comment.body:
|
||||
break
|
||||
|
||||
comments.append(comment_data)
|
||||
|
||||
return comments
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def _save_data(self, path: str, data: dict[str, Any]):
|
||||
"""Save data to a path"""
|
||||
self.file_store.write(path, json.dumps(data))
|
||||
|
||||
def _save_issue(
|
||||
self,
|
||||
github_view: GithubIssue,
|
||||
trigger_type: TriggerType,
|
||||
) -> None:
|
||||
"""
|
||||
Save issue data when it's labeled with openhands
|
||||
|
||||
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
|
||||
2. Save issue snapshot (title, body, comments)
|
||||
3. Save trigger type (label)
|
||||
4. Save PR opened (if exists, this information comes later when agent has finished its task)
|
||||
- Save commit shas
|
||||
- Save author info
|
||||
5. Was PR merged or closed
|
||||
"""
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
|
||||
if not conversation_id:
|
||||
return
|
||||
|
||||
issue_number = github_view.issue_number
|
||||
file_name = self._create_file_name(
|
||||
path=self.issues_path,
|
||||
repo_id=github_view.full_repo_name,
|
||||
number=issue_number,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
|
||||
payload_data = github_view.raw_payload.message.get('payload', {})
|
||||
isssue_details = payload_data.get('issue', {})
|
||||
is_repo_private = payload_data.get('repository', {}).get('private', 'true')
|
||||
title = isssue_details.get('title', '')
|
||||
body = isssue_details.get('body', '')
|
||||
|
||||
# Get comments for the issue
|
||||
comments = self._get_issue_comments(
|
||||
github_view.installation_id,
|
||||
github_view.full_repo_name,
|
||||
issue_number,
|
||||
conversation_id,
|
||||
)
|
||||
|
||||
data = {
|
||||
'trigger': trigger_type,
|
||||
'metadata': {
|
||||
'user': github_view.user_info.username,
|
||||
'repo_name': github_view.full_repo_name,
|
||||
'is_repo_private': is_repo_private,
|
||||
'number': issue_number,
|
||||
},
|
||||
'contents': {
|
||||
'title': title,
|
||||
'body': body,
|
||||
'comments': comments,
|
||||
},
|
||||
}
|
||||
|
||||
self._save_data(file_name, data)
|
||||
logger.info(
|
||||
f'[Github]: Saved issue #{issue_number} for {github_view.full_repo_name}'
|
||||
)
|
||||
|
||||
def _get_pr_commits(self, installation_id: str, repo_name: str, pr_number: int):
|
||||
commits = []
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(repo_name)
|
||||
pr = repo.get_pull(pr_number)
|
||||
|
||||
for commit in pr.get_commits():
|
||||
commit_data = {
|
||||
'sha': commit.sha,
|
||||
'authors': commit.author.login if commit.author else None,
|
||||
'committed_date': commit.commit.committer.date.isoformat()
|
||||
if commit.commit and commit.commit.committer
|
||||
else None,
|
||||
}
|
||||
commits.append(commit_data)
|
||||
|
||||
return commits
|
||||
|
||||
def _extract_repo_metadata(self, repo_data: dict) -> dict:
|
||||
"""Extract repository metadata from GraphQL response"""
|
||||
return {
|
||||
'name': repo_data.get('name'),
|
||||
'owner': repo_data.get('owner', {}).get('login'),
|
||||
'languages': [
|
||||
lang['name'] for lang in repo_data.get('languages', {}).get('nodes', [])
|
||||
],
|
||||
}
|
||||
|
||||
def _process_commits_page(self, pr_data: dict, commits: list) -> None:
|
||||
"""Process commits from a single GraphQL page"""
|
||||
commit_nodes = pr_data.get('commits', {}).get('nodes', [])
|
||||
for commit_node in commit_nodes:
|
||||
commit = commit_node['commit']
|
||||
author_info = commit.get('author', {})
|
||||
commit_data = {
|
||||
'sha': commit['oid'],
|
||||
'message': commit['message'],
|
||||
'committed_date': commit.get('committedDate'),
|
||||
'author': {
|
||||
'name': author_info.get('name'),
|
||||
'email': author_info.get('email'),
|
||||
'github_login': author_info.get('user', {}).get('login'),
|
||||
},
|
||||
'stats': {
|
||||
'additions': commit.get('additions', 0),
|
||||
'deletions': commit.get('deletions', 0),
|
||||
'changed_files': commit.get('changedFiles', 0),
|
||||
},
|
||||
}
|
||||
commits.append(commit_data)
|
||||
|
||||
def _process_pr_comments_page(self, pr_data: dict, pr_comments: list) -> None:
|
||||
"""Process PR comments from a single GraphQL page"""
|
||||
comment_nodes = pr_data.get('comments', {}).get('nodes', [])
|
||||
for comment in comment_nodes:
|
||||
comment_data = {
|
||||
'author': comment.get('author', {}).get('login'),
|
||||
'body': comment.get('body'),
|
||||
'created_at': comment.get('createdAt'),
|
||||
'type': 'pr_comment',
|
||||
}
|
||||
pr_comments.append(comment_data)
|
||||
|
||||
def _process_review_comments_page(
|
||||
self, pr_data: dict, review_comments: list
|
||||
) -> None:
|
||||
"""Process reviews and review comments from a single GraphQL page"""
|
||||
review_nodes = pr_data.get('reviews', {}).get('nodes', [])
|
||||
for review in review_nodes:
|
||||
# Add the review itself if it has a body
|
||||
if review.get('body', '').strip():
|
||||
review_data = {
|
||||
'author': review.get('author', {}).get('login'),
|
||||
'body': review.get('body'),
|
||||
'created_at': review.get('createdAt'),
|
||||
'state': review.get('state'),
|
||||
'type': 'review',
|
||||
}
|
||||
review_comments.append(review_data)
|
||||
|
||||
# Add individual review comments
|
||||
review_comment_nodes = review.get('comments', {}).get('nodes', [])
|
||||
for review_comment in review_comment_nodes:
|
||||
review_comment_data = {
|
||||
'author': review_comment.get('author', {}).get('login'),
|
||||
'body': review_comment.get('body'),
|
||||
'created_at': review_comment.get('createdAt'),
|
||||
'type': 'review_comment',
|
||||
}
|
||||
review_comments.append(review_comment_data)
|
||||
|
||||
def _count_openhands_activity(
|
||||
self, commits: list, review_comments: list, pr_comments: list
|
||||
) -> tuple[int, int, int]:
|
||||
"""Count OpenHands commits, review comments, and general PR comments"""
|
||||
openhands_commit_count = 0
|
||||
openhands_review_comment_count = 0
|
||||
openhands_general_comment_count = 0
|
||||
|
||||
# Count commits by OpenHands (check both name and login)
|
||||
for commit in commits:
|
||||
author = commit.get('author', {})
|
||||
author_name = author.get('name', '').lower()
|
||||
author_login = (
|
||||
author.get('github_login', '').lower()
|
||||
if author.get('github_login')
|
||||
else ''
|
||||
)
|
||||
|
||||
if self._check_openhands_author(author_name, author_login):
|
||||
openhands_commit_count += 1
|
||||
|
||||
# Count review comments by OpenHands
|
||||
for review_comment in review_comments:
|
||||
author_login = (
|
||||
review_comment.get('author', '').lower()
|
||||
if review_comment.get('author')
|
||||
else ''
|
||||
)
|
||||
author_name = '' # Initialize to avoid reference before assignment
|
||||
if self._check_openhands_author(author_name, author_login):
|
||||
openhands_review_comment_count += 1
|
||||
|
||||
# Count general PR comments by OpenHands
|
||||
for pr_comment in pr_comments:
|
||||
author_login = (
|
||||
pr_comment.get('author', '').lower() if pr_comment.get('author') else ''
|
||||
)
|
||||
author_name = '' # Initialize to avoid reference before assignment
|
||||
if self._check_openhands_author(author_name, author_login):
|
||||
openhands_general_comment_count += 1
|
||||
|
||||
return (
|
||||
openhands_commit_count,
|
||||
openhands_review_comment_count,
|
||||
openhands_general_comment_count,
|
||||
)
|
||||
|
||||
def _build_final_data_structure(
|
||||
self,
|
||||
repo_data: dict,
|
||||
pr_data: dict,
|
||||
commits: list,
|
||||
pr_comments: list,
|
||||
review_comments: list,
|
||||
openhands_commit_count: int,
|
||||
openhands_review_comment_count: int,
|
||||
openhands_general_comment_count: int = 0,
|
||||
) -> dict:
|
||||
"""Build the final data structure for JSON storage"""
|
||||
|
||||
is_merged = pr_data['merged']
|
||||
merged_by = None
|
||||
merge_commit_sha = None
|
||||
if is_merged:
|
||||
merged_by = (pr_data.get('mergedBy') or {}).get('login')
|
||||
merge_commit_sha = (pr_data.get('mergeCommit') or {}).get('oid')
|
||||
|
||||
return {
|
||||
'repo_metadata': self._extract_repo_metadata(repo_data),
|
||||
'pr_metadata': {
|
||||
'username': (pr_data.get('author') or {}).get('login'),
|
||||
'number': pr_data.get('number'),
|
||||
'title': pr_data.get('title'),
|
||||
'body': pr_data.get('body'),
|
||||
'comments': pr_comments,
|
||||
},
|
||||
'commits': commits,
|
||||
'review_comments': review_comments,
|
||||
'merge_status': {
|
||||
'merged': pr_data.get('merged'),
|
||||
'merged_by': merged_by,
|
||||
'state': pr_data.get('state'),
|
||||
'merge_commit_sha': merge_commit_sha,
|
||||
},
|
||||
'openhands_stats': {
|
||||
'num_commits': openhands_commit_count,
|
||||
'num_review_comments': openhands_review_comment_count,
|
||||
'num_general_comments': openhands_general_comment_count,
|
||||
'helped_author': openhands_commit_count > 0,
|
||||
},
|
||||
}
|
||||
|
||||
async def save_full_pr(self, openhands_pr: OpenhandsPR) -> None:
|
||||
"""
|
||||
Save PR information including metadata and commit details using GraphQL
|
||||
|
||||
Saves:
|
||||
- Repo metadata (repo name, languages, contributors)
|
||||
- PR metadata (number, title, body, author, comments)
|
||||
- Commit information (sha, authors, message, stats)
|
||||
- Merge status
|
||||
- Num openhands commits
|
||||
- Num openhands review comments
|
||||
"""
|
||||
pr_number = openhands_pr.pr_number
|
||||
installation_id = openhands_pr.installation_id
|
||||
repo_id = openhands_pr.repo_id
|
||||
|
||||
# Get installation token and create Github client
|
||||
# This will fail if the user decides to revoke OpenHands' access to their repo
|
||||
# In this case, we will simply return when the exception occurs
|
||||
# This will not lead to infinite loops when processing PRs as we log number of attempts and cap max attempts independently from this
|
||||
try:
|
||||
installation_token = self._get_installation_access_token(installation_id)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'Failed to generate token for {openhands_pr.repo_name}: {e}'
|
||||
)
|
||||
return
|
||||
|
||||
gh_client = GithubServiceImpl(token=SecretStr(installation_token))
|
||||
|
||||
# Get the new format GraphQL node ID
|
||||
node_id = await self._get_repo_node_id(repo_id, gh_client)
|
||||
|
||||
# Initialize data structures
|
||||
commits: list[dict] = []
|
||||
pr_comments: list[dict] = []
|
||||
review_comments: list[dict] = []
|
||||
pr_data = None
|
||||
repo_data = None
|
||||
|
||||
# Pagination cursors
|
||||
commits_after = None
|
||||
comments_after = None
|
||||
reviews_after = None
|
||||
|
||||
# Fetch all data with pagination
|
||||
while True:
|
||||
variables = {
|
||||
'nodeId': node_id,
|
||||
'pr_number': pr_number,
|
||||
'commits_after': commits_after,
|
||||
'comments_after': comments_after,
|
||||
'reviews_after': reviews_after,
|
||||
}
|
||||
|
||||
try:
|
||||
result = await gh_client.execute_graphql_query(
|
||||
PR_QUERY_BY_NODE_ID, variables
|
||||
)
|
||||
if not result.get('data', {}).get('node', {}).get('pullRequest'):
|
||||
break
|
||||
|
||||
pr_data = result['data']['node']['pullRequest']
|
||||
repo_data = result['data']['node']
|
||||
|
||||
# Process data from this page using modular methods
|
||||
self._process_commits_page(pr_data, commits)
|
||||
self._process_pr_comments_page(pr_data, pr_comments)
|
||||
self._process_review_comments_page(pr_data, review_comments)
|
||||
|
||||
# Check pagination for all three types
|
||||
has_more_commits = (
|
||||
pr_data.get('commits', {})
|
||||
.get('pageInfo', {})
|
||||
.get('hasNextPage', False)
|
||||
)
|
||||
has_more_comments = (
|
||||
pr_data.get('comments', {})
|
||||
.get('pageInfo', {})
|
||||
.get('hasNextPage', False)
|
||||
)
|
||||
has_more_reviews = (
|
||||
pr_data.get('reviews', {})
|
||||
.get('pageInfo', {})
|
||||
.get('hasNextPage', False)
|
||||
)
|
||||
|
||||
# Update cursors
|
||||
if has_more_commits:
|
||||
commits_after = (
|
||||
pr_data.get('commits', {}).get('pageInfo', {}).get('endCursor')
|
||||
)
|
||||
else:
|
||||
commits_after = None
|
||||
|
||||
if has_more_comments:
|
||||
comments_after = (
|
||||
pr_data.get('comments', {}).get('pageInfo', {}).get('endCursor')
|
||||
)
|
||||
else:
|
||||
comments_after = None
|
||||
|
||||
if has_more_reviews:
|
||||
reviews_after = (
|
||||
pr_data.get('reviews', {}).get('pageInfo', {}).get('endCursor')
|
||||
)
|
||||
else:
|
||||
reviews_after = None
|
||||
|
||||
# Continue if there's more data to fetch
|
||||
if not (has_more_commits or has_more_comments or has_more_reviews):
|
||||
break
|
||||
|
||||
except Exception:
|
||||
logger.warning('Error fetching PR data', exc_info=True)
|
||||
return
|
||||
|
||||
if not pr_data or not repo_data:
|
||||
return
|
||||
|
||||
# Count OpenHands activity using modular method
|
||||
(
|
||||
openhands_commit_count,
|
||||
openhands_review_comment_count,
|
||||
openhands_general_comment_count,
|
||||
) = self._count_openhands_activity(commits, review_comments, pr_comments)
|
||||
|
||||
logger.info(
|
||||
f'[Github]: PR #{pr_number} - OpenHands commits: {openhands_commit_count}, review comments: {openhands_review_comment_count}, general comments: {openhands_general_comment_count}'
|
||||
)
|
||||
logger.info(
|
||||
f'[Github]: PR #{pr_number} - Total collected: {len(commits)} commits, {len(pr_comments)} PR comments, {len(review_comments)} review comments'
|
||||
)
|
||||
|
||||
# Build final data structure using modular method
|
||||
data = self._build_final_data_structure(
|
||||
repo_data,
|
||||
pr_data or {},
|
||||
commits,
|
||||
pr_comments,
|
||||
review_comments,
|
||||
openhands_commit_count,
|
||||
openhands_review_comment_count,
|
||||
openhands_general_comment_count,
|
||||
)
|
||||
|
||||
# Update the OpenhandsPR object with OpenHands statistics
|
||||
store = OpenhandsPRStore.get_instance()
|
||||
openhands_helped_author = openhands_commit_count > 0
|
||||
|
||||
# Update the PR with OpenHands statistics
|
||||
update_success = store.update_pr_openhands_stats(
|
||||
repo_id=repo_id,
|
||||
pr_number=pr_number,
|
||||
original_updated_at=openhands_pr.updated_at,
|
||||
openhands_helped_author=openhands_helped_author,
|
||||
num_openhands_commits=openhands_commit_count,
|
||||
num_openhands_review_comments=openhands_review_comment_count,
|
||||
num_openhands_general_comments=openhands_general_comment_count,
|
||||
)
|
||||
|
||||
if not update_success:
|
||||
logger.warning(
|
||||
f'[Github]: Failed to update OpenHands stats for PR #{pr_number} in repo {repo_id} - PR may have been modified concurrently'
|
||||
)
|
||||
|
||||
# Save to file
|
||||
file_name = self._create_file_name(
|
||||
path=self.full_saved_pr_path,
|
||||
repo_id=repo_id,
|
||||
number=pr_number,
|
||||
conversation_id=None,
|
||||
)
|
||||
self._save_data(file_name, data)
|
||||
logger.info(
|
||||
f'[Github]: Saved full PR #{pr_number} for repo {repo_id} with OpenHands stats: commits={openhands_commit_count}, reviews={openhands_review_comment_count}, general_comments={openhands_general_comment_count}, helped={openhands_helped_author}'
|
||||
)
|
||||
|
||||
def _check_for_conversation_url(self, body):
|
||||
conversation_pattern = re.search(
|
||||
rf'https://{HOST}/conversations/([a-zA-Z0-9-]+)(?:\s|[.,;!?)]|$)', body
|
||||
)
|
||||
if conversation_pattern:
|
||||
return conversation_pattern.group(1)
|
||||
|
||||
return None
|
||||
|
||||
def _is_pr_closed_or_merged(self, payload):
|
||||
"""
|
||||
Check if PR was closed (regardless of conversation URL)
|
||||
"""
|
||||
action = payload.get('action', '')
|
||||
return action == 'closed' and 'pull_request' in payload
|
||||
|
||||
def _track_closed_or_merged_pr(self, payload):
|
||||
"""
|
||||
Track PR closed/merged event
|
||||
"""
|
||||
|
||||
repo_id = str(payload['repository']['id'])
|
||||
pr_number = payload['number']
|
||||
installation_id = str(payload['installation']['id'])
|
||||
private = payload['repository']['private']
|
||||
repo_name = payload['repository']['full_name']
|
||||
|
||||
pr_data = payload['pull_request']
|
||||
|
||||
# Extract PR metrics
|
||||
num_reviewers = len(pr_data.get('requested_reviewers', []))
|
||||
num_commits = pr_data.get('commits', 0)
|
||||
num_review_comments = pr_data.get('review_comments', 0)
|
||||
num_general_comments = pr_data.get('comments', 0)
|
||||
num_changed_files = pr_data.get('changed_files', 0)
|
||||
num_additions = pr_data.get('additions', 0)
|
||||
num_deletions = pr_data.get('deletions', 0)
|
||||
merged = pr_data.get('merged', False)
|
||||
|
||||
# Extract closed_at timestamp
|
||||
# Example: "closed_at":"2025-06-19T21:19:36Z"
|
||||
closed_at_str = pr_data.get('closed_at')
|
||||
created_at = pr_data.get('created_at')
|
||||
|
||||
closed_at = datetime.fromisoformat(closed_at_str.replace('Z', '+00:00'))
|
||||
|
||||
# Determine status based on whether it was merged
|
||||
status = PRStatus.MERGED if merged else PRStatus.CLOSED
|
||||
|
||||
store = OpenhandsPRStore.get_instance()
|
||||
|
||||
pr = OpenhandsPR(
|
||||
repo_name=repo_name,
|
||||
repo_id=repo_id,
|
||||
pr_number=pr_number,
|
||||
status=status,
|
||||
provider=ProviderType.GITHUB.value,
|
||||
installation_id=installation_id,
|
||||
private=private,
|
||||
num_reviewers=num_reviewers,
|
||||
num_commits=num_commits,
|
||||
num_review_comments=num_review_comments,
|
||||
num_changed_files=num_changed_files,
|
||||
num_additions=num_additions,
|
||||
num_deletions=num_deletions,
|
||||
merged=merged,
|
||||
created_at=created_at,
|
||||
closed_at=closed_at,
|
||||
# These properties will be enriched later
|
||||
openhands_helped_author=None,
|
||||
num_openhands_commits=None,
|
||||
num_openhands_review_comments=None,
|
||||
num_general_comments=num_general_comments,
|
||||
)
|
||||
|
||||
store.insert_pr(pr)
|
||||
logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}')
|
||||
|
||||
def process_payload(self, message: Message):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
return
|
||||
|
||||
raw_payload = message.message.get('payload', {})
|
||||
|
||||
if self._is_pr_closed_or_merged(raw_payload):
|
||||
self._track_closed_or_merged_pr(raw_payload)
|
||||
|
||||
async def save_data(self, github_view: ResolverViewInterface):
|
||||
if not COLLECT_GITHUB_INTERACTIONS:
|
||||
return
|
||||
|
||||
return
|
||||
|
||||
# TODO: track issue metadata in DB and save comments to filestore
|
||||
344
enterprise/integrations/github/github_manager.py
Normal file
344
enterprise/integrations/github/github_manager.py
Normal file
@@ -0,0 +1,344 @@
|
||||
from types import MappingProxyType
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from integrations.github.data_collector import GitHubDataCollector
|
||||
from integrations.github.github_solvability import summarize_issue_solvability
|
||||
from integrations.github.github_view import (
|
||||
GithubFactory,
|
||||
GithubFailingAction,
|
||||
GithubInlinePRComment,
|
||||
GithubIssue,
|
||||
GithubIssueComment,
|
||||
GithubPRComment,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import (
|
||||
Message,
|
||||
SourceType,
|
||||
)
|
||||
from integrations.types import ResolverViewInterface
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
|
||||
class GithubManager(Manager):
|
||||
def __init__(
|
||||
self, token_manager: TokenManager, data_collector: GitHubDataCollector
|
||||
):
|
||||
self.token_manager = token_manager
|
||||
self.data_collector = data_collector
|
||||
self.github_integration = GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'github')
|
||||
)
|
||||
|
||||
def _confirm_incoming_source_type(self, message: Message):
|
||||
if message.source != SourceType.GITHUB:
|
||||
raise ValueError(f'Unexpected message source {message.source}')
|
||||
|
||||
def _get_full_repo_name(self, repo_obj: dict) -> str:
|
||||
owner = repo_obj['owner']['login']
|
||||
repo_name = repo_obj['name']
|
||||
|
||||
return f'{owner}/{repo_name}'
|
||||
|
||||
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||
# get_access_token is typed to only accept int, but it can handle str.
|
||||
token_data = self.github_integration.get_access_token(
|
||||
installation_id # type: ignore[arg-type]
|
||||
)
|
||||
return token_data.token
|
||||
|
||||
def _add_reaction(
|
||||
self, github_view: ResolverViewInterface, reaction: str, installation_token: str
|
||||
):
|
||||
"""Add a reaction to the GitHub issue, PR, or comment.
|
||||
|
||||
Args:
|
||||
github_view: The GitHub view object containing issue/PR/comment info
|
||||
reaction: The reaction to add (e.g. "eyes", "+1", "-1", "laugh", "confused", "heart", "hooray", "rocket")
|
||||
installation_token: GitHub installation access token for API access
|
||||
"""
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
# Add reaction based on view type
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
inline_comment = pr.get_review_comment(github_view.comment_id)
|
||||
inline_comment.create_reaction(reaction)
|
||||
|
||||
elif isinstance(github_view, (GithubIssueComment, GithubPRComment)):
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
comment = issue.get_comment(github_view.comment_id)
|
||||
comment.create_reaction(reaction)
|
||||
else:
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
issue.create_reaction(reaction)
|
||||
|
||||
def _user_has_write_access_to_repo(
|
||||
self, installation_id: str, full_repo_name: str, username: str
|
||||
) -> bool:
|
||||
"""Check if the user is an owner, collaborator, or member of the repository."""
|
||||
with self.github_integration.get_github_for_installation(
|
||||
installation_id, # type: ignore[arg-type]
|
||||
{},
|
||||
) as repos:
|
||||
repository = repos.get_repo(full_repo_name)
|
||||
|
||||
# Check if the user is a collaborator
|
||||
try:
|
||||
collaborator = repository.get_collaborator_permission(username)
|
||||
if collaborator in ['admin', 'write']:
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If the above fails, check if the user is an owner or member
|
||||
org = repository.organization
|
||||
if org:
|
||||
user = org.get_members(username)
|
||||
return user is not None
|
||||
|
||||
return False
|
||||
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
payload = message.message.get('payload', {})
|
||||
repo_obj = payload.get('repository')
|
||||
if not repo_obj:
|
||||
return False
|
||||
username = payload.get('sender', {}).get('login')
|
||||
repo_name = self._get_full_repo_name(repo_obj)
|
||||
|
||||
# Suggestions contain `@openhands` macro; avoid kicking off jobs for system recommendations
|
||||
if GithubFactory.is_pr_comment(
|
||||
message
|
||||
) and GithubFailingAction.unqiue_suggestions_header in payload.get(
|
||||
'comment', {}
|
||||
).get('body', ''):
|
||||
return False
|
||||
|
||||
if GithubFactory.is_eligible_for_conversation_starter(
|
||||
message
|
||||
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||
await GithubFactory.trigger_conversation_starter(message)
|
||||
|
||||
if not (
|
||||
GithubFactory.is_labeled_issue(message)
|
||||
or GithubFactory.is_issue_comment(message)
|
||||
or GithubFactory.is_pr_comment(message)
|
||||
or GithubFactory.is_inline_pr_comment(message)
|
||||
):
|
||||
return False
|
||||
|
||||
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||
|
||||
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
try:
|
||||
await call_sync_from_async(self.data_collector.process_payload, message)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
'[Github]: Error processing payload for gh interaction', exc_info=True
|
||||
)
|
||||
|
||||
if await self.is_job_requested(message):
|
||||
github_view = await GithubFactory.create_github_view_from_payload(
|
||||
message, self.token_manager
|
||||
)
|
||||
logger.info(
|
||||
f'[GitHub] Creating job for {github_view.user_info.username} in {github_view.full_repo_name}#{github_view.issue_number}'
|
||||
)
|
||||
# Get the installation token
|
||||
installation_token = self._get_installation_access_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
# Store the installation token
|
||||
self.token_manager.store_org_token(
|
||||
github_view.installation_id, installation_token
|
||||
)
|
||||
# Add eyes reaction to acknowledge we've read the request
|
||||
self._add_reaction(github_view, 'eyes', installation_token)
|
||||
await self.start_job(github_view)
|
||||
|
||||
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||
installation_token = self.token_manager.load_org_token(
|
||||
github_view.installation_id
|
||||
)
|
||||
if not installation_token:
|
||||
logger.warning('Missing installation token')
|
||||
return
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(github_view, GithubInlinePRComment):
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
pr = repo.get_pull(github_view.issue_number)
|
||||
pr.create_review_comment_reply(
|
||||
comment_id=github_view.comment_id, body=outgoing_message
|
||||
)
|
||||
|
||||
elif (
|
||||
isinstance(github_view, GithubPRComment)
|
||||
or isinstance(github_view, GithubIssueComment)
|
||||
or isinstance(github_view, GithubIssue)
|
||||
):
|
||||
with Github(installation_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(number=github_view.issue_number)
|
||||
issue.create_comment(outgoing_message)
|
||||
|
||||
else:
|
||||
logger.warning('Unsupported location')
|
||||
return
|
||||
|
||||
async def start_job(self, github_view: ResolverViewInterface):
|
||||
"""Kick off a job with openhands agent.
|
||||
|
||||
1. Get user credential
|
||||
2. Initialize new conversation with repo
|
||||
3. Save interaction data
|
||||
"""
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.github_callback_processor import (
|
||||
GithubCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
|
||||
try:
|
||||
user_info = github_view.user_info
|
||||
logger.info(
|
||||
f'[GitHub] Starting job for user {user_info.username} (id={user_info.user_id})'
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
user_token = await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
str(user_info.user_id), ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if not user_token:
|
||||
logger.warning(
|
||||
f'[GitHub] No token found for user {user_info.username} (id={user_info.user_id})'
|
||||
)
|
||||
raise MissingSettingsError('Missing settings')
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Creating new conversation for user {user_info.username}'
|
||||
)
|
||||
|
||||
secret_store = UserSecrets(
|
||||
provider_tokens=MappingProxyType(
|
||||
{
|
||||
ProviderType.GITHUB: ProviderToken(
|
||||
token=SecretStr(user_token),
|
||||
user_id=str(user_info.user_id),
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# We first initialize a conversation and generate the solvability report BEFORE starting the conversation runtime
|
||||
# This helps us accumulate llm spend without requiring a running runtime. This setups us up for
|
||||
# 1. If there is a problem starting the runtime we still have accumulated total conversation cost
|
||||
# 2. In the future, based on the report confidence we can conditionally start the conversation
|
||||
# 3. Once the conversation is started, its base cost will include the report's spend as well which allows us to control max budget per resolver task
|
||||
convo_metadata = await github_view.initialize_new_conversation()
|
||||
solvability_summary = None
|
||||
try:
|
||||
if user_token:
|
||||
solvability_summary = await summarize_issue_solvability(
|
||||
github_view, user_token
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
'[Github]: No user token available for solvability analysis'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||
)
|
||||
|
||||
await github_view.create_new_conversation(
|
||||
self.jinja_env, secret_store.provider_tokens, convo_metadata
|
||||
)
|
||||
|
||||
conversation_id = github_view.conversation_id
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# Create a GithubCallbackProcessor
|
||||
processor = GithubCallbackProcessor(
|
||||
github_view=github_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send message with conversation link
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
base_msg = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
# Combine messages: include solvability report with "I'm on it!" if successful
|
||||
if solvability_summary:
|
||||
msg_info = f'{base_msg}\n\n{solvability_summary}'
|
||||
else:
|
||||
msg_info = base_msg
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] Missing settings error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
f'[GitHub] LLM authentication error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Github]: Error starting job')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
)
|
||||
await self.send_message(msg, github_view)
|
||||
|
||||
try:
|
||||
await self.data_collector.save_data(github_view)
|
||||
except Exception:
|
||||
logger.warning('[Github]: Error saving interaction data', exc_info=True)
|
||||
143
enterprise/integrations/github/github_service.py
Normal file
143
enterprise/integrations/github/github_service.py
Normal file
@@ -0,0 +1,143 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GitHubService
|
||||
from openhands.integrations.service_types import ProviderType, Repository
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
class SaaSGitHubService(GitHubService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.debug(
|
||||
f'SaaSGitHubService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, github_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_auth_id = external_auth_id
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
github_token = None
|
||||
if self.external_auth_token:
|
||||
github_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(), ProviderType.GITHUB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitHub token {github_token} from access token: {self.external_auth_token}'
|
||||
)
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
github_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.GITHUB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitHub token {github_token} from external auth user ID: {self.external_auth_id}'
|
||||
)
|
||||
elif self.user_id:
|
||||
github_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.GITHUB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitHub token {github_token} from user ID: {self.user_id}'
|
||||
)
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return github_token
|
||||
|
||||
async def get_pr_patches(
|
||||
self, owner: str, repo: str, pr_number: int, per_page: int = 30, page: int = 1
|
||||
):
|
||||
"""Get patches for files changed in a PR with pagination support.
|
||||
|
||||
Args:
|
||||
owner: Repository owner
|
||||
repo: Repository name
|
||||
pr_number: Pull request number
|
||||
per_page: Number of files per page (default: 30, max: 100)
|
||||
page: Page number to fetch (default: 1)
|
||||
"""
|
||||
url = f'https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}/files'
|
||||
params = {'per_page': min(per_page, 100), 'page': page} # GitHub max is 100
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
# Parse pagination info from headers
|
||||
has_next_page = 'next' in headers.get('link', '')
|
||||
total_count = int(headers.get('total', 0))
|
||||
|
||||
return {
|
||||
'files': response,
|
||||
'pagination': {
|
||||
'has_next_page': has_next_page,
|
||||
'total_count': total_count,
|
||||
'current_page': page,
|
||||
'per_page': per_page,
|
||||
},
|
||||
}
|
||||
|
||||
async def get_repository_node_id(self, repo_id: str) -> str:
|
||||
"""
|
||||
Get the new GitHub GraphQL node ID for a repository using REST API.
|
||||
|
||||
Args:
|
||||
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||
|
||||
Returns:
|
||||
New format node ID for GraphQL queries (e.g., "R_kgDOLfkiww")
|
||||
|
||||
Raises:
|
||||
Exception: If the API request fails or node_id is not found
|
||||
"""
|
||||
url = f'https://api.github.com/repositories/{repo_id}'
|
||||
response, _ = await self._make_request(url)
|
||||
node_id = response.get('node_id')
|
||||
if not node_id:
|
||||
raise Exception(f'No node_id found for repository {repo_id}')
|
||||
return node_id
|
||||
|
||||
async def get_paginated_repos(self, page, per_page, sort, installation_id):
|
||||
repositories = await super().get_paginated_repos(
|
||||
page, per_page, sort, installation_id
|
||||
)
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, self.external_auth_id)
|
||||
)
|
||||
return repositories
|
||||
|
||||
async def get_all_repositories(
|
||||
self, sort: str, app_mode: AppMode
|
||||
) -> list[Repository]:
|
||||
repositories = await super().get_all_repositories(sort, app_mode)
|
||||
# Schedule the background task without awaiting it
|
||||
asyncio.create_task(
|
||||
store_repositories_in_db(repositories, self.external_auth_id)
|
||||
)
|
||||
# Return repositories immediately
|
||||
return repositories
|
||||
183
enterprise/integrations/github/github_solvability.py
Normal file
183
enterprise/integrations/github/github_solvability.py
Normal file
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from github import Github
|
||||
from integrations.github.github_view import (
|
||||
GithubInlinePRComment,
|
||||
GithubIssueComment,
|
||||
GithubPRComment,
|
||||
GithubViewType,
|
||||
)
|
||||
from integrations.solvability.data import load_classifier
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from integrations.solvability.models.summary import SolvabilitySummary
|
||||
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||
from pydantic import ValidationError
|
||||
from server.auth.token_manager import get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_settings_store import SaasSettingsStore
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
from openhands.utils.utils import create_registry_and_conversation_stats
|
||||
|
||||
|
||||
def fetch_github_issue_context(
|
||||
github_view: GithubViewType,
|
||||
user_token: str,
|
||||
) -> str:
|
||||
"""Fetch full GitHub issue/PR context including title, body, and comments.
|
||||
|
||||
Args:
|
||||
full_repo_name: Full repository name in the format 'owner/repo'
|
||||
issue_number: The issue or PR number
|
||||
user_token: GitHub user access token
|
||||
max_comments: Maximum number of comments to fetch (default: 10)
|
||||
max_comment_length: Maximum length of each comment to include in the context (default: 500)
|
||||
|
||||
Returns:
|
||||
A comprehensive string containing the issue/PR context
|
||||
"""
|
||||
|
||||
# Build context string
|
||||
context_parts = []
|
||||
|
||||
# Add title and body
|
||||
context_parts.append(f'Title: {github_view.title}')
|
||||
context_parts.append(f'Description:\n{github_view.description}')
|
||||
|
||||
with Github(user_token) as github_client:
|
||||
repo = github_client.get_repo(github_view.full_repo_name)
|
||||
issue = repo.get_issue(github_view.issue_number)
|
||||
if issue.labels:
|
||||
labels = [label.name for label in issue.labels]
|
||||
context_parts.append(f"Labels: {', '.join(labels)}")
|
||||
|
||||
for comment in github_view.previous_comments:
|
||||
context_parts.append(f'- {comment.author}: {comment.body}')
|
||||
|
||||
return '\n\n'.join(context_parts)
|
||||
|
||||
|
||||
async def summarize_issue_solvability(
|
||||
github_view: GithubViewType,
|
||||
user_token: str,
|
||||
timeout: float = 60.0 * 5,
|
||||
) -> str:
|
||||
"""Generate a solvability summary for an issue using the resolver view interface.
|
||||
|
||||
Args:
|
||||
resolver_view: A resolver view interface instance (e.g., GithubIssue, GithubPRComment)
|
||||
user_token: GitHub user access token for API access
|
||||
timeout: Maximum time in seconds to wait for the result (default: 60.0)
|
||||
|
||||
Returns:
|
||||
The solvability summary as a string
|
||||
|
||||
Raises:
|
||||
ValueError: If LLM settings cannot be found for the user
|
||||
asyncio.TimeoutError: If the operation exceeds the specified timeout
|
||||
"""
|
||||
if not ENABLE_SOLVABILITY_ANALYSIS:
|
||||
raise ValueError('Solvability report feature is disabled')
|
||||
|
||||
if github_view.user_info.keycloak_user_id is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No user ID found for user {github_view.user_info.username}'
|
||||
)
|
||||
|
||||
# Grab the user's information so we can load their LLM configuration
|
||||
store = SaasSettingsStore(
|
||||
user_id=github_view.user_info.keycloak_user_id,
|
||||
session_maker=session_maker,
|
||||
config=get_config(),
|
||||
)
|
||||
|
||||
user_settings = await store.load()
|
||||
|
||||
if user_settings is None:
|
||||
raise ValueError(
|
||||
f'[Solvability] No user settings found for user ID {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
# Check if solvability analysis is enabled for this user, exit early if
|
||||
# needed
|
||||
if not getattr(user_settings, 'enable_solvability_analysis', False):
|
||||
raise ValueError(
|
||||
f'Solvability analysis disabled for user {github_view.user_info.user_id}'
|
||||
)
|
||||
|
||||
try:
|
||||
llm_config = LLMConfig(
|
||||
model=user_settings.llm_model,
|
||||
api_key=user_settings.llm_api_key.get_secret_value(),
|
||||
base_url=user_settings.llm_base_url,
|
||||
)
|
||||
except ValidationError as e:
|
||||
raise ValueError(
|
||||
f'[Solvability] Invalid LLM configuration for user {github_view.user_info.user_id}: {str(e)}'
|
||||
)
|
||||
|
||||
# Fetch the full GitHub issue/PR context using the GitHub API
|
||||
start_time = time.time()
|
||||
issue_context = fetch_github_issue_context(github_view, user_token)
|
||||
logger.info(
|
||||
f'[Solvability] Grabbed issue context for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'response_latency': time.time() - start_time,
|
||||
'full_repo_name': github_view.full_repo_name,
|
||||
'issue_number': github_view.issue_number,
|
||||
},
|
||||
)
|
||||
|
||||
# For comment-based triggers, also include the specific comment that triggered the action
|
||||
if isinstance(
|
||||
github_view, (GithubIssueComment, GithubPRComment, GithubInlinePRComment)
|
||||
):
|
||||
issue_context += f'\n\nTriggering Comment:\n{github_view.comment_body}'
|
||||
|
||||
solvability_classifier = load_classifier('default-classifier')
|
||||
|
||||
async with asyncio.timeout(timeout):
|
||||
solvability_report: SolvabilityReport = await call_sync_from_async(
|
||||
lambda: solvability_classifier.solvability_report(
|
||||
issue_context, llm_config=llm_config
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Solvability] Generated report for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'report': solvability_report.model_dump(exclude=['issue']),
|
||||
},
|
||||
)
|
||||
|
||||
llm_registry, conversation_stats, _ = create_registry_and_conversation_stats(
|
||||
get_config(),
|
||||
github_view.conversation_id,
|
||||
github_view.user_info.keycloak_user_id,
|
||||
None,
|
||||
)
|
||||
|
||||
solvability_summary = await call_sync_from_async(
|
||||
lambda: SolvabilitySummary.from_report(
|
||||
solvability_report,
|
||||
llm=llm_registry.get_llm(
|
||||
service_id='solvability_analysis', config=llm_config
|
||||
),
|
||||
)
|
||||
)
|
||||
conversation_stats.save_metrics()
|
||||
|
||||
logger.info(
|
||||
f'[Solvability] Generated summary for {github_view.conversation_id}',
|
||||
extra={
|
||||
'conversation_id': github_view.conversation_id,
|
||||
'summary': solvability_summary.model_dump(exclude=['content']),
|
||||
},
|
||||
)
|
||||
|
||||
return solvability_summary.format_as_markdown()
|
||||
26
enterprise/integrations/github/github_types.py
Normal file
26
enterprise/integrations/github/github_types.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WorkflowRunStatus(Enum):
|
||||
FAILURE = 'failure'
|
||||
COMPLETED = 'completed'
|
||||
PENDING = 'pending'
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, str):
|
||||
return self.value == other
|
||||
return super().__eq__(other)
|
||||
|
||||
|
||||
class WorkflowRun(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
status: WorkflowRunStatus
|
||||
|
||||
model_config = {'use_enum_values': True}
|
||||
|
||||
|
||||
class WorkflowRunGroup(BaseModel):
|
||||
runs: dict[str, WorkflowRun]
|
||||
756
enterprise/integrations/github/github_view.py
Normal file
756
enterprise/integrations/github/github_view.py
Normal file
@@ -0,0 +1,756 @@
|
||||
from uuid import uuid4
|
||||
|
||||
from github import Github, GithubIntegration
|
||||
from github.Issue import Issue
|
||||
from integrations.github.github_types import (
|
||||
WorkflowRun,
|
||||
WorkflowRunGroup,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
from integrations.models import Message
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import (
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||
HOST,
|
||||
HOST_URL,
|
||||
get_oh_labels,
|
||||
has_exact_mention,
|
||||
)
|
||||
from jinja2 import Environment
|
||||
from pydantic.dataclasses import dataclass
|
||||
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
from server.auth.token_manager import TokenManager, get_config
|
||||
from storage.database import session_maker
|
||||
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
from storage.user_settings import UserSettings
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.server.services.conversation_service import (
|
||||
initialize_conversation,
|
||||
start_conversation,
|
||||
)
|
||||
from openhands.storage.data_models.conversation_metadata import (
|
||||
ConversationMetadata,
|
||||
ConversationTrigger,
|
||||
)
|
||||
from openhands.utils.async_utils import call_sync_from_async
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
|
||||
|
||||
async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||
"""Get the user's proactive conversation setting.
|
||||
|
||||
Args:
|
||||
user_id: The keycloak user ID
|
||||
|
||||
Returns:
|
||||
True if proactive conversations are enabled for this user, False otherwise
|
||||
|
||||
Note:
|
||||
This function checks both the global environment variable kill switch AND
|
||||
the user's individual setting. Both must be true for the function to return true.
|
||||
"""
|
||||
|
||||
# If no user ID is provided, we can't check user settings
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
def _get_setting():
|
||||
with session_maker() as session:
|
||||
settings = (
|
||||
session.query(UserSettings)
|
||||
.filter(UserSettings.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||
return False
|
||||
|
||||
return settings.enable_proactive_conversation_starters
|
||||
|
||||
return await call_sync_from_async(_get_setting)
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Github view types
|
||||
# =================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubIssue(ResolverViewInterface):
|
||||
issue_number: int
|
||||
installation_id: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
user_info: UserData
|
||||
raw_payload: Message
|
||||
conversation_id: str
|
||||
uuid: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
self.previous_comments = await github_service.get_issue_or_pr_comments(
|
||||
self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await github_service.get_issue_or_pr_title_and_body(
|
||||
self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
)
|
||||
|
||||
await self._load_resolver_context()
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
async def create_new_conversation(
|
||||
self,
|
||||
jinja_env: Environment,
|
||||
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||
conversation_metadata: ConversationMetadata,
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
|
||||
await start_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
initial_user_msg=user_instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_id=conversation_metadata.conversation_id,
|
||||
conversation_metadata=conversation_metadata,
|
||||
conversation_instructions=conversation_instructions,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubIssueComment(GithubIssue):
|
||||
comment_body: str
|
||||
comment_id: int
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_comment=self.comment_body
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
previous_comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubPRComment(GithubIssueComment):
|
||||
branch_name: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('pr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
pr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
branch_name=self.branch_name,
|
||||
pr_title=self.title,
|
||||
pr_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||
# FIXME: Handle if initialize_conversation returns None
|
||||
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
conversation_id=None,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self.branch_name,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
git_provider=ProviderType.GITHUB,
|
||||
)
|
||||
|
||||
self.conversation_id = conversation_metadata.conversation_id
|
||||
return conversation_metadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubInlinePRComment(GithubPRComment):
|
||||
file_location: str
|
||||
line_number: int
|
||||
comment_node_id: str
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
github_service = GithubServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await github_service.get_issue_or_pr_title_and_body(
|
||||
self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
self.previous_comments = await github_service.get_review_thread_comments(
|
||||
self.comment_node_id, self.full_repo_name, self.issue_number
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('pr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
pr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'pr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
pr_number=self.issue_number,
|
||||
pr_title=self.title,
|
||||
pr_body=self.description,
|
||||
branch_name=self.branch_name,
|
||||
file_location=self.file_location,
|
||||
line_number=self.line_number,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
@dataclass
|
||||
class GithubFailingAction:
|
||||
unqiue_suggestions_header: str = (
|
||||
'Looks like there are a few issues preventing this PR from being merged!'
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_latest_sha(pr: Issue) -> str:
|
||||
pr_obj = pr.as_pull_request()
|
||||
return pr_obj.head.sha
|
||||
|
||||
@staticmethod
|
||||
def create_retrieve_workflows_callback(pr: Issue, head_sha: str):
|
||||
def get_all_workflows():
|
||||
repo = pr.repository
|
||||
workflows = repo.get_workflow_runs(head_sha=head_sha)
|
||||
|
||||
runs = {}
|
||||
|
||||
for workflow in workflows:
|
||||
conclusion = workflow.conclusion
|
||||
workflow_conclusion = WorkflowRunStatus.COMPLETED
|
||||
if conclusion is None:
|
||||
workflow_conclusion = WorkflowRunStatus.PENDING # type: ignore[unreachable]
|
||||
elif conclusion == WorkflowRunStatus.FAILURE.value:
|
||||
workflow_conclusion = WorkflowRunStatus.FAILURE
|
||||
|
||||
runs[str(workflow.id)] = WorkflowRun(
|
||||
id=str(workflow.id), name=workflow.name, status=workflow_conclusion
|
||||
)
|
||||
|
||||
return WorkflowRunGroup(runs=runs)
|
||||
|
||||
return get_all_workflows
|
||||
|
||||
@staticmethod
|
||||
def delete_old_comment_if_exists(pr: Issue):
|
||||
paginated_comments = pr.get_comments()
|
||||
for page in range(paginated_comments.totalCount):
|
||||
comments = paginated_comments.get_page(page)
|
||||
for comment in comments:
|
||||
if GithubFailingAction.unqiue_suggestions_header in comment.body:
|
||||
comment.delete()
|
||||
|
||||
@staticmethod
|
||||
def get_suggestions(
|
||||
failed_jobs: dict, pr_number: int, branch_name: str | None = None
|
||||
) -> str:
|
||||
issues = []
|
||||
|
||||
# Collect failing actions with their specific names
|
||||
if failed_jobs['actions']:
|
||||
failing_actions = failed_jobs['actions']
|
||||
issues.append(('GitHub Actions are failing:', False))
|
||||
for action in failing_actions:
|
||||
issues.append((action, True))
|
||||
|
||||
if any(failed_jobs['merge conflict']):
|
||||
issues.append(('There are merge conflicts', False))
|
||||
|
||||
# Format each line with proper indentation and dashes
|
||||
formatted_issues = []
|
||||
for issue, is_nested in issues:
|
||||
if is_nested:
|
||||
formatted_issues.append(f' - {issue}')
|
||||
else:
|
||||
formatted_issues.append(f'- {issue}')
|
||||
issues_text = '\n'.join(formatted_issues)
|
||||
|
||||
# Build list of possible suggestions based on actual issues
|
||||
suggestions = []
|
||||
branch_info = f' at branch `{branch_name}`' if branch_name else ''
|
||||
|
||||
if any(failed_jobs['merge conflict']):
|
||||
suggestions.append(
|
||||
f'@OpenHands please fix the merge conflicts on PR #{pr_number}{branch_info}'
|
||||
)
|
||||
if any(failed_jobs['actions']):
|
||||
suggestions.append(
|
||||
f'@OpenHands please fix the failing actions on PR #{pr_number}{branch_info}'
|
||||
)
|
||||
|
||||
# Take at most 2 suggestions
|
||||
suggestions = suggestions[:2]
|
||||
|
||||
help_text = """If you'd like me to help, just leave a comment, like
|
||||
|
||||
```
|
||||
{}
|
||||
```
|
||||
|
||||
Feel free to include any additional details that might help me get this PR into a better state.
|
||||
|
||||
<sub><sup>You can manage your notification [settings]({})</sup></sub>""".format(
|
||||
'\n```\n\nor\n\n```\n'.join(suggestions), f'{HOST_URL}/settings/app'
|
||||
)
|
||||
|
||||
return f'{GithubFailingAction.unqiue_suggestions_header}\n\n{issues_text}\n\n{help_text}'
|
||||
|
||||
@staticmethod
|
||||
def leave_requesting_comment(pr: Issue, failed_runs: WorkflowRunGroup):
|
||||
failed_jobs: dict = {'actions': [], 'merge conflict': []}
|
||||
|
||||
pr_obj = pr.as_pull_request()
|
||||
if not pr_obj.mergeable:
|
||||
failed_jobs['merge conflict'].append('Merge conflict detected')
|
||||
|
||||
for _, workflow_run in failed_runs.runs.items():
|
||||
if workflow_run.status == WorkflowRunStatus.FAILURE:
|
||||
failed_jobs['actions'].append(workflow_run.name)
|
||||
|
||||
logger.info(f'[GitHub] Found failing jobs for PR #{pr.number}: {failed_jobs}')
|
||||
|
||||
# Get the branch name
|
||||
branch_name = pr_obj.head.ref
|
||||
|
||||
# Get suggestions with branch name included
|
||||
suggestions = GithubFailingAction.get_suggestions(
|
||||
failed_jobs, pr.number, branch_name
|
||||
)
|
||||
|
||||
GithubFailingAction.delete_old_comment_if_exists(pr)
|
||||
pr.create_comment(suggestions)
|
||||
|
||||
|
||||
GithubViewType = (
|
||||
GithubInlinePRComment | GithubPRComment | GithubIssueComment | GithubIssue
|
||||
)
|
||||
|
||||
|
||||
# =================================================
|
||||
# SECTION: Factory to create appriorate Github view
|
||||
# =================================================
|
||||
|
||||
|
||||
class GithubFactory:
|
||||
@staticmethod
|
||||
def is_labeled_issue(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if action == 'labeled' and 'label' in payload and 'issue' in payload:
|
||||
label_name = payload['label'].get('name', '')
|
||||
if label_name == OH_LABEL:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_issue_comment(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if (
|
||||
action == 'created'
|
||||
and 'comment' in payload
|
||||
and 'issue' in payload
|
||||
and 'pull_request' not in payload['issue']
|
||||
):
|
||||
comment_body = payload['comment']['body']
|
||||
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_pr_comment(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if (
|
||||
action == 'created'
|
||||
and 'comment' in payload
|
||||
and 'issue' in payload
|
||||
and 'pull_request' in payload['issue']
|
||||
):
|
||||
comment_body = payload['comment'].get('body', '')
|
||||
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_inline_pr_comment(message: Message):
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if action == 'created' and 'comment' in payload and 'pull_request' in payload:
|
||||
comment_body = payload['comment'].get('body', '')
|
||||
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_eligible_for_conversation_starter(message: Message):
|
||||
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||
return False
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
action = payload.get('action', '')
|
||||
|
||||
if not (action == 'completed' and 'workflow_run' in payload):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def trigger_conversation_starter(message: Message):
|
||||
"""Trigger a conversation starter when a workflow fails.
|
||||
|
||||
This is the updated version that checks user settings.
|
||||
"""
|
||||
payload = message.message.get('payload', {})
|
||||
workflow_payload = payload['workflow_run']
|
||||
status = WorkflowRunStatus.COMPLETED
|
||||
|
||||
if workflow_payload['conclusion'] == 'failure':
|
||||
status = WorkflowRunStatus.FAILURE
|
||||
elif workflow_payload['conclusion'] is None:
|
||||
status = WorkflowRunStatus.PENDING
|
||||
|
||||
workflow_run = WorkflowRun(
|
||||
id=str(workflow_payload['id']), name=workflow_payload['name'], status=status
|
||||
)
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(payload['repository'])
|
||||
head_branch = payload['workflow_run']['head_branch']
|
||||
|
||||
# Get the user ID to check their settings
|
||||
user_id = None
|
||||
try:
|
||||
sender_id = payload['sender']['id']
|
||||
token_manager = TokenManager()
|
||||
user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
sender_id, ProviderType.GITHUB
|
||||
)
|
||||
except (KeyError, Exception) as e:
|
||||
logger.warning(
|
||||
f'Failed to get user ID for proactive conversation check: {str(e)}'
|
||||
)
|
||||
|
||||
# Check if proactive conversations are enabled for this user
|
||||
if not await get_user_proactive_conversation_setting(user_id):
|
||||
return False
|
||||
|
||||
def _interact_with_github() -> Issue | None:
|
||||
with GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
) as integration:
|
||||
access_token = integration.get_access_token(
|
||||
payload['installation']['id']
|
||||
).token
|
||||
|
||||
with Github(access_token) as gh:
|
||||
repo = gh.get_repo(selected_repo)
|
||||
login = (
|
||||
payload['organization']['login']
|
||||
if 'organization' in payload
|
||||
else payload['sender']['login']
|
||||
)
|
||||
|
||||
# See if a pull request is open
|
||||
open_pulls = repo.get_pulls(state='open', head=f'{login}:{head_branch}')
|
||||
if open_pulls.totalCount > 0:
|
||||
prs = open_pulls.get_page(0)
|
||||
relevant_pr = prs[0]
|
||||
issue = repo.get_issue(number=relevant_pr.number)
|
||||
return issue
|
||||
|
||||
return None
|
||||
|
||||
issue: Issue | None = await call_sync_from_async(_interact_with_github)
|
||||
if not issue:
|
||||
return False
|
||||
|
||||
incoming_commit = payload['workflow_run']['head_sha']
|
||||
latest_sha = GithubFailingAction.get_latest_sha(issue)
|
||||
if latest_sha != incoming_commit:
|
||||
# Return as this commit is not the latest
|
||||
return False
|
||||
|
||||
convo_store = ProactiveConversationStore()
|
||||
workflow_group = await convo_store.store_workflow_information(
|
||||
provider=ProviderType.GITHUB,
|
||||
repo_id=payload['repository']['id'],
|
||||
incoming_commit=incoming_commit,
|
||||
workflow=workflow_run,
|
||||
pr_number=issue.number,
|
||||
get_all_workflows=GithubFailingAction.create_retrieve_workflows_callback(
|
||||
issue, incoming_commit
|
||||
),
|
||||
)
|
||||
|
||||
if not workflow_group:
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f'[GitHub] Workflow completed for {selected_repo}#{issue.number} on branch {head_branch}'
|
||||
)
|
||||
GithubFailingAction.leave_requesting_comment(issue, workflow_group)
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_full_repo_name(repo_obj: dict) -> str:
|
||||
owner = repo_obj['owner']['login']
|
||||
repo_name = repo_obj['name']
|
||||
return f'{owner}/{repo_name}'
|
||||
|
||||
@staticmethod
|
||||
async def create_github_view_from_payload(
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||
Also return metadata about the event (e.g., action type).
|
||||
"""
|
||||
payload = message.message.get('payload', {})
|
||||
repo_obj = payload['repository']
|
||||
user_id = payload['sender']['id']
|
||||
username = payload['sender']['login']
|
||||
|
||||
keyloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITHUB
|
||||
)
|
||||
|
||||
if keyloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitHub User {user_id} ')
|
||||
|
||||
selected_repo = GithubFactory.get_full_repo_name(repo_obj)
|
||||
is_public_repo = not repo_obj.get('private', True)
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keyloak_user_id
|
||||
)
|
||||
|
||||
installation_id = message.message['installation']
|
||||
|
||||
if GithubFactory.is_labeled_issue(message):
|
||||
issue_number = payload['issue']['number']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for labeled issue from {username} in {selected_repo}#{issue_number}'
|
||||
)
|
||||
return GithubIssue(
|
||||
issue_number=issue_number,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=str(uuid4()),
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
elif GithubFactory.is_issue_comment(message):
|
||||
issue_number = payload['issue']['number']
|
||||
comment_body = payload['comment']['body']
|
||||
comment_id = payload['comment']['id']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for issue comment from {username} in {selected_repo}#{issue_number}'
|
||||
)
|
||||
return GithubIssueComment(
|
||||
issue_number=issue_number,
|
||||
comment_body=comment_body,
|
||||
comment_id=comment_id,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
elif GithubFactory.is_pr_comment(message):
|
||||
issue_number = payload['issue']['number']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for PR comment from {username} in {selected_repo}#{issue_number}'
|
||||
)
|
||||
|
||||
access_token = ''
|
||||
with GithubIntegration(
|
||||
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||
) as integration:
|
||||
access_token = integration.get_access_token(installation_id).token
|
||||
|
||||
head_ref = None
|
||||
with Github(access_token) as gh:
|
||||
repo = gh.get_repo(selected_repo)
|
||||
pull_request = repo.get_pull(issue_number)
|
||||
head_ref = pull_request.head.ref
|
||||
logger.info(
|
||||
f'[GitHub] Found PR branch {head_ref} for {selected_repo}#{issue_number}'
|
||||
)
|
||||
|
||||
comment_id = payload['comment']['id']
|
||||
return GithubPRComment(
|
||||
issue_number=issue_number,
|
||||
branch_name=head_ref,
|
||||
comment_body=payload['comment']['body'],
|
||||
comment_id=comment_id,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
elif GithubFactory.is_inline_pr_comment(message):
|
||||
pr_number = payload['pull_request']['number']
|
||||
branch_name = payload['pull_request']['head']['ref']
|
||||
comment_id = payload['comment']['id']
|
||||
comment_node_id = payload['comment']['node_id']
|
||||
file_path = payload['comment']['path']
|
||||
line_number = payload['comment']['line']
|
||||
logger.info(
|
||||
f'[GitHub] Creating view for inline PR comment from {username} in {selected_repo}#{pr_number} at {file_path}'
|
||||
)
|
||||
|
||||
return GithubInlinePRComment(
|
||||
issue_number=pr_number,
|
||||
branch_name=branch_name,
|
||||
comment_body=payload['comment']['body'],
|
||||
comment_node_id=comment_node_id,
|
||||
comment_id=comment_id,
|
||||
file_location=file_path,
|
||||
line_number=line_number,
|
||||
installation_id=installation_id,
|
||||
full_repo_name=selected_repo,
|
||||
is_public_repo=is_public_repo,
|
||||
raw_payload=message,
|
||||
user_info=user_info,
|
||||
conversation_id='',
|
||||
uuid=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid payload: must contain either 'issue' or 'pull_request'"
|
||||
)
|
||||
102
enterprise/integrations/github/queries.py
Normal file
102
enterprise/integrations/github/queries.py
Normal file
@@ -0,0 +1,102 @@
|
||||
PR_QUERY_BY_NODE_ID = """
|
||||
query($nodeId: ID!, $pr_number: Int!, $commits_after: String, $comments_after: String, $reviews_after: String) {
|
||||
node(id: $nodeId) {
|
||||
... on Repository {
|
||||
name
|
||||
owner {
|
||||
login
|
||||
}
|
||||
languages(first: 10, orderBy: {field: SIZE, direction: DESC}) {
|
||||
nodes {
|
||||
name
|
||||
}
|
||||
}
|
||||
pullRequest(number: $pr_number) {
|
||||
number
|
||||
title
|
||||
body
|
||||
author {
|
||||
login
|
||||
}
|
||||
merged
|
||||
mergedAt
|
||||
mergedBy {
|
||||
login
|
||||
}
|
||||
state
|
||||
mergeCommit {
|
||||
oid
|
||||
}
|
||||
comments(first: 50, after: $comments_after) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
}
|
||||
body
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
commits(first: 50, after: $commits_after) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
commit {
|
||||
oid
|
||||
message
|
||||
committedDate
|
||||
author {
|
||||
name
|
||||
email
|
||||
user {
|
||||
login
|
||||
}
|
||||
}
|
||||
additions
|
||||
deletions
|
||||
changedFiles
|
||||
}
|
||||
}
|
||||
}
|
||||
reviews(first: 50, after: $reviews_after) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
}
|
||||
body
|
||||
state
|
||||
createdAt
|
||||
comments(first: 50) {
|
||||
pageInfo {
|
||||
hasNextPage
|
||||
endCursor
|
||||
}
|
||||
nodes {
|
||||
author {
|
||||
login
|
||||
}
|
||||
body
|
||||
createdAt
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
rateLimit {
|
||||
remaining
|
||||
limit
|
||||
resetAt
|
||||
}
|
||||
}
|
||||
"""
|
||||
261
enterprise/integrations/gitlab/gitlab_manager.py
Normal file
261
enterprise/integrations/gitlab/gitlab_manager.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from types import MappingProxyType
|
||||
|
||||
from integrations.gitlab.gitlab_view import (
|
||||
GitlabFactory,
|
||||
GitlabInlineMRComment,
|
||||
GitlabIssue,
|
||||
GitlabIssueComment,
|
||||
GitlabMRComment,
|
||||
GitlabViewType,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.types import ResolverViewInterface
|
||||
from integrations.utils import (
|
||||
CONVERSATION_URL,
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||
|
||||
|
||||
class GitlabManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager, data_collector: None = None):
|
||||
self.token_manager = token_manager
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'gitlab')
|
||||
)
|
||||
|
||||
def _confirm_incoming_source_type(self, message: Message):
|
||||
if message.source != SourceType.GITLAB:
|
||||
raise ValueError(f'Unexpected message source {message.source}')
|
||||
|
||||
async def _user_has_write_access_to_repo(
|
||||
self, project_id: str, user_id: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the user has write access to the repository (can pull/push changes and open merge requests).
|
||||
|
||||
Args:
|
||||
project_id: The ID of the GitLab project
|
||||
username: The username of the user
|
||||
user_id: The GitLab user ID
|
||||
|
||||
Returns:
|
||||
bool: True if the user has write access to the repository, False otherwise
|
||||
"""
|
||||
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITLAB
|
||||
)
|
||||
if keycloak_user_id is None:
|
||||
logger.warning(f'Got invalid keyloak user id for GitLab User {user_id}')
|
||||
return False
|
||||
|
||||
# Importing here prevents circular import
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service: SaaSGitLabService = GitLabServiceImpl(
|
||||
external_auth_id=keycloak_user_id
|
||||
)
|
||||
|
||||
return await gitlab_service.user_has_write_access(project_id)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
if await self.is_job_requested(message):
|
||||
gitlab_view = await GitlabFactory.create_gitlab_view_from_payload(
|
||||
message, self.token_manager
|
||||
)
|
||||
logger.info(
|
||||
f'[GitLab] Creating job for {gitlab_view.user_info.username} in {gitlab_view.full_repo_name}#{gitlab_view.issue_number}'
|
||||
)
|
||||
|
||||
await self.start_job(gitlab_view)
|
||||
|
||||
async def is_job_requested(self, message) -> bool:
|
||||
self._confirm_incoming_source_type(message)
|
||||
if not (
|
||||
GitlabFactory.is_labeled_issue(message)
|
||||
or GitlabFactory.is_issue_comment(message)
|
||||
or GitlabFactory.is_mr_comment(message)
|
||||
or GitlabFactory.is_mr_comment(message, inline=True)
|
||||
):
|
||||
return False
|
||||
|
||||
payload = message.message['payload']
|
||||
|
||||
repo_obj = payload['project']
|
||||
project_id = repo_obj['id']
|
||||
selected_project = repo_obj['path_with_namespace']
|
||||
user = payload['user']
|
||||
user_id = user['id']
|
||||
username = user['username']
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Checking permissions for {username} in {selected_project}'
|
||||
)
|
||||
|
||||
has_write_access = await self._user_has_write_access_to_repo(
|
||||
project_id=str(project_id), user_id=user_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab]: {username} access in {selected_project}: {has_write_access}'
|
||||
)
|
||||
# Check if the user has write access to the repository
|
||||
return has_write_access
|
||||
|
||||
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||
"""
|
||||
Send a message to GitLab based on the view type.
|
||||
|
||||
Args:
|
||||
message: The message to send
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||
|
||||
# Importing here prevents circular import
|
||||
from integrations.gitlab.gitlab_service import SaaSGitLabService
|
||||
|
||||
gitlab_service: SaaSGitLabService = GitLabServiceImpl(
|
||||
external_auth_id=keycloak_user_id
|
||||
)
|
||||
|
||||
outgoing_message = message.message
|
||||
|
||||
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||
gitlab_view, GitlabMRComment
|
||||
):
|
||||
await gitlab_service.reply_to_mr(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
message.message,
|
||||
)
|
||||
|
||||
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
gitlab_view.discussion_id,
|
||||
outgoing_message,
|
||||
)
|
||||
elif isinstance(gitlab_view, GitlabIssue):
|
||||
await gitlab_service.reply_to_issue(
|
||||
gitlab_view.project_id,
|
||||
gitlab_view.issue_number,
|
||||
None, # no discussion id, issue is tagged
|
||||
outgoing_message,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f'[GitLab] Unsupported view type: {type(gitlab_view).__name__}'
|
||||
)
|
||||
|
||||
async def start_job(self, gitlab_view: GitlabViewType):
|
||||
"""
|
||||
Start a job for the GitLab view.
|
||||
|
||||
Args:
|
||||
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||
"""
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.gitlab_callback_processor import (
|
||||
GitlabCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
try:
|
||||
user_info = gitlab_view.user_info
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Starting job for {user_info.username} in {gitlab_view.full_repo_name}#{gitlab_view.issue_number}'
|
||||
)
|
||||
|
||||
user_token = await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
str(user_info.user_id), ProviderType.GITLAB
|
||||
)
|
||||
|
||||
if not user_token:
|
||||
logger.warning(
|
||||
f'[GitLab] No token found for user {user_info.username} (id={user_info.user_id})'
|
||||
)
|
||||
raise MissingSettingsError('Missing settings')
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Creating new conversation for user {user_info.username}'
|
||||
)
|
||||
|
||||
secret_store = UserSecrets(
|
||||
provider_tokens=MappingProxyType(
|
||||
{
|
||||
ProviderType.GITLAB: ProviderToken(
|
||||
token=SecretStr(user_token),
|
||||
user_id=str(user_info.user_id),
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
await gitlab_view.create_new_conversation(
|
||||
self.jinja_env, secret_store.provider_tokens
|
||||
)
|
||||
|
||||
conversation_id = gitlab_view.conversation_id
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Created conversation {conversation_id} for user {user_info.username}'
|
||||
)
|
||||
|
||||
# Create a GitlabCallbackProcessor for this conversation
|
||||
processor = GitlabCallbackProcessor(
|
||||
gitlab_view=gitlab_view,
|
||||
send_summary_instruction=True,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
msg_info = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
f'[GitLab] Missing settings error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
f'[GitLab] LLM authentication error for user {user_info.username}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
# Send the acknowledgment message
|
||||
msg = self.create_outgoing_message(msg_info)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||
msg = self.create_outgoing_message(
|
||||
msg='Uh oh! There was an unexpected error starting the job :('
|
||||
)
|
||||
await self.send_message(msg, gitlab_view)
|
||||
529
enterprise/integrations/gitlab/gitlab_service.py
Normal file
529
enterprise/integrations/gitlab/gitlab_service.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import asyncio
|
||||
|
||||
from integrations.types import GitLabResourceType
|
||||
from integrations.utils import store_repositories_in_db
|
||||
from pydantic import SecretStr
|
||||
from server.auth.token_manager import TokenManager
|
||||
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabService
|
||||
from openhands.integrations.service_types import (
|
||||
ProviderType,
|
||||
RateLimitError,
|
||||
Repository,
|
||||
RequestMethod,
|
||||
)
|
||||
from openhands.server.types import AppMode
|
||||
|
||||
|
||||
class SaaSGitLabService(GitLabService):
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str | None = None,
|
||||
external_auth_token: SecretStr | None = None,
|
||||
external_auth_id: str | None = None,
|
||||
token: SecretStr | None = None,
|
||||
external_token_manager: bool = False,
|
||||
base_domain: str | None = None,
|
||||
):
|
||||
logger.info(
|
||||
f'SaaSGitLabService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, gitlab_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||
)
|
||||
super().__init__(
|
||||
user_id=user_id,
|
||||
external_auth_token=external_auth_token,
|
||||
external_auth_id=external_auth_id,
|
||||
token=token,
|
||||
external_token_manager=external_token_manager,
|
||||
base_domain=base_domain,
|
||||
)
|
||||
|
||||
self.external_auth_token = external_auth_token
|
||||
self.external_auth_id = external_auth_id
|
||||
self.token_manager = TokenManager(external=external_token_manager)
|
||||
|
||||
async def get_latest_token(self) -> SecretStr | None:
|
||||
gitlab_token = None
|
||||
if self.external_auth_token:
|
||||
gitlab_token = SecretStr(
|
||||
await self.token_manager.get_idp_token(
|
||||
self.external_auth_token.get_secret_value(), idp=ProviderType.GITLAB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got GitLab token {gitlab_token} from access token: {self.external_auth_token}'
|
||||
)
|
||||
elif self.external_auth_id:
|
||||
offline_token = await self.token_manager.load_offline_token(
|
||||
self.external_auth_id
|
||||
)
|
||||
gitlab_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_offline_token(
|
||||
offline_token, ProviderType.GITLAB
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
f'Got GitLab token {gitlab_token.get_secret_value()} from external auth user ID: {self.external_auth_id}'
|
||||
)
|
||||
elif self.user_id:
|
||||
gitlab_token = SecretStr(
|
||||
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||
self.user_id, ProviderType.GITLAB
|
||||
)
|
||||
)
|
||||
logger.debug(
|
||||
f'Got Gitlab token {gitlab_token} from user ID: {self.user_id}'
|
||||
)
|
||||
else:
|
||||
logger.warning('external_auth_token and user_id not set!')
|
||||
return gitlab_token
|
||||
|
||||
async def get_owned_groups(self) -> list[dict]:
|
||||
"""
|
||||
Get all groups for which the current user is the owner.
|
||||
|
||||
Returns:
|
||||
list[dict]: A list of groups owned by the current user.
|
||||
"""
|
||||
url = f'{self.BASE_URL}/groups'
|
||||
params = {'owned': 'true', 'per_page': 100, 'top_level_only': 'true'}
|
||||
|
||||
try:
|
||||
response, headers = await self._make_request(url, params)
|
||||
return response
|
||||
except Exception:
|
||||
logger.warning('Error fetching owned groups', exc_info=True)
|
||||
return []
|
||||
|
||||
async def add_owned_projects_and_groups_to_db(self, owned_personal_projects):
|
||||
"""
|
||||
Add owned projects and groups to the database for webhook tracking.
|
||||
|
||||
Args:
|
||||
owned_personal_projects: List of personal projects owned by the user
|
||||
"""
|
||||
owned_groups = await self.get_owned_groups()
|
||||
webhooks = []
|
||||
|
||||
def build_group_webhook_entries(groups):
|
||||
return [
|
||||
GitlabWebhook(
|
||||
group_id=str(group['id']),
|
||||
project_id=None,
|
||||
user_id=self.external_auth_id,
|
||||
webhook_exists=False,
|
||||
)
|
||||
for group in groups
|
||||
]
|
||||
|
||||
def build_project_webhook_entries(projects):
|
||||
return [
|
||||
GitlabWebhook(
|
||||
group_id=None,
|
||||
project_id=str(project['id']),
|
||||
user_id=self.external_auth_id,
|
||||
webhook_exists=False,
|
||||
)
|
||||
for project in projects
|
||||
]
|
||||
|
||||
# Collect all webhook entries
|
||||
webhooks.extend(build_group_webhook_entries(owned_groups))
|
||||
webhooks.extend(build_project_webhook_entries(owned_personal_projects))
|
||||
|
||||
# Store webhooks in the database
|
||||
if webhooks:
|
||||
try:
|
||||
webhook_store = GitlabWebhookStore()
|
||||
await webhook_store.store_webhooks(webhooks)
|
||||
logger.info(
|
||||
f'Added GitLab webhooks to db for user {self.external_auth_id}'
|
||||
)
|
||||
except Exception:
|
||||
logger.warning('Failed to add Gitlab webhooks to db', exc_info=True)
|
||||
|
||||
async def store_repository_data(
|
||||
self, users_personal_projects: list[dict], repositories: list[Repository]
|
||||
) -> None:
|
||||
"""
|
||||
Store repository data in the database.
|
||||
This function combines the functionality of add_owned_projects_and_groups_to_db and store_repositories_in_db.
|
||||
|
||||
Args:
|
||||
users_personal_projects: List of personal projects owned by the user
|
||||
repositories: List of Repository objects to store
|
||||
"""
|
||||
try:
|
||||
# First, add owned projects and groups to the database
|
||||
await self.add_owned_projects_and_groups_to_db(users_personal_projects)
|
||||
|
||||
# Then, store repositories in the database
|
||||
await store_repositories_in_db(repositories, self.external_auth_id)
|
||||
|
||||
logger.info(
|
||||
f'Successfully stored repository data for user {self.external_auth_id}'
|
||||
)
|
||||
except Exception:
|
||||
logger.warning('Error storing repository data', exc_info=True)
|
||||
|
||||
async def get_all_repositories(
|
||||
self, sort: str, app_mode: AppMode, store_in_background: bool = True
|
||||
) -> list[Repository]:
|
||||
"""
|
||||
Get repositories for the authenticated user, including information about the kind of project.
|
||||
Also collects repositories where the kind is "user" and the user is the owner.
|
||||
|
||||
Args:
|
||||
sort: The field to sort repositories by
|
||||
app_mode: The application mode (OSS or SAAS)
|
||||
|
||||
Returns:
|
||||
List[Repository]: A list of repositories for the authenticated user
|
||||
"""
|
||||
MAX_REPOS = 1000
|
||||
PER_PAGE = 100 # Maximum allowed by GitLab API
|
||||
all_repos: list[dict] = []
|
||||
users_personal_projects: list[dict] = []
|
||||
page = 1
|
||||
|
||||
url = f'{self.BASE_URL}/projects'
|
||||
# Map GitHub's sort values to GitLab's order_by values
|
||||
order_by = {
|
||||
'pushed': 'last_activity_at',
|
||||
'updated': 'last_activity_at',
|
||||
'created': 'created_at',
|
||||
'full_name': 'name',
|
||||
}.get(sort, 'last_activity_at')
|
||||
|
||||
user_id = None
|
||||
try:
|
||||
user_info = await self.get_user()
|
||||
user_id = user_info.id
|
||||
except Exception as e:
|
||||
logger.warning(f'Could not fetch user id: {e}')
|
||||
|
||||
while len(all_repos) < MAX_REPOS:
|
||||
params = {
|
||||
'page': str(page),
|
||||
'per_page': str(PER_PAGE),
|
||||
'order_by': order_by,
|
||||
'sort': 'desc', # GitLab uses sort for direction (asc/desc)
|
||||
'membership': 1, # Use 1 instead of True
|
||||
}
|
||||
|
||||
try:
|
||||
response, headers = await self._make_request(url, params)
|
||||
|
||||
if not response: # No more repositories
|
||||
break
|
||||
|
||||
# Process each repository to identify user-owned ones
|
||||
for repo in response:
|
||||
namespace = repo.get('namespace', {})
|
||||
kind = namespace.get('kind')
|
||||
owner_id = repo.get('owner', {}).get('id')
|
||||
|
||||
# Collect user owned personal projects
|
||||
if kind == 'user' and str(owner_id) == str(user_id):
|
||||
users_personal_projects.append(repo)
|
||||
|
||||
# Add to all repos regardless
|
||||
all_repos.append(repo)
|
||||
|
||||
page += 1
|
||||
|
||||
# Check if we've reached the last page
|
||||
link_header = headers.get('Link', '')
|
||||
if 'rel="next"' not in link_header:
|
||||
break
|
||||
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f'Error fetching repositories on page {page}', exc_info=True
|
||||
)
|
||||
break
|
||||
|
||||
# Trim to MAX_REPOS if needed and convert to Repository objects
|
||||
all_repos = all_repos[:MAX_REPOS]
|
||||
repositories = [
|
||||
Repository(
|
||||
id=str(repo.get('id')),
|
||||
full_name=str(repo.get('path_with_namespace')),
|
||||
stargazers_count=repo.get('star_count'),
|
||||
git_provider=ProviderType.GITLAB,
|
||||
is_public=repo.get('visibility') == 'public',
|
||||
)
|
||||
for repo in all_repos
|
||||
]
|
||||
|
||||
# Store webhook and repository info
|
||||
if store_in_background:
|
||||
asyncio.create_task(
|
||||
self.store_repository_data(users_personal_projects, repositories)
|
||||
)
|
||||
else:
|
||||
await self.store_repository_data(users_personal_projects, repositories)
|
||||
return repositories
|
||||
|
||||
async def check_resource_exists(
|
||||
self, resource_type: GitLabResourceType, resource_id: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if resource exists and the user has access to it.
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of resource to check
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if the resource exists and the user has access to it, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}'
|
||||
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
# If we get a response, the resource exists and the user has access to it
|
||||
return bool(response and 'id' in response), None
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Resource existence check failed', exc_info=True)
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
async def check_webhook_exists_on_resource(
|
||||
self, resource_type: GitLabResourceType, resource_id: str, webhook_url: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if a webhook already exists for resource with a specific URL.
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of the resource to check
|
||||
webhook_url: The URL of the webhook to check for
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if the webhook exists, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
# Construct the URL based on the resource type
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}/hooks'
|
||||
|
||||
try:
|
||||
# Get all webhooks for the resource
|
||||
response, _ = await self._make_request(url)
|
||||
|
||||
# Check if any webhook has the specified URL
|
||||
exists = False
|
||||
if response:
|
||||
for webhook in response:
|
||||
if webhook.get('url') == webhook_url:
|
||||
exists = True
|
||||
|
||||
return exists, None
|
||||
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Webhook existence check failed', exc_info=True)
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
async def check_user_has_admin_access_to_resource(
|
||||
self, resource_type: GitLabResourceType, resource_id: str
|
||||
) -> tuple[bool, WebhookStatus | None]:
|
||||
"""
|
||||
Check if the user has admin access to resource (is either an owner or maintainer)
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of the resource to check
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if the user has admin access to the resource (owner or maintainer), False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
# For groups, we need to check if the user is an owner or maintainer
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/members/all'
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
# Check if the current user is in the members list with access level >= 40 (Maintainer or Owner)
|
||||
|
||||
exists = False
|
||||
if response:
|
||||
current_user = await self.get_user()
|
||||
user_id = current_user.id
|
||||
for member in response:
|
||||
if (
|
||||
str(member.get('id')) == str(user_id)
|
||||
and member.get('access_level', 0) >= 40
|
||||
):
|
||||
exists = True
|
||||
return exists, None
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
# For projects, we need to check if the user has maintainer or owner access
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}/members/all'
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
exists = False
|
||||
# Check if the current user is in the members list with access level >= 40 (Maintainer)
|
||||
if response:
|
||||
current_user = await self.get_user()
|
||||
user_id = current_user.id
|
||||
for member in response:
|
||||
if (
|
||||
str(member.get('id')) == str(user_id)
|
||||
and member.get('access_level', 0) >= 40
|
||||
):
|
||||
exists = True
|
||||
return exists, None
|
||||
except RateLimitError:
|
||||
return False, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Admin access check failed', exc_info=True)
|
||||
return False, WebhookStatus.INVALID
|
||||
|
||||
async def install_webhook(
|
||||
self,
|
||||
resource_type: GitLabResourceType,
|
||||
resource_id: str,
|
||||
webhook_name: str,
|
||||
webhook_url: str,
|
||||
webhook_secret: str,
|
||||
webhook_uuid: str,
|
||||
scopes: list[str],
|
||||
) -> tuple[str | None, WebhookStatus | None]:
|
||||
"""
|
||||
Install webhook for user's group or project
|
||||
|
||||
Args:
|
||||
resource_type: The type of resource
|
||||
resource_id: The ID of the resource to check
|
||||
webhook_secret: Webhook secret that is used to verify payload
|
||||
webhook_name: Name of webhook
|
||||
webhook_url: Webhook URL
|
||||
scopes: activity webhook listens for
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: A tuple containing:
|
||||
- bool: True if installation was successful, False otherwise
|
||||
- str: A reason message explaining the result
|
||||
"""
|
||||
|
||||
description = 'Cloud OpenHands Resolver'
|
||||
|
||||
# Set up webhook parameters
|
||||
webhook_data = {
|
||||
'url': webhook_url,
|
||||
'name': webhook_name,
|
||||
'enable_ssl_verification': True,
|
||||
'token': webhook_secret,
|
||||
'description': description,
|
||||
}
|
||||
|
||||
for scope in scopes:
|
||||
webhook_data[scope] = True
|
||||
|
||||
# Add custom headers with user id
|
||||
if self.external_auth_id:
|
||||
webhook_data['custom_headers'] = [
|
||||
{'key': 'X-OpenHands-User-ID', 'value': self.external_auth_id},
|
||||
{'key': 'X-OpenHands-Webhook-ID', 'value': webhook_uuid},
|
||||
]
|
||||
|
||||
if resource_type == GitLabResourceType.GROUP:
|
||||
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{resource_id}/hooks'
|
||||
|
||||
try:
|
||||
# Make the API request
|
||||
response, _ = await self._make_request(
|
||||
url=url, params=webhook_data, method=RequestMethod.POST
|
||||
)
|
||||
|
||||
if response and 'id' in response:
|
||||
return str(response['id']), None
|
||||
|
||||
# Check if the webhook was created successfully
|
||||
return None, None
|
||||
|
||||
except RateLimitError:
|
||||
return None, WebhookStatus.RATE_LIMITED
|
||||
except Exception:
|
||||
logger.warning('Webhook installation failed', exc_info=True)
|
||||
return None, WebhookStatus.INVALID
|
||||
|
||||
async def user_has_write_access(self, project_id: str) -> bool:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}'
|
||||
try:
|
||||
response, _ = await self._make_request(url)
|
||||
# Check if the current user is in the members list with access level >= 30 (Developer)
|
||||
|
||||
if 'permissions' not in response:
|
||||
logger.info('permissions not found', extra={'response': response})
|
||||
return False
|
||||
|
||||
permissions = response['permissions']
|
||||
if permissions['project_access']:
|
||||
logger.info('[GitLab]: Checking project access')
|
||||
return permissions['project_access']['access_level'] >= 30
|
||||
|
||||
if permissions['group_access']:
|
||||
logger.info('[GitLab]: Checking group access')
|
||||
return permissions['group_access']['access_level'] >= 30
|
||||
|
||||
return False
|
||||
except Exception:
|
||||
logger.warning('Access check failed', exc_info=True)
|
||||
return False
|
||||
|
||||
async def reply_to_issue(
|
||||
self, project_id: str, issue_number: str, discussion_id: str | None, body: str
|
||||
):
|
||||
"""
|
||||
Either create new comment thread, or reply to comment thread (depending on discussion_id param)
|
||||
"""
|
||||
try:
|
||||
if discussion_id:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions/{discussion_id}/notes'
|
||||
else:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions'
|
||||
params = {'body': body}
|
||||
|
||||
await self._make_request(url=url, params=params, method=RequestMethod.POST)
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab]: Reply to issue failed {e}')
|
||||
|
||||
async def reply_to_mr(
|
||||
self, project_id: str, merge_request_iid: str, discussion_id: str, body: str
|
||||
):
|
||||
"""
|
||||
Reply to comment thread on MR
|
||||
"""
|
||||
try:
|
||||
url = f'{self.BASE_URL}/projects/{project_id}/merge_requests/{merge_request_iid}/discussions/{discussion_id}/notes'
|
||||
params = {'body': body}
|
||||
|
||||
await self._make_request(url=url, params=params, method=RequestMethod.POST)
|
||||
except Exception as e:
|
||||
logger.exception(f'[GitLab]: Reply to MR failed {e}')
|
||||
450
enterprise/integrations/gitlab/gitlab_view.py
Normal file
450
enterprise/integrations/gitlab/gitlab_view.py
Normal file
@@ -0,0 +1,450 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.types import ResolverViewInterface, UserData
|
||||
from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||
from jinja2 import Environment
|
||||
from server.auth.token_manager import TokenManager, get_config
|
||||
from storage.database import session_maker
|
||||
from storage.saas_secrets_store import SaasSecretsStore
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||
from openhands.integrations.service_types import Comment
|
||||
from openhands.server.services.conversation_service import create_new_conversation
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||
CONFIDENTIAL_NOTE = 'confidential_note'
|
||||
NOTE_TYPES = ['note', CONFIDENTIAL_NOTE]
|
||||
|
||||
# =================================================
|
||||
# SECTION: Factory to create appriorate Gitlab view
|
||||
# =================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabIssue(ResolverViewInterface):
|
||||
installation_id: str # Webhook installation ID for Gitlab (comes from our DB)
|
||||
issue_number: int
|
||||
project_id: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
user_info: UserData
|
||||
raw_payload: Message
|
||||
conversation_id: str
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
title: str
|
||||
description: str
|
||||
previous_comments: list[Comment]
|
||||
is_mr: bool
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
gitlab_service = GitLabServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
self.previous_comments = await gitlab_service.get_issue_or_mr_comments(
|
||||
str(self.project_id), self.issue_number, is_mr=self.is_mr
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await gitlab_service.get_issue_or_mr_title_and_body(
|
||||
str(self.project_id), self.issue_number, is_mr=self.is_mr
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def _get_user_secrets(self):
|
||||
secrets_store = SaasSecretsStore(
|
||||
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||
)
|
||||
user_secrets = await secrets_store.load()
|
||||
|
||||
return user_secrets.custom_secrets if user_secrets else None
|
||||
|
||||
async def create_new_conversation(
|
||||
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions,
|
||||
image_urls=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
replay_json=None,
|
||||
)
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
return self.conversation_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabIssueComment(GitlabIssue):
|
||||
comment_body: str
|
||||
discussion_id: str
|
||||
confidential: bool
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
issue_comment=self.comment_body
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'issue_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
issue_number=self.issue_number,
|
||||
issue_title=self.title,
|
||||
issue_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabMRComment(GitlabIssueComment):
|
||||
branch_name: str
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
mr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'mr_update_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
mr_number=self.issue_number,
|
||||
branch_name=self.branch_name,
|
||||
mr_title=self.title,
|
||||
mr_body=self.description,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
async def create_new_conversation(
|
||||
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||
):
|
||||
custom_secrets = await self._get_user_secrets()
|
||||
|
||||
user_instructions, conversation_instructions = await self._get_instructions(
|
||||
jinja_env
|
||||
)
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.user_info.keycloak_user_id,
|
||||
git_provider_tokens=git_provider_tokens,
|
||||
custom_secrets=custom_secrets,
|
||||
selected_repository=self.full_repo_name,
|
||||
selected_branch=self.branch_name,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions,
|
||||
image_urls=None,
|
||||
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||
replay_json=None,
|
||||
)
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
return self.conversation_id
|
||||
|
||||
|
||||
@dataclass
|
||||
class GitlabInlineMRComment(GitlabMRComment):
|
||||
file_location: str
|
||||
line_number: int
|
||||
|
||||
async def _load_resolver_context(self):
|
||||
gitlab_service = GitLabServiceImpl(
|
||||
external_auth_id=self.user_info.keycloak_user_id
|
||||
)
|
||||
|
||||
(
|
||||
self.title,
|
||||
self.description,
|
||||
) = await gitlab_service.get_issue_or_mr_title_and_body(
|
||||
str(self.project_id), self.issue_number, is_mr=self.is_mr
|
||||
)
|
||||
|
||||
self.previous_comments = await gitlab_service.get_review_thread_comments(
|
||||
str(self.project_id), self.issue_number, self.discussion_id
|
||||
)
|
||||
|
||||
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||
await self._load_resolver_context()
|
||||
|
||||
user_instructions = user_instructions_template.render(
|
||||
mr_comment=self.comment_body,
|
||||
)
|
||||
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'mr_update_conversation_instructions.j2'
|
||||
)
|
||||
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
mr_number=self.issue_number,
|
||||
mr_title=self.title,
|
||||
mr_body=self.description,
|
||||
branch_name=self.branch_name,
|
||||
file_location=self.file_location,
|
||||
line_number=self.line_number,
|
||||
comments=self.previous_comments,
|
||||
)
|
||||
|
||||
return user_instructions, conversation_instructions
|
||||
|
||||
|
||||
GitlabViewType = (
|
||||
GitlabInlineMRComment | GitlabMRComment | GitlabIssueComment | GitlabIssue
|
||||
)
|
||||
|
||||
|
||||
class GitlabFactory:
|
||||
@staticmethod
|
||||
def is_labeled_issue(message: Message) -> bool:
|
||||
payload = message.message['payload']
|
||||
object_kind = payload.get('object_kind')
|
||||
event_type = payload.get('event_type')
|
||||
|
||||
if object_kind == 'issue' and event_type == 'issue':
|
||||
changes = payload.get('changes', {})
|
||||
labels = changes.get('labels', {})
|
||||
previous = labels.get('previous', [])
|
||||
current = labels.get('current', [])
|
||||
|
||||
previous_labels = [obj['title'] for obj in previous]
|
||||
current_labels = [obj['title'] for obj in current]
|
||||
|
||||
if OH_LABEL not in previous_labels and OH_LABEL in current_labels:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_issue_comment(message: Message) -> bool:
|
||||
payload = message.message['payload']
|
||||
object_kind = payload.get('object_kind')
|
||||
event_type = payload.get('event_type')
|
||||
issue = payload.get('issue')
|
||||
|
||||
if object_kind == 'note' and event_type in NOTE_TYPES and issue:
|
||||
comment_body = payload.get('object_attributes', {}).get('note', '')
|
||||
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_mr_comment(message: Message, inline=False) -> bool:
|
||||
payload = message.message['payload']
|
||||
object_kind = payload.get('object_kind')
|
||||
event_type = payload.get('event_type')
|
||||
merge_request = payload.get('merge_request')
|
||||
|
||||
if not (object_kind == 'note' and event_type in NOTE_TYPES and merge_request):
|
||||
return False
|
||||
|
||||
# Check whether not belongs to MR
|
||||
object_attributes = payload.get('object_attributes', {})
|
||||
noteable_type = object_attributes.get('noteable_type')
|
||||
|
||||
if noteable_type != 'MergeRequest':
|
||||
return False
|
||||
|
||||
# Check whether comment is inline
|
||||
change_position = object_attributes.get('change_position')
|
||||
if inline and not change_position:
|
||||
return False
|
||||
if not inline and change_position:
|
||||
return False
|
||||
|
||||
# Check body
|
||||
comment_body = object_attributes.get('note', '')
|
||||
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||
|
||||
@staticmethod
|
||||
def determine_if_confidential(event_type: str):
|
||||
return event_type == CONFIDENTIAL_NOTE
|
||||
|
||||
@staticmethod
|
||||
async def create_gitlab_view_from_payload(
|
||||
message: Message, token_manager: TokenManager
|
||||
) -> ResolverViewInterface:
|
||||
payload = message.message['payload']
|
||||
installation_id = message.message['installation_id']
|
||||
user = payload['user']
|
||||
user_id = user['id']
|
||||
username = user['username']
|
||||
repo_obj = payload['project']
|
||||
selected_project = repo_obj['path_with_namespace']
|
||||
is_public_repo = repo_obj['visibility_level'] == 0
|
||||
project_id = payload['object_attributes']['project_id']
|
||||
|
||||
keycloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||
user_id, ProviderType.GITLAB
|
||||
)
|
||||
|
||||
user_info = UserData(
|
||||
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||
)
|
||||
|
||||
if GitlabFactory.is_labeled_issue(message):
|
||||
issue_iid = payload['object_attributes']['iid']
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for labeled issue from {username} in {selected_project}#{issue_iid}'
|
||||
)
|
||||
return GitlabIssue(
|
||||
installation_id=installation_id,
|
||||
issue_number=issue_iid,
|
||||
project_id=project_id,
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_issue_comment(message):
|
||||
event_type = payload['event_type']
|
||||
issue_iid = payload['issue']['iid']
|
||||
object_attributes = payload['object_attributes']
|
||||
discussion_id = object_attributes['discussion_id']
|
||||
comment_body = object_attributes['note']
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for issue comment from {username} in {selected_project}#{issue_iid}'
|
||||
)
|
||||
|
||||
return GitlabIssueComment(
|
||||
installation_id=installation_id,
|
||||
comment_body=comment_body,
|
||||
issue_number=issue_iid,
|
||||
discussion_id=discussion_id,
|
||||
project_id=project_id,
|
||||
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=False,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message):
|
||||
event_type = payload['event_type']
|
||||
merge_request_iid = payload['merge_request']['iid']
|
||||
branch_name = payload['merge_request']['source_branch']
|
||||
object_attributes = payload['object_attributes']
|
||||
discussion_id = object_attributes['discussion_id']
|
||||
comment_body = object_attributes['note']
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for merge request comment from {username} in {selected_project}#{merge_request_iid}'
|
||||
)
|
||||
|
||||
return GitlabMRComment(
|
||||
installation_id=installation_id,
|
||||
comment_body=comment_body,
|
||||
issue_number=merge_request_iid, # Using issue_number as mr_number for compatibility
|
||||
discussion_id=discussion_id,
|
||||
project_id=project_id,
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||
branch_name=branch_name,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
|
||||
elif GitlabFactory.is_mr_comment(message, inline=True):
|
||||
event_type = payload['event_type']
|
||||
merge_request_iid = payload['merge_request']['iid']
|
||||
branch_name = payload['merge_request']['source_branch']
|
||||
object_attributes = payload['object_attributes']
|
||||
comment_body = object_attributes['note']
|
||||
position_info = object_attributes['position']
|
||||
discussion_id = object_attributes['discussion_id']
|
||||
file_location = object_attributes['position']['new_path']
|
||||
line_number = (
|
||||
position_info.get('new_line') or position_info.get('old_line') or 0
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[GitLab] Creating view for inline merge request comment from {username} in {selected_project}#{merge_request_iid}'
|
||||
)
|
||||
|
||||
return GitlabInlineMRComment(
|
||||
installation_id=installation_id,
|
||||
issue_number=merge_request_iid, # Using issue_number as mr_number for compatibility
|
||||
discussion_id=discussion_id,
|
||||
project_id=project_id,
|
||||
full_repo_name=selected_project,
|
||||
is_public_repo=is_public_repo,
|
||||
user_info=user_info,
|
||||
raw_payload=message,
|
||||
conversation_id='',
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||
branch_name=branch_name,
|
||||
file_location=file_location,
|
||||
line_number=line_number,
|
||||
comment_body=comment_body,
|
||||
title='',
|
||||
description='',
|
||||
previous_comments=[],
|
||||
is_mr=True,
|
||||
)
|
||||
503
enterprise/integrations/jira/jira_manager.py
Normal file
503
enterprise/integrations/jira/jira_manager.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.jira.jira_types import JiraViewInterface
|
||||
from integrations.jira.jira_view import (
|
||||
JiraExistingConversationView,
|
||||
JiraFactory,
|
||||
JiraNewConversationView,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||
|
||||
|
||||
class JiraManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, jira_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Jira user by Keycloak user ID and workspace ID
|
||||
jira_user = await self.integration_store.get_active_user(
|
||||
jira_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not jira_user:
|
||||
logger.warning(
|
||||
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_user.keycloak_user_id
|
||||
)
|
||||
return jira_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
account_id = user_data.get('accountId')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
account_id,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=account_id,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira view
|
||||
jira_view = await JiraFactory.create_jira_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_view: JiraViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_view, JiraExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_view: JiraViewInterface):
|
||||
"""Start a Jira job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_callback_processor import (
|
||||
JiraCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraUser = jira_view.jira_user
|
||||
logger.info(
|
||||
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await jira_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
# Register callback processor for updates
|
||||
if isinstance(jira_view, JiraNewConversationView):
|
||||
processor = JiraCallbackProcessor(
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
workspace_name=jira_view.jira_workspace.name,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Jira] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
) -> Tuple[str, str]:
|
||||
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
message: Message,
|
||||
issue_key: str,
|
||||
jira_cloud_id: str,
|
||||
svc_acc_email: str,
|
||||
svc_acc_api_key: str,
|
||||
):
|
||||
url = (
|
||||
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||
)
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraWorkspace | None,
|
||||
):
|
||||
"""Send error comment to Jira issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=job_context.issue_key,
|
||||
jira_cloud_id=workspace.jira_cloud_id,
|
||||
svc_acc_email=workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_view.jira_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_view.job_context.issue_key,
|
||||
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
40
enterprise/integrations/jira/jira_types.py
Normal file
40
enterprise/integrations/jira/jira_types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class JiraViewInterface(ABC):
|
||||
"""Interface for Jira views that handle different types of Jira interactions."""
|
||||
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create or update a conversation and return the conversation ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira."""
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
222
enterprise/integrations/jira/jira_view.py
Normal file
222
enterprise/integrations/jira/jira_view.py
Normal file
@@ -0,0 +1,222 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.jira_conversation import JiraConversation
|
||||
from storage.jira_integration_store import JiraIntegrationStore
|
||||
from storage.jira_user import JiraUser
|
||||
from storage.jira_workspace import JiraWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = JiraIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraNewConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.jira_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_msg,
|
||||
conversation_instructions=instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.JIRA,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Jira] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Jira conversation mapping
|
||||
jira_conversation = JiraConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
jira_user_id=self.jira_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraExistingConversationView(JiraViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_user: JiraUser
|
||||
jira_workspace: JiraWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
# Should we raise here if there are no providers?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraFactory:
|
||||
"""Factory for creating Jira views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
async def create_jira_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_user: JiraUser,
|
||||
jira_workspace: JiraWorkspace,
|
||||
) -> JiraViewInterface:
|
||||
"""Create appropriate Jira view based on the message and user state"""
|
||||
|
||||
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_user.id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return JiraExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_user=jira_user,
|
||||
jira_workspace=jira_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
508
enterprise/integrations/jira_dc/jira_dc_manager.py
Normal file
508
enterprise/integrations/jira_dc/jira_dc_manager.py
Normal file
@@ -0,0 +1,508 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.jira_dc.jira_dc_types import (
|
||||
JiraDcViewInterface,
|
||||
)
|
||||
from integrations.jira_dc.jira_dc_view import (
|
||||
JiraDcExistingConversationView,
|
||||
JiraDcFactory,
|
||||
JiraDcNewConversationView,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.jira_dc_integration_store import JiraDcIntegrationStore
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class JiraDcManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = JiraDcIntegrationStore.get_instance()
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira_dc')
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, user_email: str, jira_dc_user_id: str, workspace_id: int
|
||||
) -> tuple[JiraDcUser | None, UserAuth | None]:
|
||||
"""Authenticate Jira DC user and get their OpenHands user auth."""
|
||||
|
||||
if not jira_dc_user_id or jira_dc_user_id == 'none':
|
||||
# Get Keycloak user ID from email
|
||||
keycloak_user_id = await self.token_manager.get_user_id_from_user_email(
|
||||
user_email
|
||||
)
|
||||
if not keycloak_user_id:
|
||||
logger.warning(
|
||||
f'[Jira DC] No Keycloak user found for email: {user_email}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
# Find active Jira DC user by Keycloak user ID and organization
|
||||
jira_dc_user = await self.integration_store.get_active_user_by_keycloak_id_and_workspace(
|
||||
keycloak_user_id, workspace_id
|
||||
)
|
||||
else:
|
||||
jira_dc_user = await self.integration_store.get_active_user(
|
||||
jira_dc_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not jira_dc_user:
|
||||
logger.warning(
|
||||
f'[Jira DC] No active Jira DC user found for {user_email} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
jira_dc_user.keycloak_user_id
|
||||
)
|
||||
return jira_dc_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Jira DC webhook signature."""
|
||||
signature_header = request.headers.get('x-hub-signature')
|
||||
signature = signature_header.split('=')[1] if signature_header else None
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
workspace_name = ''
|
||||
|
||||
if payload.get('webhookEvent') == 'comment_created':
|
||||
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||
selfUrl = payload.get('user', {}).get('self')
|
||||
else:
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(selfUrl)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Jira DC] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Jira DC] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Jira DC] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira DC] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Jira DC] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
event_type = payload.get('webhookEvent')
|
||||
|
||||
if event_type == 'comment_created':
|
||||
comment_data = payload.get('comment', {})
|
||||
comment = comment_data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = comment_data.get('author', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
user_key = user_data.get('key')
|
||||
elif event_type == 'jira:issue_updated':
|
||||
changelog = payload.get('changelog', {})
|
||||
items = changelog.get('items', [])
|
||||
labels = [
|
||||
item.get('toString', '')
|
||||
for item in items
|
||||
if item.get('field') == 'labels' and 'toString' in item
|
||||
]
|
||||
|
||||
if 'openhands' not in labels:
|
||||
return None
|
||||
|
||||
issue_data = payload.get('issue', {})
|
||||
issue_id = issue_data.get('id')
|
||||
issue_key = issue_data.get('key')
|
||||
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||
|
||||
user_data = payload.get('user', {})
|
||||
user_email = user_data.get('emailAddress')
|
||||
display_name = user_data.get('displayName')
|
||||
user_key = user_data.get('key')
|
||||
comment = ''
|
||||
else:
|
||||
return None
|
||||
|
||||
workspace_name = ''
|
||||
|
||||
parsedUrl = urlparse(base_api_url)
|
||||
if parsedUrl.hostname:
|
||||
workspace_name = parsedUrl.hostname
|
||||
|
||||
if not all(
|
||||
[
|
||||
issue_id,
|
||||
issue_key,
|
||||
user_email,
|
||||
display_name,
|
||||
user_key,
|
||||
workspace_name,
|
||||
base_api_url,
|
||||
]
|
||||
):
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
display_name=display_name,
|
||||
platform_user_id=user_key,
|
||||
workspace_name=workspace_name,
|
||||
base_api_url=base_api_url,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Jira DC webhook message."""
|
||||
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Jira DC] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Jira DC] No workspace found for email domain: {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Your workspace is not configured with Jira DC integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Jira DC] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Jira DC integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
jira_dc_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.user_email, job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not jira_dc_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Jira DC] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Jira DC integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira DC] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Jira DC view
|
||||
jira_dc_view = await JiraDcFactory.create_jira_dc_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
jira_dc_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira DC] Failed to create jira dc view: {str(e)}', exc_info=True
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, jira_dc_view):
|
||||
return
|
||||
|
||||
await self.start_job(jira_dc_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, jira_dc_view: JiraDcViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(jira_dc_view, JiraDcExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
jira_dc_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{jira_dc_view.job_context.issue_description}\n{jira_dc_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
jira_dc_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Jira DC] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(jira_dc_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira DC] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, jira_dc_view: JiraDcViewInterface):
|
||||
"""Start a Jira DC job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.jira_dc_callback_processor import (
|
||||
JiraDcCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: JiraDcUser = jira_dc_view.jira_dc_user
|
||||
logger.info(
|
||||
f'[Jira DC] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {jira_dc_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await jira_dc_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created/Updated conversation {conversation_id} for issue {jira_dc_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
if isinstance(jira_dc_view, JiraDcNewConversationView):
|
||||
# Register callback processor for updates
|
||||
processor = JiraDcCallbackProcessor(
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
workspace_name=jira_dc_view.jira_dc_workspace.name,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Jira DC] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = jira_dc_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Jira DC] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Jira DC] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira DC] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||
|
||||
async def get_issue_details(
|
||||
self, job_context: JobContext, svc_acc_api_key: str
|
||||
) -> Tuple[str, str]:
|
||||
"""Get issue details from Jira DC API."""
|
||||
url = f'{job_context.base_api_url}/rest/api/2/issue/{job_context.issue_key}'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
issue_payload = response.json()
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||
|
||||
title = issue_payload.get('fields', {}).get('summary', '')
|
||||
description = issue_payload.get('fields', {}).get('description', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||
)
|
||||
|
||||
if not description:
|
||||
raise ValueError(
|
||||
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||
)
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(
|
||||
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||
):
|
||||
"""Send message/comment to Jira DC issue."""
|
||||
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||
data = {'body': message.message}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def _send_error_comment(
|
||||
self,
|
||||
job_context: JobContext,
|
||||
error_msg: str,
|
||||
workspace: JiraDcWorkspace | None,
|
||||
):
|
||||
"""Send error comment to Jira DC issue."""
|
||||
if not workspace:
|
||||
logger.error('[Jira DC] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg),
|
||||
issue_key=job_context.issue_key,
|
||||
base_api_url=job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Jira DC] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, jira_dc_view: JiraDcViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
issue_key=jira_dc_view.job_context.issue_key,
|
||||
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||
svc_acc_api_key=api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Jira] Sent repository selection comment for issue {jira_dc_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
40
enterprise/integrations/jira_dc/jira_dc_types.py
Normal file
40
enterprise/integrations/jira_dc/jira_dc_types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class JiraDcViewInterface(ABC):
|
||||
"""Interface for Jira DC views that handle different types of Jira DC interactions."""
|
||||
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_dc_user: JiraDcUser
|
||||
jira_dc_workspace: JiraDcWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create or update a conversation and return the conversation ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira DC."""
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
223
enterprise/integrations/jira_dc/jira_dc_view.py
Normal file
223
enterprise/integrations/jira_dc/jira_dc_view.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.jira_dc.jira_dc_types import (
|
||||
JiraDcViewInterface,
|
||||
StartingConvoException,
|
||||
)
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.jira_dc_conversation import JiraDcConversation
|
||||
from storage.jira_dc_integration_store import JiraDcIntegrationStore
|
||||
from storage.jira_dc_user import JiraDcUser
|
||||
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = JiraDcIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraDcNewConversationView(JiraDcViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_dc_user: JiraDcUser
|
||||
jira_dc_workspace: JiraDcWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Jira DC conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.jira_dc_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_msg,
|
||||
conversation_instructions=instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.JIRA_DC,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Jira DC] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Jira DC conversation mapping
|
||||
jira_dc_conversation = JiraDcConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
jira_dc_user_id=self.jira_dc_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(jira_dc_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira DC] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira DC"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
@dataclass
|
||||
class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
jira_dc_user: JiraDcUser
|
||||
jira_dc_workspace: JiraDcWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Jira conversation"""
|
||||
|
||||
user_id = self.jira_dc_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('Could not load provider tokens')
|
||||
providers_set = list(provider_tokens.keys())
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Jira"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||
|
||||
|
||||
class JiraDcFactory:
|
||||
"""Factory class for creating Jira DC views based on message type."""
|
||||
|
||||
@staticmethod
|
||||
async def create_jira_dc_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
jira_dc_user: JiraDcUser,
|
||||
jira_dc_workspace: JiraDcWorkspace,
|
||||
) -> JiraDcViewInterface:
|
||||
"""Create appropriate Jira DC view based on the payload."""
|
||||
|
||||
if not jira_dc_user or not saas_user_auth or not jira_dc_workspace:
|
||||
raise StartingConvoException('User not authenticated with Jira integration')
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, jira_dc_user.id
|
||||
)
|
||||
|
||||
if conversation:
|
||||
return JiraDcExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_dc_user=jira_dc_user,
|
||||
jira_dc_workspace=jira_dc_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return JiraDcNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
jira_dc_user=jira_dc_user,
|
||||
jira_dc_workspace=jira_dc_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
522
enterprise/integrations/linear/linear_manager.py
Normal file
522
enterprise/integrations/linear/linear_manager.py
Normal file
@@ -0,0 +1,522 @@
|
||||
import hashlib
|
||||
import hmac
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
from integrations.linear.linear_types import LinearViewInterface
|
||||
from integrations.linear.linear_view import (
|
||||
LinearExistingConversationView,
|
||||
LinearFactory,
|
||||
LinearNewConversationView,
|
||||
)
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import JobContext, Message
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
filter_potential_repos_by_user_msg,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from storage.linear_integration_store import LinearIntegrationStore
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class LinearManager(Manager):
|
||||
def __init__(self, token_manager: TokenManager):
|
||||
self.token_manager = token_manager
|
||||
self.integration_store = LinearIntegrationStore.get_instance()
|
||||
self.api_url = 'https://api.linear.app/graphql'
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'linear')
|
||||
)
|
||||
|
||||
async def authenticate_user(
|
||||
self, linear_user_id: str, workspace_id: int
|
||||
) -> tuple[LinearUser | None, UserAuth | None]:
|
||||
"""Authenticate Linear user and get their OpenHands user auth."""
|
||||
|
||||
# Find active Linear user by Linear user ID and workspace ID
|
||||
linear_user = await self.integration_store.get_active_user(
|
||||
linear_user_id, workspace_id
|
||||
)
|
||||
|
||||
if not linear_user:
|
||||
logger.warning(
|
||||
f'[Linear] No active Linear user found for {linear_user_id} in workspace {workspace_id}'
|
||||
)
|
||||
return None, None
|
||||
|
||||
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||
linear_user.keycloak_user_id
|
||||
)
|
||||
return linear_user, saas_user_auth
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
"""Get repositories that the user has access to."""
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
async def validate_request(
|
||||
self, request: Request
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||
"""Verify Linear webhook signature."""
|
||||
signature = request.headers.get('linear-signature')
|
||||
body = await request.body()
|
||||
payload = await request.json()
|
||||
actor_url = payload.get('actor', {}).get('url', '')
|
||||
workspace_name = ''
|
||||
|
||||
# Extract workspace name from actor URL
|
||||
# Format: https://linear.app/{workspace}/profiles/{user}
|
||||
if actor_url.startswith('https://linear.app/'):
|
||||
url_parts = actor_url.split('/')
|
||||
if len(url_parts) >= 4:
|
||||
workspace_name = url_parts[3] # Extract workspace name
|
||||
else:
|
||||
logger.warning(f'[Linear] Invalid actor URL format: {actor_url}')
|
||||
return False, None, None
|
||||
else:
|
||||
logger.warning(
|
||||
f'[Linear] Actor URL does not match expected format: {actor_url}'
|
||||
)
|
||||
return False, None, None
|
||||
|
||||
if not workspace_name:
|
||||
logger.warning('[Linear] No workspace name found in webhook payload')
|
||||
return False, None, None
|
||||
|
||||
if not signature:
|
||||
logger.warning('[Linear] No signature found in webhook headers')
|
||||
return False, None, None
|
||||
|
||||
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||
|
||||
if not workspace:
|
||||
logger.warning('[Linear] Could not identify workspace for webhook')
|
||||
return False, None, None
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Linear] Workspace {workspace.id} is not active')
|
||||
return False, None, None
|
||||
|
||||
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||
|
||||
if hmac.compare_digest(signature, digest):
|
||||
logger.info('[Linear] Webhook signature verified successfully')
|
||||
return True, signature, payload
|
||||
|
||||
return False, None, None
|
||||
|
||||
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||
action = payload.get('action')
|
||||
type = payload.get('type')
|
||||
|
||||
if action == 'create' and type == 'Comment':
|
||||
data = payload.get('data', {})
|
||||
comment = data.get('body', '')
|
||||
|
||||
if '@openhands' not in comment:
|
||||
return None
|
||||
|
||||
issue_data = data.get('issue', {})
|
||||
issue_id = issue_data.get('id', '')
|
||||
issue_key = issue_data.get('identifier', '')
|
||||
elif action == 'update' and type == 'Issue':
|
||||
data = payload.get('data', {})
|
||||
labels = data.get('labels', [])
|
||||
|
||||
has_openhands_label = False
|
||||
label_id = ''
|
||||
for label in labels:
|
||||
if label.get('name') == 'openhands':
|
||||
label_id = label.get('id', '')
|
||||
has_openhands_label = True
|
||||
break
|
||||
|
||||
if not has_openhands_label and not label_id:
|
||||
return None
|
||||
|
||||
labelIdChanges = data.get('updatedFrom', {}).get('labelIds', [])
|
||||
|
||||
if labelIdChanges and label_id in labelIdChanges:
|
||||
return None # Label was added previously, ignore this webhook
|
||||
|
||||
issue_id = data.get('id', '')
|
||||
issue_key = data.get('identifier', '')
|
||||
comment = ''
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
actor = payload.get('actor', {})
|
||||
display_name = actor.get('name', '')
|
||||
user_email = actor.get('email', '')
|
||||
actor_url = actor.get('url', '')
|
||||
actor_id = actor.get('id', '')
|
||||
workspace_name = ''
|
||||
|
||||
if actor_url.startswith('https://linear.app/'):
|
||||
url_parts = actor_url.split('/')
|
||||
if len(url_parts) >= 4:
|
||||
workspace_name = url_parts[3] # Extract workspace name
|
||||
else:
|
||||
logger.warning(f'[Linear] Invalid actor URL format: {actor_url}')
|
||||
return None
|
||||
else:
|
||||
logger.warning(
|
||||
f'[Linear] Actor URL does not match expected format: {actor_url}'
|
||||
)
|
||||
return None
|
||||
|
||||
if not all(
|
||||
[issue_id, issue_key, display_name, user_email, actor_id, workspace_name]
|
||||
):
|
||||
logger.warning('[Linear] Missing required fields in webhook payload')
|
||||
return None
|
||||
|
||||
return JobContext(
|
||||
issue_id=issue_id,
|
||||
issue_key=issue_key,
|
||||
user_msg=comment,
|
||||
user_email=user_email,
|
||||
platform_user_id=actor_id,
|
||||
workspace_name=workspace_name,
|
||||
display_name=display_name,
|
||||
)
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
"""Process incoming Linear webhook message."""
|
||||
payload = message.message.get('payload', {})
|
||||
job_context = self.parse_webhook(payload)
|
||||
|
||||
if not job_context:
|
||||
logger.info('[Linear] Webhook does not match trigger conditions')
|
||||
return
|
||||
|
||||
# Get workspace by user email domain
|
||||
workspace = await self.integration_store.get_workspace_by_name(
|
||||
job_context.workspace_name
|
||||
)
|
||||
if not workspace:
|
||||
logger.warning(
|
||||
f'[Linear] No workspace found for email domain: {job_context.workspace_name}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Your workspace is not configured with Linear integration.',
|
||||
None,
|
||||
)
|
||||
return
|
||||
|
||||
# Prevent any recursive triggers from the service account
|
||||
if job_context.user_email == workspace.svc_acc_email:
|
||||
return
|
||||
|
||||
if workspace.status != 'active':
|
||||
logger.warning(f'[Linear] Workspace {workspace.id} is not active')
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Linear integration is not active for your workspace.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Authenticate user
|
||||
linear_user, saas_user_auth = await self.authenticate_user(
|
||||
job_context.platform_user_id, workspace.id
|
||||
)
|
||||
if not linear_user or not saas_user_auth:
|
||||
logger.warning(
|
||||
f'[Linear] User authentication failed for {job_context.user_email}'
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
f'User {job_context.user_email} is not authenticated or active in the Linear integration.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
# Get issue details
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
issue_title, issue_description = await self.get_issue_details(
|
||||
job_context.issue_id, api_key
|
||||
)
|
||||
job_context.issue_title = issue_title
|
||||
job_context.issue_description = issue_description
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to get issue context: {str(e)}')
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Failed to retrieve issue details. Please check the issue ID and try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Create Linear view
|
||||
linear_view = await LinearFactory.create_linear_view_from_payload(
|
||||
job_context,
|
||||
saas_user_auth,
|
||||
linear_user,
|
||||
workspace,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to create linear view: {str(e)}', exc_info=True
|
||||
)
|
||||
await self._send_error_comment(
|
||||
job_context.issue_id,
|
||||
'Failed to initialize conversation. Please try again.',
|
||||
workspace,
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, linear_view):
|
||||
return
|
||||
|
||||
await self.start_job(linear_view)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, linear_view: LinearViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a job is requested and handle repository selection.
|
||||
"""
|
||||
|
||||
if isinstance(linear_view, LinearExistingConversationView):
|
||||
return True
|
||||
|
||||
try:
|
||||
# Get user repositories
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
linear_view.saas_user_auth
|
||||
)
|
||||
|
||||
target_str = f'{linear_view.job_context.issue_description}\n{linear_view.job_context.user_msg}'
|
||||
|
||||
# Try to infer repository from issue description
|
||||
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||
|
||||
if match:
|
||||
# Found exact repository match
|
||||
linear_view.selected_repo = repos[0].full_name
|
||||
logger.info(f'[Linear] Inferred repository: {repos[0].full_name}')
|
||||
return True
|
||||
else:
|
||||
# No clear match - send repository selection comment
|
||||
await self._send_repo_selection_comment(linear_view)
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Error in is_job_requested: {str(e)}')
|
||||
return False
|
||||
|
||||
async def start_job(self, linear_view: LinearViewInterface):
|
||||
"""Start a Linear job/conversation."""
|
||||
# Import here to prevent circular import
|
||||
from server.conversation_callback_processor.linear_callback_processor import (
|
||||
LinearCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
user_info: LinearUser = linear_view.linear_user
|
||||
logger.info(
|
||||
f'[Linear] Starting job for user {user_info.keycloak_user_id} '
|
||||
f'issue {linear_view.job_context.issue_key}',
|
||||
)
|
||||
|
||||
# Create conversation
|
||||
conversation_id = await linear_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created/Updated conversation {conversation_id} for issue {linear_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
if isinstance(linear_view, LinearNewConversationView):
|
||||
# Register callback processor for updates
|
||||
processor = LinearCallbackProcessor(
|
||||
issue_id=linear_view.job_context.issue_id,
|
||||
issue_key=linear_view.job_context.issue_key,
|
||||
workspace_name=linear_view.linear_workspace.name,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
# Send initial response
|
||||
msg_info = linear_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(f'[Linear] Missing settings error: {str(e)}')
|
||||
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(f'[Linear] LLM authentication error: {str(e)}')
|
||||
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Unexpected error starting job: {str(e)}', exc_info=True
|
||||
)
|
||||
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||
|
||||
# Send response comment
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=msg_info),
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send response message: {str(e)}')
|
||||
|
||||
async def _query_api(self, query: str, variables: Dict, api_key: str) -> Dict:
|
||||
"""Query Linear GraphQL API."""
|
||||
headers = {'Authorization': api_key}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.api_url,
|
||||
headers=headers,
|
||||
json={'query': query, 'variables': variables},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_issue_details(self, issue_id: str, api_key: str) -> Tuple[str, str]:
|
||||
"""Get issue details from Linear API."""
|
||||
query = """
|
||||
query Issue($issueId: String!) {
|
||||
issue(id: $issueId) {
|
||||
id
|
||||
identifier
|
||||
title
|
||||
description
|
||||
syncedWith {
|
||||
metadata {
|
||||
... on ExternalEntityInfoGithubMetadata {
|
||||
owner
|
||||
repo
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
issue_payload = await self._query_api(query, {'issueId': issue_id}, api_key)
|
||||
|
||||
if not issue_payload:
|
||||
raise ValueError(f'Issue with ID {issue_id} not found.')
|
||||
|
||||
issue_data = issue_payload.get('data', {}).get('issue', {})
|
||||
title = issue_data.get('title', '')
|
||||
description = issue_data.get('description', '')
|
||||
synced_with = issue_data.get('syncedWith', [])
|
||||
owner = ''
|
||||
repo = ''
|
||||
if synced_with:
|
||||
owner = synced_with[0].get('metadata', {}).get('owner', '')
|
||||
repo = synced_with[0].get('metadata', {}).get('repo', '')
|
||||
|
||||
if not title:
|
||||
raise ValueError(f'Issue with ID {issue_id} does not have a title.')
|
||||
|
||||
if not description:
|
||||
raise ValueError(f'Issue with ID {issue_id} does not have a description.')
|
||||
|
||||
if owner and repo:
|
||||
description += f'\n\nGit Repo: {owner}/{repo}'
|
||||
|
||||
return title, description
|
||||
|
||||
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||
"""Send message/comment to Linear issue."""
|
||||
query = """
|
||||
mutation CommentCreate($input: CommentCreateInput!) {
|
||||
commentCreate(input: $input) {
|
||||
success
|
||||
comment {
|
||||
id
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||
return await self._query_api(query, variables, api_key)
|
||||
|
||||
async def _send_error_comment(
|
||||
self, issue_id: str, error_msg: str, workspace: LinearWorkspace | None
|
||||
):
|
||||
"""Send error comment to Linear issue."""
|
||||
if not workspace:
|
||||
logger.error('[Linear] Cannot send error comment - no workspace available')
|
||||
return
|
||||
|
||||
try:
|
||||
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||
|
||||
async def _send_repo_selection_comment(self, linear_view: LinearViewInterface):
|
||||
"""Send a comment with repository options for the user to choose."""
|
||||
try:
|
||||
comment_msg = (
|
||||
'I need to know which repository to work with. '
|
||||
'Please add it to your issue description or send a followup comment.'
|
||||
)
|
||||
|
||||
api_key = self.token_manager.decrypt_text(
|
||||
linear_view.linear_workspace.svc_acc_api_key
|
||||
)
|
||||
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg=comment_msg),
|
||||
linear_view.job_context.issue_id,
|
||||
api_key,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Linear] Sent repository selection comment for issue {linear_view.job_context.issue_key}'
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to send repository selection comment: {str(e)}'
|
||||
)
|
||||
40
enterprise/integrations/linear/linear_types.py
Normal file
40
enterprise/integrations/linear/linear_types.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import JobContext
|
||||
from jinja2 import Environment
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class LinearViewInterface(ABC):
|
||||
"""Interface for Linear views that handle different types of Linear interactions."""
|
||||
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
linear_user: LinearUser
|
||||
linear_workspace: LinearWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Get initial instructions for the conversation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create or update a conversation and return the conversation ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Linear."""
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""Exception raised when starting a conversation fails."""
|
||||
|
||||
pass
|
||||
224
enterprise/integrations/linear/linear_view.py
Normal file
224
enterprise/integrations/linear/linear_view.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.linear.linear_types import LinearViewInterface, StartingConvoException
|
||||
from integrations.models import JobContext
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from storage.linear_conversation import LinearConversation
|
||||
from storage.linear_integration_store import LinearIntegrationStore
|
||||
from storage.linear_user import LinearUser
|
||||
from storage.linear_workspace import LinearWorkspace
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
|
||||
integration_store = LinearIntegrationStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearNewConversationView(LinearViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
linear_user: LinearUser
|
||||
linear_workspace: LinearWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||
instructions = instructions_template.render()
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_new_conversation.j2')
|
||||
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
)
|
||||
|
||||
return instructions, user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Create a new Linear conversation"""
|
||||
|
||||
if not self.selected_repo:
|
||||
raise StartingConvoException('No repository selected for this conversation')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
instructions, user_msg = self._get_instructions(jinja_env)
|
||||
|
||||
try:
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.linear_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_msg,
|
||||
conversation_instructions=instructions,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.LINEAR,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
|
||||
logger.info(f'[Linear] Created conversation {self.conversation_id}')
|
||||
|
||||
# Store Linear conversation mapping
|
||||
linear_conversation = LinearConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
issue_id=self.job_context.issue_id,
|
||||
issue_key=self.job_context.issue_key,
|
||||
linear_user_id=self.linear_user.id,
|
||||
)
|
||||
|
||||
await integration_store.create_conversation(linear_conversation)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Linear"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [track my progress here]({conversation_link})."
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearExistingConversationView(LinearViewInterface):
|
||||
job_context: JobContext
|
||||
saas_user_auth: UserAuth
|
||||
linear_user: LinearUser
|
||||
linear_workspace: LinearWorkspace
|
||||
selected_repo: str | None
|
||||
conversation_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"""Instructions passed when conversation is first initialized"""
|
||||
|
||||
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||
user_msg = user_msg_template.render(
|
||||
issue_key=self.job_context.issue_key,
|
||||
user_message=self.job_context.user_msg or '',
|
||||
issue_title=self.job_context.issue_title,
|
||||
issue_description=self.job_context.issue_description,
|
||||
)
|
||||
|
||||
return '', user_msg
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||
"""Update an existing Linear conversation"""
|
||||
|
||||
user_id = self.linear_user.keycloak_user_id
|
||||
|
||||
try:
|
||||
conversation_store = await ConversationStoreImpl.get_instance(
|
||||
config, user_id
|
||||
)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
raise ValueError('Could not load provider tokens')
|
||||
providers_set = list(provider_tokens.keys())
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
_, user_msg = self._get_instructions(jinja_env)
|
||||
user_message_event = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_message_event)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Linear] Failed to create conversation: {str(e)}', exc_info=True
|
||||
)
|
||||
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
"""Get the response message to send back to Linear"""
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here]({conversation_link})."
|
||||
|
||||
|
||||
class LinearFactory:
|
||||
"""Factory for creating Linear views based on message content"""
|
||||
|
||||
@staticmethod
|
||||
async def create_linear_view_from_payload(
|
||||
job_context: JobContext,
|
||||
saas_user_auth: UserAuth,
|
||||
linear_user: LinearUser,
|
||||
linear_workspace: LinearWorkspace,
|
||||
) -> LinearViewInterface:
|
||||
"""Create appropriate Linear view based on the message and user state"""
|
||||
|
||||
if not linear_user or not saas_user_auth or not linear_workspace:
|
||||
raise StartingConvoException(
|
||||
'User not authenticated with Linear integration'
|
||||
)
|
||||
|
||||
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||
job_context.issue_id, linear_user.id
|
||||
)
|
||||
if conversation:
|
||||
logger.info(
|
||||
f'[Linear] Found existing conversation for issue {job_context.issue_id}'
|
||||
)
|
||||
return LinearExistingConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
linear_user=linear_user,
|
||||
linear_workspace=linear_workspace,
|
||||
selected_repo=None,
|
||||
conversation_id=conversation.conversation_id,
|
||||
)
|
||||
|
||||
return LinearNewConversationView(
|
||||
job_context=job_context,
|
||||
saas_user_auth=saas_user_auth,
|
||||
linear_user=linear_user,
|
||||
linear_workspace=linear_workspace,
|
||||
selected_repo=None, # Will be set later after repo inference
|
||||
conversation_id='', # Will be set when conversation is created
|
||||
)
|
||||
30
enterprise/integrations/manager.py
Normal file
30
enterprise/integrations/manager.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.models import Message, SourceType
|
||||
|
||||
|
||||
class Manager(ABC):
|
||||
manager_type: SourceType
|
||||
|
||||
@abstractmethod
|
||||
async def receive_message(self, message: Message):
|
||||
"Receive message from integration"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def send_message(self, message: Message):
|
||||
"Send message to integration from Openhands server"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def is_job_requested(self, message: Message) -> bool:
|
||||
"Confirm that a job is being requested"
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self):
|
||||
"Kick off a job with openhands agent"
|
||||
raise NotImplementedError
|
||||
|
||||
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||
52
enterprise/integrations/models.py
Normal file
52
enterprise/integrations/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.schema import AgentState
|
||||
|
||||
|
||||
class SourceType(str, Enum):
|
||||
GITHUB = 'github'
|
||||
GITLAB = 'gitlab'
|
||||
OPENHANDS = 'openhands'
|
||||
SLACK = 'slack'
|
||||
JIRA = 'jira'
|
||||
JIRA_DC = 'jira_dc'
|
||||
LINEAR = 'linear'
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
source: SourceType
|
||||
message: str | dict
|
||||
ephemeral: bool = False
|
||||
|
||||
|
||||
class JobContext(BaseModel):
|
||||
issue_id: str
|
||||
issue_key: str
|
||||
user_msg: str
|
||||
user_email: str
|
||||
display_name: str
|
||||
platform_user_id: str = ''
|
||||
workspace_name: str
|
||||
base_api_url: str = ''
|
||||
issue_title: str = ''
|
||||
issue_description: str = ''
|
||||
|
||||
|
||||
class JobResult:
|
||||
result: str
|
||||
explanation: str
|
||||
|
||||
|
||||
class GithubResolverJob:
|
||||
type: SourceType
|
||||
status: AgentState
|
||||
result: JobResult
|
||||
owner: str
|
||||
repo: str
|
||||
installation_token: str
|
||||
issue_number: int
|
||||
runtime_id: int
|
||||
created_at: int
|
||||
completed_at: int
|
||||
363
enterprise/integrations/slack/slack_manager.py
Normal file
363
enterprise/integrations/slack/slack_manager.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import re
|
||||
|
||||
import jwt
|
||||
from integrations.manager import Manager
|
||||
from integrations.models import Message, SourceType
|
||||
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||
from integrations.slack.slack_view import (
|
||||
SlackFactory,
|
||||
SlackNewConversationFromRepoFormView,
|
||||
SlackNewConversationView,
|
||||
SlackUnkownUserView,
|
||||
SlackUpdateExistingConversationView,
|
||||
)
|
||||
from integrations.utils import (
|
||||
HOST_URL,
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||
)
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from pydantic import SecretStr
|
||||
from server.auth.saas_user_auth import SaasUserAuth
|
||||
from server.constants import SLACK_CLIENT_ID
|
||||
from server.utils.conversation_callback_utils import register_callback_processor
|
||||
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from storage.database import session_maker
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.integrations.provider import ProviderHandler
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.server.shared import config, server_config
|
||||
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
authorize_url_generator = AuthorizeUrlGenerator(
|
||||
client_id=SLACK_CLIENT_ID,
|
||||
scopes=['app_mentions:read', 'chat:write'],
|
||||
user_scopes=['search:read'],
|
||||
)
|
||||
|
||||
|
||||
class SlackManager(Manager):
|
||||
def __init__(self, token_manager):
|
||||
self.token_manager = token_manager
|
||||
self.login_link = (
|
||||
'User has not yet authenticated: [Click here to Login to OpenHands]({}).'
|
||||
)
|
||||
|
||||
self.jinja_env = Environment(
|
||||
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'slack')
|
||||
)
|
||||
|
||||
def _confirm_incoming_source_type(self, message: Message):
|
||||
if message.source != SourceType.SLACK:
|
||||
raise ValueError(f'Unexpected message source {message.source}')
|
||||
|
||||
async def _get_user_auth(self, keycloak_user_id: str) -> UserAuth:
|
||||
offline_token = await self.token_manager.load_offline_token(keycloak_user_id)
|
||||
if offline_token is None:
|
||||
logger.info('no_offline_token_found')
|
||||
|
||||
user_auth = SaasUserAuth(
|
||||
user_id=keycloak_user_id,
|
||||
refresh_token=SecretStr(offline_token),
|
||||
)
|
||||
return user_auth
|
||||
|
||||
async def authenticate_user(
|
||||
self, slack_user_id: str
|
||||
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||
slack_user = None
|
||||
with session_maker() as session:
|
||||
slack_user = (
|
||||
session.query(SlackUser)
|
||||
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||
|
||||
saas_user_auth = None
|
||||
if slack_user:
|
||||
saas_user_auth = await self._get_user_auth(slack_user.keycloak_user_id)
|
||||
# slack_view.saas_user_auth = await self._get_user_auth(slack_view.slack_to_openhands_user.keycloak_user_id)
|
||||
|
||||
return slack_user, saas_user_auth
|
||||
|
||||
def _infer_repo_from_message(self, user_msg: str) -> str | None:
|
||||
# Regular expression to match patterns like "All-Hands-AI/OpenHands" or "deploy repo"
|
||||
pattern = r'([a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+)|([a-zA-Z0-9_-]+)(?=\s+repo)'
|
||||
match = re.search(pattern, user_msg)
|
||||
|
||||
if match:
|
||||
repo = match.group(1) if match.group(1) else match.group(2)
|
||||
return repo
|
||||
|
||||
return None
|
||||
|
||||
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||
provider_tokens = await user_auth.get_provider_tokens()
|
||||
if provider_tokens is None:
|
||||
return []
|
||||
access_token = await user_auth.get_access_token()
|
||||
user_id = await user_auth.get_user_id()
|
||||
client = ProviderHandler(
|
||||
provider_tokens=provider_tokens,
|
||||
external_auth_token=access_token,
|
||||
external_auth_id=user_id,
|
||||
)
|
||||
repos: list[Repository] = await client.get_repositories(
|
||||
'pushed', server_config.app_mode, None, None, None, None
|
||||
)
|
||||
return repos
|
||||
|
||||
def _generate_repo_selection_form(
|
||||
self, repo_list: list[Repository], message_ts: str, thread_ts: str | None
|
||||
):
|
||||
options = [
|
||||
{
|
||||
'text': {'type': 'plain_text', 'text': 'No Repository'},
|
||||
'value': '-',
|
||||
}
|
||||
]
|
||||
options.extend(
|
||||
{
|
||||
'text': {
|
||||
'type': 'plain_text',
|
||||
'text': repo.full_name,
|
||||
},
|
||||
'value': repo.full_name,
|
||||
}
|
||||
for repo in repo_list
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
'type': 'header',
|
||||
'text': {
|
||||
'type': 'plain_text',
|
||||
'text': 'Choose a repository',
|
||||
'emoji': True,
|
||||
},
|
||||
},
|
||||
{
|
||||
'type': 'actions',
|
||||
'elements': [
|
||||
{
|
||||
'type': 'static_select',
|
||||
'action_id': f'repository_select:{message_ts}:{thread_ts}',
|
||||
'options': options,
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def filter_potential_repos_by_user_msg(
|
||||
self, user_msg: str, user_repos: list[Repository]
|
||||
) -> tuple[bool, list[Repository]]:
|
||||
inferred_repo = self._infer_repo_from_message(user_msg)
|
||||
if not inferred_repo:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
final_repos = []
|
||||
for repo in user_repos:
|
||||
if inferred_repo.lower() in repo.full_name.lower():
|
||||
final_repos.append(repo)
|
||||
|
||||
# no repos matched, return original list
|
||||
if len(final_repos) == 0:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
# Found exact match
|
||||
elif len(final_repos) == 1:
|
||||
return True, final_repos
|
||||
|
||||
# Found partial matches
|
||||
return False, final_repos[0:99]
|
||||
|
||||
async def receive_message(self, message: Message):
|
||||
self._confirm_incoming_source_type(message)
|
||||
|
||||
slack_user, saas_user_auth = await self.authenticate_user(
|
||||
slack_user_id=message.message['slack_user_id']
|
||||
)
|
||||
|
||||
try:
|
||||
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||
message, slack_user, saas_user_auth
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'[Slack]: Failed to create slack view: {e}',
|
||||
exc_info=True,
|
||||
stack_info=True,
|
||||
)
|
||||
return
|
||||
|
||||
if isinstance(slack_view, SlackUnkownUserView):
|
||||
jwt_secret = config.jwt_secret
|
||||
if not jwt_secret:
|
||||
raise ValueError('Must configure jwt_secret')
|
||||
state = jwt.encode(
|
||||
message.message, jwt_secret.get_secret_value(), algorithm='HS256'
|
||||
)
|
||||
link = authorize_url_generator.generate(state)
|
||||
msg = self.login_link.format(link)
|
||||
|
||||
logger.info('slack_not_yet_authenticated')
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||
)
|
||||
return
|
||||
|
||||
if not await self.is_job_requested(message, slack_view):
|
||||
return
|
||||
|
||||
await self.start_job(slack_view)
|
||||
|
||||
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||
if message.ephemeral and isinstance(message.message, str):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
)
|
||||
elif message.ephemeral and isinstance(message.message, dict):
|
||||
await client.chat_postEphemeral(
|
||||
channel=slack_view.channel_id,
|
||||
user=slack_view.slack_user_id,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
text=message.message['text'],
|
||||
blocks=message.message['blocks'],
|
||||
)
|
||||
else:
|
||||
await client.chat_postMessage(
|
||||
channel=slack_view.channel_id,
|
||||
markdown_text=message.message,
|
||||
thread_ts=slack_view.message_ts,
|
||||
)
|
||||
|
||||
async def is_job_requested(
|
||||
self, message: Message, slack_view: SlackViewInterface
|
||||
) -> bool:
|
||||
"""
|
||||
A job is always request we only receive webhooks for events associated with the slack bot
|
||||
This method really just checks
|
||||
1. Is the user is authenticated
|
||||
2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user)
|
||||
"""
|
||||
|
||||
# Infer repo from user message is not needed; user selected repo from the form or is updating existing convo
|
||||
if isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||
return True
|
||||
elif isinstance(slack_view, SlackNewConversationFromRepoFormView):
|
||||
return True
|
||||
elif isinstance(slack_view, SlackNewConversationView):
|
||||
user = slack_view.slack_to_openhands_user
|
||||
user_repos: list[Repository] = await self._get_repositories(
|
||||
slack_view.saas_user_auth
|
||||
)
|
||||
match, repos = self.filter_potential_repos_by_user_msg(
|
||||
slack_view.user_msg, user_repos
|
||||
)
|
||||
|
||||
# User mentioned a matching repo is their message, start job without repo selection form
|
||||
if match:
|
||||
slack_view.selected_repo = repos[0].full_name
|
||||
return True
|
||||
|
||||
logger.info(
|
||||
'render_repository_selector',
|
||||
extra={
|
||||
'slack_user_id': user,
|
||||
'keycloak_user_id': user.keycloak_user_id,
|
||||
'message_ts': slack_view.message_ts,
|
||||
'thread_ts': slack_view.thread_ts,
|
||||
},
|
||||
)
|
||||
|
||||
repo_selection_msg = {
|
||||
'text': 'Choose a Repository:',
|
||||
'blocks': self._generate_repo_selection_form(
|
||||
repos, slack_view.message_ts, slack_view.thread_ts
|
||||
),
|
||||
}
|
||||
await self.send_message(
|
||||
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||
slack_view,
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def start_job(self, slack_view: SlackViewInterface):
|
||||
# Importing here prevents circular import
|
||||
from server.conversation_callback_processor.slack_callback_processor import (
|
||||
SlackCallbackProcessor,
|
||||
)
|
||||
|
||||
try:
|
||||
msg_info = None
|
||||
user_info: SlackUser = slack_view.slack_to_openhands_user
|
||||
try:
|
||||
logger.info(
|
||||
f'[Slack] Starting job for user {user_info.slack_display_name} (id={user_info.slack_user_id})',
|
||||
extra={'keyloak_user_id': user_info.keycloak_user_id},
|
||||
)
|
||||
conversation_id = await slack_view.create_or_update_conversation(
|
||||
self.jinja_env
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Created conversation {conversation_id} for user {user_info.slack_display_name}'
|
||||
)
|
||||
|
||||
if not isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||
# We don't re-subscribe for follow up messages from slack.
|
||||
# Summaries are generated for every messages anyways, we only need to do
|
||||
# this subscription once for the event which kicked off the job.
|
||||
processor = SlackCallbackProcessor(
|
||||
slack_user_id=slack_view.slack_user_id,
|
||||
channel_id=slack_view.channel_id,
|
||||
message_ts=slack_view.message_ts,
|
||||
thread_ts=slack_view.thread_ts,
|
||||
team_id=slack_view.team_id,
|
||||
)
|
||||
|
||||
# Register the callback processor
|
||||
register_callback_processor(conversation_id, processor)
|
||||
|
||||
logger.info(
|
||||
f'[Slack] Created callback processor for conversation {conversation_id}'
|
||||
)
|
||||
|
||||
msg_info = slack_view.get_response_msg()
|
||||
|
||||
except MissingSettingsError as e:
|
||||
logger.warning(
|
||||
f'[Slack] Missing settings error for user {user_info.slack_display_name}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'{user_info.slack_display_name} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except LLMAuthenticationError as e:
|
||||
logger.warning(
|
||||
f'[Slack] LLM authentication error for user {user_info.slack_display_name}: {str(e)}'
|
||||
)
|
||||
|
||||
msg_info = f'@{user_info.slack_display_name} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||
|
||||
except StartingConvoException as e:
|
||||
msg_info = str(e)
|
||||
|
||||
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||
|
||||
except Exception:
|
||||
logger.exception('[Slack]: Error starting job')
|
||||
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||
48
enterprise/integrations/slack/slack_types.py
Normal file
48
enterprise/integrations/slack/slack_types.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from integrations.types import SummaryExtractionTracker
|
||||
from jinja2 import Environment
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
|
||||
|
||||
class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser | None
|
||||
saas_user_auth: UserAuth | None
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
|
||||
@abstractmethod
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
"Create a new conversation"
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callback_id(self) -> str:
|
||||
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_response_msg(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class StartingConvoException(Exception):
|
||||
"""
|
||||
Raised when trying to send message to a conversation that's is still starting up
|
||||
"""
|
||||
435
enterprise/integrations/slack/slack_view.py
Normal file
435
enterprise/integrations/slack/slack_view.py
Normal file
@@ -0,0 +1,435 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from integrations.models import Message
|
||||
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||
from jinja2 import Environment
|
||||
from slack_sdk import WebClient
|
||||
from storage.slack_conversation import SlackConversation
|
||||
from storage.slack_conversation_store import SlackConversationStore
|
||||
from storage.slack_team_store import SlackTeamStore
|
||||
from storage.slack_user import SlackUser
|
||||
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events.action import MessageAction
|
||||
from openhands.events.serialization.event import event_to_dict
|
||||
from openhands.server.services.conversation_service import (
|
||||
create_new_conversation,
|
||||
setup_init_conversation_settings,
|
||||
)
|
||||
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||
from openhands.server.user_auth.user_auth import UserAuth
|
||||
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||
|
||||
# =================================================
|
||||
# SECTION: Github view types
|
||||
# =================================================
|
||||
|
||||
|
||||
CONTEXT_LIMIT = 21
|
||||
slack_conversation_store = SlackConversationStore.get_instance()
|
||||
slack_team_store = SlackTeamStore.get_instance()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackUnkownUserView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser | None
|
||||
saas_user_auth: UserAuth | None
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackNewConversationView(SlackViewInterface):
|
||||
bot_access_token: str
|
||||
user_msg: str | None
|
||||
slack_user_id: str
|
||||
slack_to_openhands_user: SlackUser
|
||||
saas_user_auth: UserAuth
|
||||
channel_id: str
|
||||
message_ts: str
|
||||
thread_ts: str | None
|
||||
selected_repo: str | None
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
conversation_id: str
|
||||
team_id: str
|
||||
|
||||
def _get_initial_prompt(self, text: str, blocks: list[dict]):
|
||||
bot_id = self._get_bot_id(blocks)
|
||||
text = text.replace(f'<@{bot_id}>', '').strip()
|
||||
return text
|
||||
|
||||
def _get_bot_id(self, blocks: list[dict]) -> str:
|
||||
for block in blocks:
|
||||
type_ = block['type']
|
||||
if type_ in ('rich_text', 'rich_text_section'):
|
||||
bot_id = self._get_bot_id(block['elements'])
|
||||
if bot_id:
|
||||
return bot_id
|
||||
if type_ == 'user':
|
||||
return block['user_id']
|
||||
return ''
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
messages = []
|
||||
if self.thread_ts:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.thread_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT, # We can be smarter about getting more context/condensing it even in the future
|
||||
)
|
||||
|
||||
messages = result['messages']
|
||||
|
||||
else:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_history(
|
||||
channel=self.channel_id,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=CONTEXT_LIMIT,
|
||||
)
|
||||
|
||||
messages = result['messages']
|
||||
messages.reverse()
|
||||
|
||||
if not messages:
|
||||
raise ValueError('Failed to fetch information from slack API')
|
||||
|
||||
logger.info('got_messages_from_slack', extra={'messages': messages})
|
||||
|
||||
trigger_msg = messages[-1]
|
||||
user_message = self._get_initial_prompt(
|
||||
trigger_msg['text'], trigger_msg['blocks']
|
||||
)
|
||||
|
||||
conversation_instructions = ''
|
||||
|
||||
if len(messages) > 1:
|
||||
messages.pop()
|
||||
text_messages = [m['text'] for m in messages if m.get('text')]
|
||||
conversation_instructions_template = jinja_env.get_template(
|
||||
'user_message_conversation_instructions.j2'
|
||||
)
|
||||
conversation_instructions = conversation_instructions_template.render(
|
||||
messages=text_messages,
|
||||
username=user_info.slack_display_name,
|
||||
conversation_url=CONVERSATION_URL,
|
||||
)
|
||||
|
||||
return user_message, conversation_instructions
|
||||
|
||||
def _verify_necessary_values_are_set(self):
|
||||
if not self.selected_repo:
|
||||
raise ValueError(
|
||||
'Attempting to start conversation without confirming selected repo from user'
|
||||
)
|
||||
|
||||
async def save_slack_convo(self):
|
||||
if self.slack_to_openhands_user:
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
|
||||
logger.info(
|
||||
'Create slack conversation',
|
||||
extra={
|
||||
'channel_id': self.channel_id,
|
||||
'conversation_id': self.conversation_id,
|
||||
'keycloak_user_id': user_info.keycloak_user_id,
|
||||
'parent_id': self.thread_ts or self.message_ts,
|
||||
},
|
||||
)
|
||||
slack_conversation = SlackConversation(
|
||||
conversation_id=self.conversation_id,
|
||||
channel_id=self.channel_id,
|
||||
keycloak_user_id=user_info.keycloak_user_id,
|
||||
parent_id=self.thread_ts
|
||||
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||
)
|
||||
await slack_conversation_store.create_slack_conversation(slack_conversation)
|
||||
|
||||
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||
"""
|
||||
Only creates a new conversation
|
||||
"""
|
||||
self._verify_necessary_values_are_set()
|
||||
|
||||
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||
|
||||
agent_loop_info = await create_new_conversation(
|
||||
user_id=self.slack_to_openhands_user.keycloak_user_id,
|
||||
git_provider_tokens=provider_tokens,
|
||||
selected_repository=self.selected_repo,
|
||||
selected_branch=None,
|
||||
initial_user_msg=user_instructions,
|
||||
conversation_instructions=conversation_instructions
|
||||
if conversation_instructions
|
||||
else None,
|
||||
image_urls=None,
|
||||
replay_json=None,
|
||||
conversation_trigger=ConversationTrigger.SLACK,
|
||||
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||
)
|
||||
|
||||
self.conversation_id = agent_loop_info.conversation_id
|
||||
await self.save_slack_convo()
|
||||
return self.conversation_id
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
return f'slack_{self.channel_id}_{self.message_ts}'
|
||||
|
||||
def get_response_msg(self) -> str:
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {user_info.slack_display_name} can [track my progress here]({conversation_link})."
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||
def _verify_necessary_values_are_set(self):
|
||||
# Exclude selected repo check from parent
|
||||
# User can start conversations without a repo when specified via the repo selection form
|
||||
return
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||
slack_conversation: SlackConversation
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
client = WebClient(token=self.bot_access_token)
|
||||
result = client.conversations_replies(
|
||||
channel=self.channel_id,
|
||||
ts=self.message_ts,
|
||||
inclusive=True,
|
||||
latest=self.message_ts,
|
||||
limit=1, # Get exact user message, in future we can be smarter with collecting additional context
|
||||
)
|
||||
|
||||
user_message = result['messages'][0]
|
||||
user_message = self._get_initial_prompt(
|
||||
user_message['text'], user_message['blocks']
|
||||
)
|
||||
|
||||
return user_message, ''
|
||||
|
||||
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||
"""
|
||||
Send new user message to converation
|
||||
"""
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
saas_user_auth: UserAuth = self.saas_user_auth
|
||||
user_id = user_info.keycloak_user_id
|
||||
|
||||
# Org management in the future will get rid of this
|
||||
# For now, only user that created the conversation can send follow up messages to it
|
||||
if user_id != self.slack_conversation.keycloak_user_id:
|
||||
raise StartingConvoException(
|
||||
f'{user_info.slack_display_name} is not authorized to send messages to this conversation.'
|
||||
)
|
||||
|
||||
# Check if conversation has been deleted
|
||||
# Update logic when soft delete is implemented
|
||||
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||
if not metadata:
|
||||
raise StartingConvoException('Conversation no longer exists.')
|
||||
|
||||
provider_tokens = await saas_user_auth.get_provider_tokens()
|
||||
|
||||
# Should we raise here if there are no provider tokens?
|
||||
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||
|
||||
conversation_init_data = await setup_init_conversation_settings(
|
||||
user_id, self.conversation_id, providers_set
|
||||
)
|
||||
|
||||
# Either join ongoing conversation, or restart the conversation
|
||||
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||
self.conversation_id, conversation_init_data, user_id
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(
|
||||
agent_loop_info.event_store
|
||||
)
|
||||
agent_state = (
|
||||
None
|
||||
if len(final_agent_observation) == 0
|
||||
else final_agent_observation[0].agent_state
|
||||
)
|
||||
|
||||
if not agent_state or agent_state == AgentState.LOADING:
|
||||
raise StartingConvoException('Conversation is still starting')
|
||||
|
||||
user_msg, _ = self._get_instructions(jinja)
|
||||
user_msg_action = MessageAction(content=user_msg)
|
||||
await conversation_manager.send_event_to_conversation(
|
||||
self.conversation_id, event_to_dict(user_msg_action)
|
||||
)
|
||||
|
||||
return self.conversation_id
|
||||
|
||||
def get_response_msg(self):
|
||||
user_info: SlackUser = self.slack_to_openhands_user
|
||||
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||
return f"I'm on it! {user_info.slack_display_name} can [continue tracking my progress here]({conversation_link})."
|
||||
|
||||
|
||||
class SlackFactory:
|
||||
@staticmethod
|
||||
def did_user_select_repo_from_form(message: Message):
|
||||
payload = message.message
|
||||
return 'selected_repo' in payload
|
||||
|
||||
@staticmethod
|
||||
async def determine_if_updating_existing_conversation(
|
||||
message: Message,
|
||||
) -> SlackConversation | None:
|
||||
payload = message.message
|
||||
channel_id = payload.get('channel_id')
|
||||
thread_ts = payload.get('thread_ts')
|
||||
|
||||
# Follow up conversations must be contained in-thread
|
||||
if not thread_ts:
|
||||
return None
|
||||
|
||||
# thread_ts in slack payloads in the parent's (root level msg's) message ID
|
||||
return await slack_conversation_store.get_slack_conversation(
|
||||
channel_id, thread_ts
|
||||
)
|
||||
|
||||
def create_slack_view_from_payload(
|
||||
message: Message, slack_user: SlackUser | None, saas_user_auth: UserAuth | None
|
||||
):
|
||||
payload = message.message
|
||||
slack_user_id = payload['slack_user_id']
|
||||
channel_id = payload.get('channel_id')
|
||||
message_ts = payload.get('message_ts')
|
||||
thread_ts = payload.get('thread_ts')
|
||||
team_id = payload['team_id']
|
||||
user_msg = payload.get('user_msg')
|
||||
|
||||
bot_access_token = slack_team_store.get_team_bot_token(team_id)
|
||||
if not bot_access_token:
|
||||
logger.error(
|
||||
'Did not find slack team',
|
||||
extra={
|
||||
'slack_user_id': slack_user_id,
|
||||
'channel_id': channel_id,
|
||||
},
|
||||
)
|
||||
raise Exception('Did not slack team')
|
||||
|
||||
# Determine if this is a known slack user by openhands
|
||||
if not slack_user or not saas_user_auth or not channel_id:
|
||||
return SlackUnkownUserView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=False,
|
||||
send_summary_instruction=False,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
conversation: SlackConversation | None = call_async_from_sync(
|
||||
SlackFactory.determine_if_updating_existing_conversation,
|
||||
GENERAL_TIMEOUT,
|
||||
message,
|
||||
)
|
||||
if conversation:
|
||||
logger.info(
|
||||
'Found existing slack conversation',
|
||||
extra={
|
||||
'conversation_id': conversation.conversation_id,
|
||||
'parent_id': conversation.parent_id,
|
||||
},
|
||||
)
|
||||
return SlackUpdateExistingConversationView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
conversation_id=conversation.conversation_id,
|
||||
slack_conversation=conversation,
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
elif SlackFactory.did_user_select_repo_from_form(message):
|
||||
return SlackNewConversationFromRepoFormView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=payload['selected_repo'],
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
)
|
||||
|
||||
else:
|
||||
return SlackNewConversationView(
|
||||
bot_access_token=bot_access_token,
|
||||
user_msg=user_msg,
|
||||
slack_user_id=slack_user_id,
|
||||
slack_to_openhands_user=slack_user,
|
||||
saas_user_auth=saas_user_auth,
|
||||
channel_id=channel_id,
|
||||
message_ts=message_ts,
|
||||
thread_ts=thread_ts,
|
||||
selected_repo=None,
|
||||
should_extract=True,
|
||||
send_summary_instruction=True,
|
||||
conversation_id='',
|
||||
team_id=team_id,
|
||||
)
|
||||
0
enterprise/integrations/solvability/__init__.py
Normal file
0
enterprise/integrations/solvability/__init__.py
Normal file
41
enterprise/integrations/solvability/data/__init__.py
Normal file
41
enterprise/integrations/solvability/data/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
Utilities for loading and managing pre-trained classifiers.
|
||||
|
||||
Assumes that classifiers are stored adjacent to this file in the `solvability/data` directory, using a simple
|
||||
`name + .json` pattern.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
|
||||
|
||||
def load_classifier(name: str) -> SolvabilityClassifier:
|
||||
"""
|
||||
Load a classifier by name.
|
||||
|
||||
Args:
|
||||
name (str): The name of the classifier to load.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: The loaded classifier instance.
|
||||
"""
|
||||
data_dir = Path(__file__).parent
|
||||
classifier_path = data_dir / f'{name}.json'
|
||||
|
||||
if not classifier_path.exists():
|
||||
raise FileNotFoundError(f"Classifier '{name}' not found at {classifier_path}")
|
||||
|
||||
with classifier_path.open('r') as f:
|
||||
return SolvabilityClassifier.model_validate_json(f.read())
|
||||
|
||||
|
||||
def available_classifiers() -> list[str]:
|
||||
"""
|
||||
List all available classifiers in the data directory.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of classifier names (without the .json extension).
|
||||
"""
|
||||
data_dir = Path(__file__).parent
|
||||
return [f.stem for f in data_dir.glob('*.json') if f.is_file()]
|
||||
File diff suppressed because one or more lines are too long
38
enterprise/integrations/solvability/models/__init__.py
Normal file
38
enterprise/integrations/solvability/models/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Solvability Models Package
|
||||
|
||||
This package contains the core machine learning models and components for predicting
|
||||
the solvability of GitHub issues and similar technical problems.
|
||||
|
||||
The solvability prediction system works by:
|
||||
1. Using a Featurizer to extract semantic features from issue descriptions via LLM calls
|
||||
2. Training a RandomForestClassifier on these features to predict solvability
|
||||
3. Generating detailed reports with feature importance analysis
|
||||
|
||||
Key Components:
|
||||
- Feature: Defines individual features that can be extracted from issues
|
||||
- Featurizer: Orchestrates LLM-based feature extraction with sampling and batching
|
||||
- SolvabilityClassifier: Main ML pipeline combining featurization and classification
|
||||
- SolvabilityReport: Comprehensive output with predictions, feature analysis, and metadata
|
||||
- ImportanceStrategy: Configurable methods for calculating feature importance (SHAP, permutation, impurity)
|
||||
"""
|
||||
|
||||
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||
from integrations.solvability.models.featurizer import (
|
||||
EmbeddingDimension,
|
||||
Feature,
|
||||
FeatureEmbedding,
|
||||
Featurizer,
|
||||
)
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
|
||||
__all__ = [
|
||||
'Feature',
|
||||
'EmbeddingDimension',
|
||||
'FeatureEmbedding',
|
||||
'Featurizer',
|
||||
'ImportanceStrategy',
|
||||
'SolvabilityClassifier',
|
||||
'SolvabilityReport',
|
||||
]
|
||||
433
enterprise/integrations/solvability/models/classifier.py
Normal file
433
enterprise/integrations/solvability/models/classifier.py
Normal file
@@ -0,0 +1,433 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import shap
|
||||
from integrations.solvability.models.featurizer import Feature, Featurizer
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.exceptions import NotFittedError
|
||||
from sklearn.inspection import permutation_importance
|
||||
from sklearn.utils.validation import check_is_fitted
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
|
||||
|
||||
class SolvabilityClassifier(BaseModel):
|
||||
"""
|
||||
Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
|
||||
|
||||
This classifier combines LLM-based feature extraction with traditional ML classification:
|
||||
1. Uses a Featurizer to extract semantic boolean features from issue descriptions via LLM calls
|
||||
2. Trains a RandomForestClassifier on these features to predict solvability scores
|
||||
3. Provides feature importance analysis using configurable strategies (SHAP, permutation, impurity)
|
||||
4. Generates comprehensive reports with predictions, feature analysis, and cost metrics
|
||||
|
||||
The classifier supports both training on labeled data and inference on new issues, with built-in
|
||||
support for batch processing and concurrent feature extraction.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""
|
||||
The identifier for the classifier.
|
||||
"""
|
||||
|
||||
featurizer: Featurizer
|
||||
"""
|
||||
The featurizer to use for transforming the input data.
|
||||
"""
|
||||
|
||||
classifier: RandomForestClassifier
|
||||
"""
|
||||
The RandomForestClassifier used for predicting solvability from extracted features.
|
||||
|
||||
This ensemble model provides robust predictions and built-in feature importance metrics.
|
||||
"""
|
||||
|
||||
importance_strategy: ImportanceStrategy = ImportanceStrategy.IMPURITY
|
||||
"""
|
||||
Strategy to use for calculating feature importance.
|
||||
"""
|
||||
|
||||
samples: int = 10
|
||||
"""
|
||||
Number of samples to use for calculating feature embedding coefficients.
|
||||
"""
|
||||
|
||||
random_state: int | None = None
|
||||
"""
|
||||
Random state for reproducibility.
|
||||
"""
|
||||
|
||||
_classifier_attrs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||
"""
|
||||
Private dictionary storing cached results from feature extraction and importance calculations.
|
||||
|
||||
Contains keys like 'features_', 'cost_', 'feature_importances_', and 'labels_' that are populated
|
||||
during transform(), fit(), and predict() operations. Access these via the corresponding properties.
|
||||
|
||||
This field is never serialized, so cached values will not persist across model save/load cycles.
|
||||
"""
|
||||
|
||||
model_config = {
|
||||
'arbitrary_types_allowed': True,
|
||||
}
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_random_state(self) -> SolvabilityClassifier:
|
||||
"""
|
||||
Validate the random state configuration between this object and the classifier.
|
||||
"""
|
||||
# If both random states are set, they definitely need to agree.
|
||||
if self.random_state is not None and self.classifier.random_state is not None:
|
||||
if self.random_state != self.classifier.random_state:
|
||||
raise ValueError(
|
||||
'The random state of the classifier and the top-level classifier must agree.'
|
||||
)
|
||||
|
||||
# Otherwise, we'll always set the classifier's random state to the top-level one.
|
||||
self.classifier.random_state = self.random_state
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def features_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the features used by the classifier for the most recent inputs.
|
||||
"""
|
||||
if 'features_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
)
|
||||
return self._classifier_attrs['features_']
|
||||
|
||||
@property
|
||||
def cost_(self) -> pd.DataFrame:
|
||||
"""
|
||||
Get the cost of the classifier for the most recent inputs.
|
||||
"""
|
||||
if 'cost_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'SolvabilityClassifier.transform() has not yet been called.'
|
||||
)
|
||||
return self._classifier_attrs['cost_']
|
||||
|
||||
@property
|
||||
def feature_importances_(self) -> np.ndarray:
|
||||
"""
|
||||
Get the feature importances for the most recent inputs.
|
||||
"""
|
||||
if 'feature_importances_' not in self._classifier_attrs:
|
||||
raise ValueError(
|
||||
'No SolvabilityClassifier methods that produce feature importances (.fit(), .predict_proba(), and '
|
||||
'.predict()) have been called.'
|
||||
)
|
||||
return self._classifier_attrs['feature_importances_'] # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def is_fitted(self) -> bool:
|
||||
"""
|
||||
Check if the classifier is fitted.
|
||||
"""
|
||||
try:
|
||||
check_is_fitted(self.classifier)
|
||||
return True
|
||||
except NotFittedError:
|
||||
return False
|
||||
|
||||
def transform(self, issues: pd.Series, llm_config: LLMConfig) -> pd.DataFrame:
|
||||
"""
|
||||
Transform the input issues using the featurizer to extract features.
|
||||
|
||||
This method orchestrates the feature extraction pipeline:
|
||||
1. Uses the featurizer to generate embeddings for all issues
|
||||
2. Converts embeddings to a structured DataFrame
|
||||
3. Separates feature columns from metadata columns
|
||||
4. Stores results for later access via properties
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: A DataFrame containing only the feature columns (no metadata).
|
||||
"""
|
||||
# Generate feature embeddings for all issues using batch processing
|
||||
feature_embeddings = self.featurizer.embed_batch(
|
||||
issues, samples=self.samples, llm_config=llm_config
|
||||
)
|
||||
df = pd.DataFrame(embedding.to_row() for embedding in feature_embeddings)
|
||||
|
||||
# Split into feature columns (used by classifier) and cost columns (metadata)
|
||||
feature_columns = [feature.identifier for feature in self.featurizer.features]
|
||||
cost_columns = [col for col in df.columns if col not in feature_columns]
|
||||
|
||||
# Store both sets for access via properties
|
||||
self._classifier_attrs['features_'] = df[feature_columns]
|
||||
self._classifier_attrs['cost_'] = df[cost_columns]
|
||||
|
||||
return self.features_
|
||||
|
||||
def fit(
|
||||
self, issues: pd.Series, labels: pd.Series, llm_config: LLMConfig
|
||||
) -> SolvabilityClassifier:
|
||||
"""
|
||||
Fit the classifier to the input issues and labels.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
|
||||
labels: A pandas Series containing the labels (0 or 1) for each issue.
|
||||
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: The fitted classifier.
|
||||
"""
|
||||
features = self.transform(issues, llm_config=llm_config)
|
||||
self.classifier.fit(features, labels)
|
||||
|
||||
# Store labels for permutation importance calculation
|
||||
self._classifier_attrs['labels_'] = labels
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, self.classifier.predict_proba(features), labels
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def predict_proba(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability probabilities for the input issues.
|
||||
|
||||
Returns class probabilities where the second column represents the probability
|
||||
of the issue being solvable (positive class).
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Array of shape (n_samples, 2) with probabilities for each class.
|
||||
Column 0: probability of not solvable, Column 1: probability of solvable.
|
||||
"""
|
||||
features = self.transform(issues, llm_config=llm_config)
|
||||
scores = self.classifier.predict_proba(features)
|
||||
|
||||
# Calculate feature importances based on the configured strategy
|
||||
# For permutation importance, we need ground truth labels if available
|
||||
labels = self._classifier_attrs.get('labels_')
|
||||
if (
|
||||
self.importance_strategy == ImportanceStrategy.PERMUTATION
|
||||
and labels is not None
|
||||
):
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, scores, labels
|
||||
)
|
||||
else:
|
||||
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||
features, scores
|
||||
)
|
||||
|
||||
return scores # type: ignore[no-any-return]
|
||||
|
||||
def predict(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||
"""
|
||||
Predict the solvability of the input issues by returning binary labels.
|
||||
|
||||
Uses a 0.5 probability threshold to convert probabilities to binary predictions.
|
||||
|
||||
Args:
|
||||
issues: A pandas Series containing the issue descriptions.
|
||||
llm_config: LLM configuration to use for feature extraction.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Boolean array where True indicates the issue is predicted as solvable.
|
||||
"""
|
||||
probabilities = self.predict_proba(issues, llm_config=llm_config)
|
||||
# Apply 0.5 threshold to convert probabilities to binary predictions
|
||||
labels = probabilities[:, 1] >= 0.5
|
||||
return labels
|
||||
|
||||
def _importance(
|
||||
self,
|
||||
features: pd.DataFrame,
|
||||
scores: np.ndarray,
|
||||
labels: np.ndarray | None = None,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Calculate feature importance scores using the configured strategy.
|
||||
|
||||
Different strategies provide different interpretations:
|
||||
- SHAP: Shapley values indicating contribution to individual predictions
|
||||
- PERMUTATION: Decrease in model performance when feature is shuffled
|
||||
- IMPURITY: Gini impurity decrease from splits on each feature
|
||||
|
||||
Args:
|
||||
features: Feature matrix used for predictions.
|
||||
scores: Model prediction scores (unused for some strategies).
|
||||
labels: Ground truth labels (required for permutation importance).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Feature importance scores, one per feature.
|
||||
"""
|
||||
match self.importance_strategy:
|
||||
case ImportanceStrategy.SHAP:
|
||||
# Use SHAP TreeExplainer for tree-based models
|
||||
explainer = shap.TreeExplainer(self.classifier)
|
||||
shap_values = explainer.shap_values(features)
|
||||
# Return mean SHAP values for the positive class (solvable)
|
||||
return shap_values.mean(axis=0)[:, 1] # type: ignore[no-any-return]
|
||||
|
||||
case ImportanceStrategy.PERMUTATION:
|
||||
# Permutation importance requires ground truth labels
|
||||
if labels is None:
|
||||
raise ValueError('Labels are required for permutation importance')
|
||||
result = permutation_importance(
|
||||
self.classifier,
|
||||
features,
|
||||
labels,
|
||||
n_repeats=10, # Number of permutation rounds for stability
|
||||
random_state=self.random_state,
|
||||
)
|
||||
return result.importances_mean # type: ignore[no-any-return]
|
||||
|
||||
case ImportanceStrategy.IMPURITY:
|
||||
# Use built-in feature importances from RandomForest
|
||||
return self.classifier.feature_importances_ # type: ignore[no-any-return]
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f'Unknown importance strategy: {self.importance_strategy}'
|
||||
)
|
||||
|
||||
def add_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Add new features to the classifier's featurizer.
|
||||
|
||||
Note: Adding features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
|
||||
Args:
|
||||
features: List of Feature objects to add.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: Self for method chaining.
|
||||
"""
|
||||
for feature in features:
|
||||
if feature not in self.featurizer.features:
|
||||
self.featurizer.features.append(feature)
|
||||
return self
|
||||
|
||||
def forget_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||
"""
|
||||
Remove features from the classifier's featurizer.
|
||||
|
||||
Note: Removing features after training requires retraining the classifier
|
||||
since the feature space will have changed.
|
||||
|
||||
Args:
|
||||
features: List of Feature objects to remove.
|
||||
|
||||
Returns:
|
||||
SolvabilityClassifier: Self for method chaining.
|
||||
"""
|
||||
for feature in features:
|
||||
try:
|
||||
self.featurizer.features.remove(feature)
|
||||
except ValueError:
|
||||
# Feature not in list, continue with others
|
||||
continue
|
||||
return self
|
||||
|
||||
@field_serializer('classifier')
|
||||
@staticmethod
|
||||
def _rfc_to_json(rfc: RandomForestClassifier) -> str:
|
||||
"""
|
||||
Convert a RandomForestClassifier to a JSON-compatible value (a string).
|
||||
"""
|
||||
return base64.b64encode(pickle.dumps(rfc)).decode('utf-8')
|
||||
|
||||
@field_validator('classifier', mode='before')
|
||||
@staticmethod
|
||||
def _json_to_rfc(value: str | RandomForestClassifier) -> RandomForestClassifier:
|
||||
"""
|
||||
Convert a JSON-compatible value (a string) back to a RandomForestClassifier.
|
||||
"""
|
||||
if isinstance(value, RandomForestClassifier):
|
||||
return value
|
||||
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
model = pickle.loads(base64.b64decode(value))
|
||||
if isinstance(model, RandomForestClassifier):
|
||||
return model
|
||||
except Exception as e:
|
||||
raise ValueError(f'Failed to decode the classifier: {e}')
|
||||
|
||||
raise ValueError(
|
||||
'The classifier must be a RandomForestClassifier or a JSON-compatible dictionary.'
|
||||
)
|
||||
|
||||
def solvability_report(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
|
||||
Args:
|
||||
issue: The issue description for which to generate the report.
|
||||
llm_config: Optional LLM configuration to use for feature extraction.
|
||||
kwargs: Additional metadata to include in the report.
|
||||
|
||||
Returns:
|
||||
SolvabilityReport: The generated solvability report.
|
||||
"""
|
||||
if not self.is_fitted:
|
||||
raise ValueError(
|
||||
'The classifier must be fitted before generating a report.'
|
||||
)
|
||||
|
||||
scores = self.predict_proba(pd.Series([issue]), llm_config=llm_config)
|
||||
|
||||
return SolvabilityReport(
|
||||
identifier=self.identifier,
|
||||
issue=issue,
|
||||
score=scores[0, 1],
|
||||
features=self.features_.iloc[0].to_dict(),
|
||||
samples=self.samples,
|
||||
importance_strategy=self.importance_strategy,
|
||||
# Unlike the features, the importances are just a series with no link
|
||||
# to the actual feature names. For that we have to recombine with the
|
||||
# feature identifiers.
|
||||
feature_importances=dict(
|
||||
zip(
|
||||
self.featurizer.feature_identifiers(),
|
||||
self.feature_importances_.tolist(),
|
||||
)
|
||||
),
|
||||
random_state=self.random_state,
|
||||
metadata=dict(kwargs) if kwargs else None,
|
||||
# Both cost and response_latency are columns in the cost_ DataFrame,
|
||||
# so we can get both by just unpacking the first row.
|
||||
**self.cost_.iloc[0].to_dict(),
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||
) -> SolvabilityReport:
|
||||
"""
|
||||
Generate a solvability report for the given issue.
|
||||
"""
|
||||
return self.solvability_report(issue, llm_config=llm_config, **kwargs)
|
||||
@@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class DifficultyLevel(Enum):
|
||||
"""Enum representing the difficulty level based on solvability score."""
|
||||
|
||||
EASY = ('EASY', 0.7, '🟢')
|
||||
MEDIUM = ('MEDIUM', 0.4, '🟡')
|
||||
HARD = ('HARD', 0.0, '🔴')
|
||||
|
||||
def __init__(self, label: str, threshold: float, emoji: str):
|
||||
self.label = label
|
||||
self.threshold = threshold
|
||||
self.emoji = emoji
|
||||
|
||||
@classmethod
|
||||
def from_score(cls, score: float) -> DifficultyLevel:
|
||||
"""Get difficulty level from a solvability score.
|
||||
|
||||
Returns the difficulty level with the highest threshold that is less than or equal to the given score.
|
||||
"""
|
||||
# Sort enum values by threshold in descending order
|
||||
sorted_levels = sorted(cls, key=lambda x: x.threshold, reverse=True)
|
||||
|
||||
# Find the first level where score meets the threshold
|
||||
for level in sorted_levels:
|
||||
if score >= level.threshold:
|
||||
return level
|
||||
|
||||
# This should never happen if thresholds are set correctly,
|
||||
# but return the lowest threshold level as fallback
|
||||
return sorted_levels[-1]
|
||||
|
||||
def format_display(self) -> str:
|
||||
"""Format the difficulty level for display."""
|
||||
return f'{self.emoji} **Solvability: {self.label}**'
|
||||
368
enterprise/integrations/solvability/models/featurizer.py
Normal file
368
enterprise/integrations/solvability/models/featurizer.py
Normal file
@@ -0,0 +1,368 @@
|
||||
import json
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from openhands.core.config import LLMConfig
|
||||
from openhands.llm.llm import LLM
|
||||
|
||||
|
||||
class Feature(BaseModel):
|
||||
"""
|
||||
Represents a single boolean feature that can be extracted from issue descriptions.
|
||||
|
||||
Features are semantic properties of issues (e.g., "has_code_example", "requires_debugging")
|
||||
that are evaluated by LLMs and used as input to the solvability classifier.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""Unique identifier for the feature, used as column name in feature matrices."""
|
||||
|
||||
description: str
|
||||
"""Human-readable description of what the feature represents, used in LLM prompts."""
|
||||
|
||||
@property
|
||||
def to_tool_description_field(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert this feature to a JSON schema field for LLM tool calling.
|
||||
|
||||
Returns:
|
||||
dict: JSON schema field definition for this feature.
|
||||
"""
|
||||
return {
|
||||
'type': 'boolean',
|
||||
'description': self.description,
|
||||
}
|
||||
|
||||
|
||||
class EmbeddingDimension(BaseModel):
|
||||
"""
|
||||
Represents a single dimension (feature evaluation) within a feature embedding sample.
|
||||
|
||||
Each dimension corresponds to one feature being evaluated as true/false for a given issue.
|
||||
"""
|
||||
|
||||
feature_id: str
|
||||
"""Identifier of the feature being evaluated."""
|
||||
|
||||
result: bool
|
||||
"""Boolean result of the feature evaluation for this sample."""
|
||||
|
||||
|
||||
# Type alias for a single embedding sample - maps feature identifiers to boolean values
|
||||
EmbeddingSample = dict[str, bool]
|
||||
"""
|
||||
A single sample from the LLM evaluation of features for an issue.
|
||||
Maps feature identifiers to their boolean evaluations.
|
||||
"""
|
||||
|
||||
|
||||
class FeatureEmbedding(BaseModel):
|
||||
"""
|
||||
Represents the complete feature embedding for a single issue, including multiple samples
|
||||
and associated metadata about the LLM calls used to generate it.
|
||||
|
||||
Multiple samples are collected to account for LLM variability and provide more robust
|
||||
feature estimates through averaging.
|
||||
"""
|
||||
|
||||
samples: list[EmbeddingSample]
|
||||
"""List of individual feature evaluation samples from the LLM."""
|
||||
|
||||
prompt_tokens: int | None = None
|
||||
"""Total prompt tokens consumed across all LLM calls for this embedding."""
|
||||
|
||||
completion_tokens: int | None = None
|
||||
"""Total completion tokens generated across all LLM calls for this embedding."""
|
||||
|
||||
response_latency: float | None = None
|
||||
"""Total response latency (seconds) across all LLM calls for this embedding."""
|
||||
|
||||
@property
|
||||
def dimensions(self) -> list[str]:
|
||||
"""
|
||||
Get all unique feature identifiers present across all samples.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers that appear in at least one sample.
|
||||
"""
|
||||
dims: set[str] = set()
|
||||
for sample in self.samples:
|
||||
dims.update(sample.keys())
|
||||
return list(dims)
|
||||
|
||||
def coefficient(self, dimension: str) -> float | None:
|
||||
"""
|
||||
Calculate the average coefficient (0-1) for a specific feature dimension.
|
||||
|
||||
This computes the proportion of samples where the feature was evaluated as True,
|
||||
providing a continuous feature value for the classifier.
|
||||
|
||||
Args:
|
||||
dimension: Feature identifier to calculate coefficient for.
|
||||
|
||||
Returns:
|
||||
float | None: Average coefficient (0.0-1.0), or None if dimension not found.
|
||||
"""
|
||||
# Extract boolean values for this dimension, converting to 0/1
|
||||
values = [
|
||||
1 if v else 0
|
||||
for v in [sample.get(dimension) for sample in self.samples]
|
||||
if v is not None
|
||||
]
|
||||
if values:
|
||||
return sum(values) / len(values)
|
||||
return None
|
||||
|
||||
def to_row(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert the embedding to a flat dictionary suitable for DataFrame construction.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Dictionary with metadata fields and feature coefficients.
|
||||
"""
|
||||
return {
|
||||
'response_latency': self.response_latency,
|
||||
'prompt_tokens': self.prompt_tokens,
|
||||
'completion_tokens': self.completion_tokens,
|
||||
**{dimension: self.coefficient(dimension) for dimension in self.dimensions},
|
||||
}
|
||||
|
||||
def sample_entropy(self) -> dict[str, float]:
|
||||
"""
|
||||
Calculate the Shannon entropy of feature evaluations across samples.
|
||||
|
||||
Higher entropy indicates more variability in LLM responses for a feature,
|
||||
which may suggest ambiguity in the feature definition or issue description.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: Mapping of feature identifiers to their entropy values (0-1).
|
||||
"""
|
||||
from collections import Counter
|
||||
from math import log2
|
||||
|
||||
entropy = {}
|
||||
for dimension in self.dimensions:
|
||||
# Count True/False occurrences for this feature across samples
|
||||
counts = Counter(sample.get(dimension, False) for sample in self.samples)
|
||||
total = sum(counts.values())
|
||||
if total == 0:
|
||||
entropy[dimension] = 0.0
|
||||
continue
|
||||
# Calculate Shannon entropy: -Σ(p * log2(p))
|
||||
entropy_value = -sum(
|
||||
(count / total) * log2(count / total)
|
||||
for count in counts.values()
|
||||
if count > 0
|
||||
)
|
||||
entropy[dimension] = entropy_value
|
||||
return entropy
|
||||
|
||||
|
||||
class Featurizer(BaseModel):
|
||||
"""
|
||||
Orchestrates LLM-based feature extraction from issue descriptions.
|
||||
|
||||
The Featurizer uses structured LLM tool calling to evaluate boolean features
|
||||
for issue descriptions. It handles prompt construction, tool schema generation,
|
||||
and batch processing with concurrency.
|
||||
"""
|
||||
|
||||
system_prompt: str
|
||||
"""System prompt that provides context and instructions to the LLM."""
|
||||
|
||||
message_prefix: str
|
||||
"""Prefix added to user messages before the issue description."""
|
||||
|
||||
features: list[Feature]
|
||||
"""List of features to extract from each issue description."""
|
||||
|
||||
def system_message(self) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the system message for LLM conversations.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: System message dictionary for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'role': 'system',
|
||||
'content': self.system_prompt,
|
||||
}
|
||||
|
||||
def user_message(
|
||||
self, issue_description: str, set_cache: bool = True
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Construct the user message containing the issue description.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
set_cache: Whether to enable ephemeral caching for this message.
|
||||
Should be False for single samples to avoid cache overhead.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: User message dictionary for LLM API calls.
|
||||
"""
|
||||
message: dict[str, Any] = {
|
||||
'role': 'user',
|
||||
'content': f'{self.message_prefix}{issue_description}',
|
||||
}
|
||||
if set_cache:
|
||||
message['cache_control'] = {'type': 'ephemeral'}
|
||||
return message
|
||||
|
||||
@property
|
||||
def tool_choice(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get the tool choice configuration for forcing LLM to use the featurizer tool.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Tool choice configuration for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {'name': 'call_featurizer'},
|
||||
}
|
||||
|
||||
@property
|
||||
def tool_description(self) -> dict[str, Any]:
|
||||
"""
|
||||
Generate the tool schema for the featurizer function.
|
||||
|
||||
Creates a JSON schema that describes the featurizer tool with all configured
|
||||
features as boolean parameters.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: Complete tool description for LLM API calls.
|
||||
"""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'call_featurizer',
|
||||
'description': 'Record the features present in the issue.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
feature.identifier: feature.to_tool_description_field
|
||||
for feature in self.features
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def embed(
|
||||
self,
|
||||
issue_description: str,
|
||||
llm_config: LLMConfig,
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> FeatureEmbedding:
|
||||
"""
|
||||
Generate a feature embedding for a single issue description.
|
||||
|
||||
Makes multiple LLM calls to collect samples and reduce variance in feature evaluations.
|
||||
Each call uses tool calling to extract structured boolean feature values.
|
||||
|
||||
Args:
|
||||
issue_description: The description of the issue to analyze.
|
||||
llm_config: Configuration for the LLM to use.
|
||||
temperature: Sampling temperature for the model. Higher values increase randomness.
|
||||
samples: Number of samples to generate for averaging.
|
||||
|
||||
Returns:
|
||||
FeatureEmbedding: Complete embedding with samples and metadata.
|
||||
"""
|
||||
embedding_samples: list[dict[str, Any]] = []
|
||||
response_latency: float = 0.0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
# TODO: use llm registry
|
||||
llm = LLM(llm_config, service_id='solvability')
|
||||
|
||||
# Generate multiple samples to account for LLM variability
|
||||
for _ in range(samples):
|
||||
start_time = time.time()
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
self.system_message(),
|
||||
self.user_message(issue_description, set_cache=(samples > 1)),
|
||||
],
|
||||
tools=[self.tool_description],
|
||||
tool_choice=self.tool_choice,
|
||||
temperature=temperature,
|
||||
)
|
||||
stop_time = time.time()
|
||||
|
||||
# Extract timing and token usage metrics
|
||||
latency = stop_time - start_time
|
||||
# Parse the structured tool call response containing feature evaluations
|
||||
features = response.choices[0].message.tool_calls[0].function.arguments # type: ignore[index, union-attr]
|
||||
embedding = json.loads(features)
|
||||
|
||||
# Accumulate results and metrics
|
||||
embedding_samples.append(embedding)
|
||||
prompt_tokens += response.usage.prompt_tokens # type: ignore[union-attr, attr-defined]
|
||||
completion_tokens += response.usage.completion_tokens # type: ignore[union-attr, attr-defined]
|
||||
response_latency += latency
|
||||
|
||||
return FeatureEmbedding(
|
||||
samples=embedding_samples,
|
||||
response_latency=response_latency,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
def embed_batch(
|
||||
self,
|
||||
issue_descriptions: list[str],
|
||||
llm_config: LLMConfig,
|
||||
temperature: float = 1.0,
|
||||
samples: int = 10,
|
||||
) -> list[FeatureEmbedding]:
|
||||
"""
|
||||
Generate embeddings for a batch of issue descriptions using concurrent processing.
|
||||
|
||||
Processes multiple issues in parallel to improve throughput while maintaining
|
||||
result ordering.
|
||||
|
||||
Args:
|
||||
issue_descriptions: List of issue descriptions to analyze.
|
||||
llm_config: Configuration for the LLM to use.
|
||||
temperature: Sampling temperature for the model.
|
||||
samples: Number of samples to generate per issue.
|
||||
|
||||
Returns:
|
||||
list[FeatureEmbedding]: List of embeddings in the same order as input.
|
||||
"""
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# Submit all embedding tasks concurrently
|
||||
future_to_desc = {
|
||||
executor.submit(
|
||||
self.embed,
|
||||
desc,
|
||||
llm_config,
|
||||
temperature=temperature,
|
||||
samples=samples,
|
||||
): i
|
||||
for i, desc in enumerate(issue_descriptions)
|
||||
}
|
||||
|
||||
# Collect results in original order to maintain consistency
|
||||
results: list[FeatureEmbedding] = [None] * len(issue_descriptions) # type: ignore[list-item]
|
||||
for future in as_completed(future_to_desc):
|
||||
index = future_to_desc[future]
|
||||
results[index] = future.result()
|
||||
|
||||
return results
|
||||
|
||||
def feature_identifiers(self) -> list[str]:
|
||||
"""
|
||||
Get the identifiers of all configured features.
|
||||
|
||||
Returns:
|
||||
list[str]: List of feature identifiers in the order they were defined.
|
||||
"""
|
||||
return [feature.identifier for feature in self.features]
|
||||
@@ -0,0 +1,23 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ImportanceStrategy(str, Enum):
|
||||
"""
|
||||
Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
|
||||
in training loops and explanations.
|
||||
"""
|
||||
|
||||
SHAP = 'shap'
|
||||
"""
|
||||
Use SHAP (SHapley Additive exPlanations) to calculate feature importances.
|
||||
"""
|
||||
|
||||
PERMUTATION = 'permutation'
|
||||
"""
|
||||
Use the permutation-based feature importances.
|
||||
"""
|
||||
|
||||
IMPURITY = 'impurity'
|
||||
"""
|
||||
Use the impurity-based feature importances from the RandomForestClassifier.
|
||||
"""
|
||||
87
enterprise/integrations/solvability/models/report.py
Normal file
87
enterprise/integrations/solvability/models/report.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SolvabilityReport(BaseModel):
|
||||
"""
|
||||
Comprehensive report containing solvability predictions and analysis for a single issue.
|
||||
|
||||
This report includes the solvability score, extracted feature values, feature importance analysis,
|
||||
cost metrics (tokens and latency), and metadata about the prediction process. It serves as the
|
||||
primary output format for solvability analysis and can be used for logging, debugging, and
|
||||
generating human-readable summaries.
|
||||
"""
|
||||
|
||||
identifier: str
|
||||
"""
|
||||
The identifier of the solvability model used to generate the report.
|
||||
"""
|
||||
|
||||
issue: str
|
||||
"""
|
||||
The issue description for which the solvability is predicted.
|
||||
|
||||
This field is exactly the input to the solvability model.
|
||||
"""
|
||||
|
||||
score: float
|
||||
"""
|
||||
[0, 1]-valued score indicating the likelihood of the issue being solvable.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
"""
|
||||
Total number of prompt tokens used in API calls made to generate the features.
|
||||
"""
|
||||
|
||||
completion_tokens: int
|
||||
"""
|
||||
Total number of completion tokens used in API calls made to generate the features.
|
||||
"""
|
||||
|
||||
response_latency: float
|
||||
"""
|
||||
Total response latency of API calls made to generate the features.
|
||||
"""
|
||||
|
||||
features: dict[str, float]
|
||||
"""
|
||||
[0, 1]-valued scores for each feature in the model.
|
||||
|
||||
These are the values fed to the random forest classifier to generate the solvability score.
|
||||
"""
|
||||
|
||||
samples: int
|
||||
"""
|
||||
Number of samples used to compute the feature embedding coefficients.
|
||||
"""
|
||||
|
||||
importance_strategy: ImportanceStrategy
|
||||
"""
|
||||
Strategy used to calculate feature importances.
|
||||
"""
|
||||
|
||||
feature_importances: dict[str, float]
|
||||
"""
|
||||
Importance scores for each feature in the model.
|
||||
|
||||
Interpretation of these scores depends on the importance strategy used.
|
||||
"""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
"""
|
||||
Datetime when the report was created.
|
||||
"""
|
||||
|
||||
random_state: int | None = None
|
||||
"""
|
||||
Classifier random state used when generating this report.
|
||||
"""
|
||||
|
||||
metadata: dict[str, Any] | None = None
|
||||
"""
|
||||
Metadata for logging and debugging purposes.
|
||||
"""
|
||||
172
enterprise/integrations/solvability/models/summary.py
Normal file
172
enterprise/integrations/solvability/models/summary.py
Normal file
@@ -0,0 +1,172 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from integrations.solvability.models.difficulty_level import DifficultyLevel
|
||||
from integrations.solvability.models.report import SolvabilityReport
|
||||
from integrations.solvability.prompts import load_prompt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from openhands.llm import LLM
|
||||
|
||||
|
||||
class SolvabilitySummary(BaseModel):
|
||||
"""Summary of the solvability analysis in human-readable format."""
|
||||
|
||||
score: float
|
||||
"""
|
||||
Solvability score indicating the likelihood of the issue being solvable.
|
||||
"""
|
||||
|
||||
summary: str
|
||||
"""
|
||||
The executive summary content generated by the LLM.
|
||||
"""
|
||||
|
||||
actionable_feedback: str
|
||||
"""
|
||||
Actionable feedback content generated by the LLM.
|
||||
"""
|
||||
|
||||
positive_feedback: str
|
||||
"""
|
||||
Positive feedback content generated by the LLM, highlighting what is good about the issue.
|
||||
"""
|
||||
|
||||
prompt_tokens: int
|
||||
"""
|
||||
Number of prompt tokens used in the API call to generate the summary.
|
||||
"""
|
||||
|
||||
completion_tokens: int
|
||||
"""
|
||||
Number of completion tokens used in the API call to generate the summary.
|
||||
"""
|
||||
|
||||
response_latency: float
|
||||
"""
|
||||
Response latency of the API call to generate the summary.
|
||||
"""
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
"""
|
||||
Datetime when the summary was created.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def tool_description() -> dict[str, Any]:
|
||||
"""Get the tool description for the LLM."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'solvability_summary',
|
||||
'description': 'Generate a human-readable summary of the solvability analysis.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'summary': {
|
||||
'type': 'string',
|
||||
'description': 'A high-level (at most two sentences) summary of the solvability report.',
|
||||
},
|
||||
'actionable_feedback': {
|
||||
'type': 'string',
|
||||
'description': (
|
||||
'Bullet list of 1-3 pieces of actionable feedback on how the user can address the lowest scoring relevant features.'
|
||||
),
|
||||
},
|
||||
'positive_feedback': {
|
||||
'type': 'string',
|
||||
'description': (
|
||||
'Bullet list of 1-3 pieces of positive feedback on the issue, highlighting what is good about it.'
|
||||
),
|
||||
},
|
||||
},
|
||||
'required': ['summary', 'actionable_feedback'],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def tool_choice() -> dict[str, Any]:
|
||||
"""Get the tool choice for the LLM."""
|
||||
return {
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'solvability_summary',
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def system_message() -> dict[str, Any]:
|
||||
"""Get the system message for the LLM."""
|
||||
return {
|
||||
'role': 'system',
|
||||
'content': load_prompt('summary_system_message'),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def user_message(report: SolvabilityReport) -> dict[str, Any]:
|
||||
"""Get the user message for the LLM."""
|
||||
return {
|
||||
'role': 'user',
|
||||
'content': load_prompt(
|
||||
'summary_user_message',
|
||||
report=report.model_dump(),
|
||||
difficulty_level=DifficultyLevel.from_score(report.score).value[0],
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_report(report: SolvabilityReport, llm: LLM) -> SolvabilitySummary:
|
||||
"""Create a SolvabilitySummary from a SolvabilityReport."""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
response = llm.completion(
|
||||
messages=[
|
||||
SolvabilitySummary.system_message(),
|
||||
SolvabilitySummary.user_message(report),
|
||||
],
|
||||
tools=[SolvabilitySummary.tool_description()],
|
||||
tool_choice=SolvabilitySummary.tool_choice(),
|
||||
)
|
||||
response_latency = time.time() - start_time
|
||||
|
||||
# Grab the arguments from the forced function call
|
||||
arguments = json.loads(
|
||||
response.choices[0].message.tool_calls[0].function.arguments
|
||||
)
|
||||
|
||||
return SolvabilitySummary(
|
||||
# The score is copied directly from the report
|
||||
score=report.score,
|
||||
# Performance and usage metrics are pulled from the response
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
response_latency=response_latency,
|
||||
# Every other field should be taken from the forced function call
|
||||
**arguments,
|
||||
)
|
||||
|
||||
def format_as_markdown(self) -> str:
|
||||
"""Format the summary content as Markdown."""
|
||||
# Convert score to difficulty level enum
|
||||
difficulty_level = DifficultyLevel.from_score(self.score)
|
||||
|
||||
# Create the main difficulty display
|
||||
result = f'{difficulty_level.format_display()}\n\n{self.summary}'
|
||||
|
||||
# If not easy, show the three features with lowest importance scores
|
||||
if difficulty_level != DifficultyLevel.EASY:
|
||||
# Add dropdown with lowest importance features
|
||||
result += '\n\nYou can make the issue easier to resolve by addressing these concerns in the conversation:\n\n'
|
||||
result += self.actionable_feedback
|
||||
|
||||
# If the difficulty isn't hard, add some positive feedback
|
||||
if difficulty_level != DifficultyLevel.HARD:
|
||||
result += '\n\nPositive feedback:\n\n'
|
||||
result += self.positive_feedback
|
||||
|
||||
return result
|
||||
13
enterprise/integrations/solvability/prompts/__init__.py
Normal file
13
enterprise/integrations/solvability/prompts/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import jinja2
|
||||
|
||||
|
||||
def load_prompt(prompt: str, **kwargs) -> str:
|
||||
"""Load a prompt by name. Passes all the keyword arguments to the prompt template."""
|
||||
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(__file__).parent))
|
||||
template = env.get_template(f'{prompt}.j2')
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
__all__ = ['load_prompt']
|
||||
@@ -0,0 +1,10 @@
|
||||
You are a helpful assistant that generates human-readable summaries of solvability reports.
|
||||
The report predicts how likely it is that the issue can be resolved, and is produced purely based on the information provided in the issue description and comments.
|
||||
The report explains which features are present in the issue and how impactful they are to the solvability score (using SHAP values).
|
||||
Your task is to create a concise, high-level summary of the solvability analysis,
|
||||
with an emphasis on the key factors that make the issue easy or hard to resolve.
|
||||
Focus on the features with extreme scores, BUT ONLY if they are related to the issue at hand after careful consideration.
|
||||
You should NEVER mention: SHAP, scores, feature names, or technical metrics.
|
||||
You will also be given the expected difficulty of the issue, as EASY/MEDIUM/HARD.
|
||||
Be sure to frame your responses with that difficulty in mind.
|
||||
For example, if the issue is HARD you should not describe it as "straightforward".
|
||||
@@ -0,0 +1,9 @@
|
||||
Generate a high-level summary of the solvability report:
|
||||
|
||||
{{ report }}
|
||||
|
||||
We estimate the issue is {{ difficulty_level }}.
|
||||
The summary should be concise (at most two sentences) and describe the primary characteristics of this issue.
|
||||
Focus on what information is present and what factors are most relevant to resolution.
|
||||
Actionable feedback should be something that can be addressed by the user purely by providing more information.
|
||||
Positive feedback should explain the features that are positively contributing to the solvability score.
|
||||
0
enterprise/integrations/solvability/py.typed
Normal file
0
enterprise/integrations/solvability/py.typed
Normal file
73
enterprise/integrations/stripe_service.py
Normal file
73
enterprise/integrations/stripe_service.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import stripe
|
||||
from server.auth.token_manager import TokenManager
|
||||
from server.constants import STRIPE_API_KEY
|
||||
from server.logger import logger
|
||||
from storage.database import session_maker
|
||||
from storage.stripe_customer import StripeCustomer
|
||||
|
||||
stripe.api_key = STRIPE_API_KEY
|
||||
|
||||
|
||||
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||
# First search our own DB...
|
||||
with session_maker() as session:
|
||||
stripe_customer = (
|
||||
session.query(StripeCustomer)
|
||||
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if stripe_customer:
|
||||
return stripe_customer.stripe_customer_id
|
||||
|
||||
# If that fails, fallback to stripe
|
||||
search_result = await stripe.Customer.search_async(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
)
|
||||
data = search_result.data
|
||||
if not data:
|
||||
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||
return None
|
||||
return data[0].id # type: ignore [attr-defined]
|
||||
|
||||
|
||||
async def find_or_create_customer(user_id: str) -> str:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id:
|
||||
return customer_id
|
||||
logger.info('creating_customer', extra={'user_id': user_id})
|
||||
|
||||
# Get the user info from keycloak
|
||||
token_manager = TokenManager()
|
||||
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||
|
||||
# Create the customer in stripe
|
||||
customer = await stripe.Customer.create_async(
|
||||
email=str(user_info.get('email', '')),
|
||||
metadata={'user_id': user_id},
|
||||
)
|
||||
|
||||
# Save the stripe customer in the local db
|
||||
with session_maker() as session:
|
||||
session.add(
|
||||
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||
)
|
||||
session.commit()
|
||||
|
||||
logger.info(
|
||||
'created_customer',
|
||||
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||
)
|
||||
return customer.id
|
||||
|
||||
|
||||
async def has_payment_method(user_id: str) -> bool:
|
||||
customer_id = await find_customer_id_by_user_id(user_id)
|
||||
if customer_id is None:
|
||||
return False
|
||||
payment_methods = await stripe.Customer.list_payment_methods_async(
|
||||
customer_id,
|
||||
)
|
||||
logger.info(
|
||||
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||
)
|
||||
return bool(payment_methods.data)
|
||||
51
enterprise/integrations/types.py
Normal file
51
enterprise/integrations/types.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from jinja2 import Environment
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GitLabResourceType(Enum):
|
||||
GROUP = 'group'
|
||||
SUBGROUP = 'subgroup'
|
||||
PROJECT = 'project'
|
||||
|
||||
|
||||
class PRStatus(Enum):
|
||||
CLOSED = 'CLOSED'
|
||||
MERGED = 'MERGED'
|
||||
|
||||
|
||||
class UserData(BaseModel):
|
||||
user_id: int
|
||||
username: str
|
||||
keycloak_user_id: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryExtractionTracker:
|
||||
conversation_id: str
|
||||
should_extract: bool
|
||||
send_summary_instruction: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResolverViewInterface(SummaryExtractionTracker):
|
||||
installation_id: int
|
||||
user_info: UserData
|
||||
issue_number: int
|
||||
full_repo_name: str
|
||||
is_public_repo: bool
|
||||
raw_payload: dict
|
||||
|
||||
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||
"Instructions passed when conversation is first initialized"
|
||||
raise NotImplementedError()
|
||||
|
||||
async def create_new_conversation(self, jinja_env: Environment, token: str):
|
||||
"Create a new conversation"
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_callback_id(self) -> str:
|
||||
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||
raise NotImplementedError()
|
||||
557
enterprise/integrations/utils.py
Normal file
557
enterprise/integrations/utils.py
Normal file
@@ -0,0 +1,557 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from server.constants import WEB_HOST
|
||||
from storage.repository_store import RepositoryStore
|
||||
from storage.stored_repository import StoredRepository
|
||||
from storage.user_repo_map import UserRepositoryMap
|
||||
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||
|
||||
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||
from openhands.core.logger import openhands_logger as logger
|
||||
from openhands.core.schema.agent import AgentState
|
||||
from openhands.events import Event, EventSource
|
||||
from openhands.events.action import (
|
||||
AgentFinishAction,
|
||||
MessageAction,
|
||||
)
|
||||
from openhands.events.event_store_abc import EventStoreABC
|
||||
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||
from openhands.integrations.service_types import Repository
|
||||
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openhands.server.conversation_manager.conversation_manager import (
|
||||
ConversationManager,
|
||||
)
|
||||
|
||||
# ---- DO NOT REMOVE ----
|
||||
# WARNING: Langfuse depends on the WEB_HOST environment variable being set to track events.
|
||||
HOST = WEB_HOST
|
||||
# ---- DO NOT REMOVE ----
|
||||
|
||||
HOST_URL = f'https://{HOST}'
|
||||
GITHUB_WEBHOOK_URL = f'{HOST_URL}/integration/github/events'
|
||||
GITLAB_WEBHOOK_URL = f'{HOST_URL}/integration/gitlab/events'
|
||||
conversation_prefix = 'conversations/{}'
|
||||
CONVERSATION_URL = f'{HOST_URL}/{conversation_prefix}'
|
||||
|
||||
# Toggle for auto-response feature that proactively starts conversations with users when workflow tests fail
|
||||
ENABLE_PROACTIVE_CONVERSATION_STARTERS = (
|
||||
os.getenv('ENABLE_PROACTIVE_CONVERSATION_STARTERS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
# Toggle for solvability report feature
|
||||
ENABLE_SOLVABILITY_ANALYSIS = (
|
||||
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||
)
|
||||
|
||||
|
||||
OPENHANDS_RESOLVER_TEMPLATES_DIR = 'openhands/integrations/templates/resolver/'
|
||||
jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||
|
||||
|
||||
def get_oh_labels(web_host: str) -> tuple[str, str]:
|
||||
"""Get the OpenHands labels based on the web host.
|
||||
|
||||
Args:
|
||||
web_host: The web host string to check
|
||||
|
||||
Returns:
|
||||
A tuple of (oh_label, inline_oh_label) where:
|
||||
- oh_label is 'openhands-exp' for staging/local hosts, 'openhands' otherwise
|
||||
- inline_oh_label is '@openhands-exp' for staging/local hosts, '@openhands' otherwise
|
||||
"""
|
||||
web_host = web_host.strip()
|
||||
is_staging_or_local = 'staging' in web_host or 'local' in web_host
|
||||
oh_label = 'openhands-exp' if is_staging_or_local else 'openhands'
|
||||
inline_oh_label = '@openhands-exp' if is_staging_or_local else '@openhands'
|
||||
return oh_label, inline_oh_label
|
||||
|
||||
|
||||
def get_summary_instruction():
|
||||
summary_instruction_template = jinja_env.get_template('summary_prompt.j2')
|
||||
summary_instruction = summary_instruction_template.render()
|
||||
return summary_instruction
|
||||
|
||||
|
||||
def has_exact_mention(text: str, mention: str) -> bool:
|
||||
"""Check if the text contains an exact mention (not part of a larger word).
|
||||
|
||||
Args:
|
||||
text: The text to check for mentions
|
||||
mention: The mention to look for (e.g. "@openhands")
|
||||
|
||||
Returns:
|
||||
bool: True if the exact mention is found, False otherwise
|
||||
|
||||
Example:
|
||||
>>> has_exact_mention("Hello @openhands!", "@openhands") # True
|
||||
>>> has_exact_mention("Hello @openhands-agent!", "@openhands") # False
|
||||
>>> has_exact_mention("(@openhands)", "@openhands") # True
|
||||
>>> has_exact_mention("user@openhands.com", "@openhands") # False
|
||||
>>> has_exact_mention("Hello @OpenHands!", "@openhands") # True (case-insensitive)
|
||||
"""
|
||||
# Convert both text and mention to lowercase for case-insensitive matching
|
||||
text_lower = text.lower()
|
||||
mention_lower = mention.lower()
|
||||
|
||||
pattern = re.escape(mention_lower)
|
||||
# Match mention that is not part of a larger word
|
||||
return bool(re.search(rf'(?:^|[^\w@]){pattern}(?![\w-])', text_lower))
|
||||
|
||||
|
||||
def confirm_event_type(event: Event):
|
||||
return isinstance(event, AgentStateChangedObservation) and not (
|
||||
event.agent_state == AgentState.REJECTED
|
||||
or event.agent_state == AgentState.USER_CONFIRMED
|
||||
or event.agent_state == AgentState.USER_REJECTED
|
||||
or event.agent_state == AgentState.LOADING
|
||||
or event.agent_state == AgentState.RUNNING
|
||||
)
|
||||
|
||||
|
||||
def get_readable_error_reason(reason: str):
|
||||
if reason == 'STATUS$ERROR_LLM_AUTHENTICATION':
|
||||
reason = 'Authentication with the LLM provider failed. Please check your API key or credentials'
|
||||
elif reason == 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE':
|
||||
reason = 'The LLM service is temporarily unavailable. Please try again later'
|
||||
elif reason == 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR':
|
||||
reason = 'The LLM provider encountered an internal error. Please try again soon'
|
||||
elif reason == 'STATUS$ERROR_LLM_OUT_OF_CREDITS':
|
||||
reason = "You've run out of credits. Please top up to continue"
|
||||
elif reason == 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION':
|
||||
reason = 'Content policy violation. The output was blocked by content filtering policy'
|
||||
return reason
|
||||
|
||||
|
||||
def get_summary_for_agent_state(
|
||||
observations: list[AgentStateChangedObservation], conversation_link: str
|
||||
) -> str:
|
||||
unknown_error_msg = f'OpenHands encountered an unknown error. [See the conversation]({conversation_link}) for more information, or try again'
|
||||
|
||||
if len(observations) == 0:
|
||||
logger.error(
|
||||
'Unknown error: No agent state observations found',
|
||||
extra={'conversation_link': conversation_link},
|
||||
)
|
||||
return unknown_error_msg
|
||||
|
||||
observation: AgentStateChangedObservation = observations[0]
|
||||
state = observation.agent_state
|
||||
|
||||
if state == AgentState.RATE_LIMITED:
|
||||
logger.warning(
|
||||
'Agent was rate limited',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return 'OpenHands was rate limited by the LLM provider. Please try again later.'
|
||||
|
||||
if state == AgentState.ERROR:
|
||||
reason = observation.reason
|
||||
reason = get_readable_error_reason(reason)
|
||||
|
||||
logger.error(
|
||||
'Agent encountered an error',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': observation.reason,
|
||||
'readable_reason': reason,
|
||||
},
|
||||
)
|
||||
|
||||
return f'OpenHands encountered an error: **{reason}**.\n\n[See the conversation]({conversation_link}) for more information.'
|
||||
|
||||
if state == AgentState.AWAITING_USER_INPUT:
|
||||
logger.info(
|
||||
'Agent is awaiting user input',
|
||||
extra={
|
||||
'agent_state': state.value,
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return f'OpenHands is waiting for your input. [Continue the conversation]({conversation_link}) to provide additional instructions.'
|
||||
|
||||
# Log unknown agent state as error
|
||||
logger.error(
|
||||
'Unknown error: Unhandled agent state',
|
||||
extra={
|
||||
'agent_state': state.value if hasattr(state, 'value') else str(state),
|
||||
'conversation_link': conversation_link,
|
||||
'observation_reason': getattr(observation, 'reason', None),
|
||||
},
|
||||
)
|
||||
return unknown_error_msg
|
||||
|
||||
|
||||
def get_final_agent_observation(
|
||||
event_store: EventStoreABC,
|
||||
) -> list[AgentStateChangedObservation]:
|
||||
return event_store.get_matching_events(
|
||||
source=EventSource.ENVIRONMENT,
|
||||
event_types=(AgentStateChangedObservation,),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
|
||||
def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
|
||||
return event_store.get_matching_events(
|
||||
source=EventSource.USER, event_types=(MessageAction,), limit=1, reverse='true'
|
||||
)
|
||||
|
||||
|
||||
def extract_summary_from_event_store(
|
||||
event_store: EventStoreABC, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
summary_instruction = get_summary_instruction()
|
||||
|
||||
instruction_event: list[MessageAction] = event_store.get_matching_events(
|
||||
query=json.dumps(summary_instruction),
|
||||
source=EventSource.USER,
|
||||
event_types=(MessageAction,),
|
||||
limit=1,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
final_agent_observation = get_final_agent_observation(event_store)
|
||||
|
||||
# Find summary instruction event ID
|
||||
if len(instruction_event) == 0:
|
||||
logger.warning(
|
||||
'no_instruction_event_found', extra={'conversation_id': conversation_id}
|
||||
)
|
||||
return get_summary_for_agent_state(
|
||||
final_agent_observation, conversation_link
|
||||
) # Agent did not receive summary instruction
|
||||
|
||||
event_id: int = instruction_event[0].id
|
||||
|
||||
agent_messages: list[MessageAction | AgentFinishAction] = (
|
||||
event_store.get_matching_events(
|
||||
start_id=event_id,
|
||||
source=EventSource.AGENT,
|
||||
event_types=(MessageAction, AgentFinishAction),
|
||||
reverse=True,
|
||||
limit=1,
|
||||
)
|
||||
)
|
||||
|
||||
if len(agent_messages) == 0:
|
||||
logger.warning(
|
||||
'no_agent_messages_found', extra={'conversation_id': conversation_id}
|
||||
)
|
||||
return get_summary_for_agent_state(
|
||||
final_agent_observation, conversation_link
|
||||
) # Agent failed to generate summary
|
||||
|
||||
summary_event: MessageAction | AgentFinishAction = agent_messages[0]
|
||||
if isinstance(summary_event, MessageAction):
|
||||
return summary_event.content
|
||||
|
||||
return summary_event.final_thought
|
||||
|
||||
|
||||
async def get_event_store_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
) -> EventStoreABC:
|
||||
agent_loop_infos = await conversation_manager.get_agent_loop_info(
|
||||
filter_to_sids={conversation_id}
|
||||
)
|
||||
if not agent_loop_infos or agent_loop_infos[0].status != ConversationStatus.RUNNING:
|
||||
raise RuntimeError(f'conversation_not_running:{conversation_id}')
|
||||
event_store = agent_loop_infos[0].event_store
|
||||
if not event_store:
|
||||
raise RuntimeError(f'event_store_missing:{conversation_id}')
|
||||
return event_store
|
||||
|
||||
|
||||
async def get_last_user_msg_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
):
|
||||
event_store = await get_event_store_from_conversation_manager(
|
||||
conversation_manager, conversation_id
|
||||
)
|
||||
return get_last_user_msg(event_store)
|
||||
|
||||
|
||||
async def extract_summary_from_conversation_manager(
|
||||
conversation_manager: ConversationManager, conversation_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Get agent summary or alternative message depending on current AgentState
|
||||
"""
|
||||
|
||||
event_store = await get_event_store_from_conversation_manager(
|
||||
conversation_manager, conversation_id
|
||||
)
|
||||
summary = extract_summary_from_event_store(event_store, conversation_id)
|
||||
return append_conversation_footer(summary, conversation_id)
|
||||
|
||||
|
||||
def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||
"""
|
||||
Append a small footer with the conversation URL to a message.
|
||||
|
||||
Args:
|
||||
message: The original message content
|
||||
conversation_id: The conversation ID to link to
|
||||
|
||||
Returns:
|
||||
The message with the conversation footer appended
|
||||
"""
|
||||
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
|
||||
return message + footer
|
||||
|
||||
|
||||
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||
"""
|
||||
Store repositories in DB and create user-repository mappings
|
||||
|
||||
Args:
|
||||
repos: List of Repository objects to store
|
||||
user_id: User ID associated with these repositories
|
||||
"""
|
||||
|
||||
# Convert Repository objects to StoredRepository objects
|
||||
# Convert Repository objects to UserRepositoryMap objects
|
||||
stored_repos = []
|
||||
user_repos = []
|
||||
for repo in repos:
|
||||
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||
stored_repo = StoredRepository(
|
||||
repo_name=repo.full_name,
|
||||
repo_id=repo_id,
|
||||
is_public=repo.is_public,
|
||||
# Optional fields set to None by default
|
||||
has_microagent=None,
|
||||
has_setup_script=None,
|
||||
)
|
||||
stored_repos.append(stored_repo)
|
||||
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||
|
||||
user_repos.append(user_repo_map)
|
||||
|
||||
# Get config instance
|
||||
config = OpenHandsConfig()
|
||||
|
||||
try:
|
||||
# Store repositories in the repos table
|
||||
repo_store = RepositoryStore.get_instance(config)
|
||||
repo_store.store_projects(stored_repos)
|
||||
|
||||
# Store user-repository mappings in the user-repos table
|
||||
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||
user_repo_store.store_user_repo_mappings(user_repos)
|
||||
|
||||
logger.info(f'Saved repos for user {user_id}')
|
||||
except Exception:
|
||||
logger.warning('Failed to save repos', exc_info=True)
|
||||
|
||||
|
||||
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||
"""
|
||||
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||
Args:
|
||||
user_msg: Input message that may contain repository references
|
||||
Returns:
|
||||
List of repository names in 'owner/repo' format, empty list if none found
|
||||
"""
|
||||
# Normalize the message by removing extra whitespace and newlines
|
||||
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
|
||||
|
||||
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
|
||||
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||
|
||||
# Pattern to match direct owner/repo mentions (e.g., "All-Hands-AI/OpenHands")
|
||||
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||
direct_pattern = (
|
||||
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||
)
|
||||
|
||||
matches = []
|
||||
|
||||
# First, find all Git URLs (highest priority)
|
||||
git_matches = re.findall(git_url_pattern, normalized_msg)
|
||||
for owner, repo in git_matches:
|
||||
# Remove .git extension if present
|
||||
repo = re.sub(r'\.git$', '', repo)
|
||||
matches.append(f'{owner}/{repo}')
|
||||
|
||||
# Second, find all direct owner/repo mentions
|
||||
direct_matches = re.findall(direct_pattern, normalized_msg)
|
||||
for owner, repo in direct_matches:
|
||||
full_match = f'{owner}/{repo}'
|
||||
|
||||
# Skip if it looks like a version number, date, or file path
|
||||
if (
|
||||
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
|
||||
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
|
||||
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
|
||||
or repo.endswith('.txt')
|
||||
or repo.endswith('.md') # file extensions
|
||||
or repo.endswith('.py')
|
||||
or repo.endswith('.js')
|
||||
or '.' in repo
|
||||
and len(repo.split('.')) > 2
|
||||
): # complex file paths
|
||||
continue
|
||||
|
||||
# Avoid duplicates from Git URLs already found
|
||||
if full_match not in matches:
|
||||
matches.append(full_match)
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def filter_potential_repos_by_user_msg(
|
||||
user_msg: str, user_repos: list[Repository]
|
||||
) -> tuple[bool, list[Repository]]:
|
||||
"""Filter repositories based on user message inference."""
|
||||
inferred_repos = infer_repo_from_message(user_msg)
|
||||
if not inferred_repos:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
final_repos = []
|
||||
for repo in user_repos:
|
||||
# Check if the repo matches any of the inferred repositories
|
||||
for inferred_repo in inferred_repos:
|
||||
if inferred_repo.lower() in repo.full_name.lower():
|
||||
final_repos.append(repo)
|
||||
break # Avoid adding the same repo multiple times
|
||||
|
||||
# no repos matched, return original list
|
||||
if len(final_repos) == 0:
|
||||
return False, user_repos[0:99]
|
||||
|
||||
# Found exact match
|
||||
elif len(final_repos) == 1:
|
||||
return True, final_repos
|
||||
|
||||
# Found partial matches
|
||||
return False, final_repos[0:99]
|
||||
|
||||
|
||||
def markdown_to_jira_markup(markdown_text: str) -> str:
|
||||
"""
|
||||
Convert markdown text to Jira Wiki Markup format.
|
||||
This function handles common markdown elements and converts them to their
|
||||
Jira Wiki Markup equivalents. It's designed to be exception-safe.
|
||||
Args:
|
||||
markdown_text: The markdown text to convert
|
||||
Returns:
|
||||
str: The converted Jira Wiki Markup text
|
||||
"""
|
||||
if not markdown_text or not isinstance(markdown_text, str):
|
||||
return ''
|
||||
|
||||
try:
|
||||
# Work with a copy to avoid modifying the original
|
||||
text = markdown_text
|
||||
|
||||
# Convert headers (# ## ### #### ##### ######)
|
||||
text = re.sub(r'^#{6}\s+(.*?)$', r'h6. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{5}\s+(.*?)$', r'h5. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{4}\s+(.*?)$', r'h4. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{3}\s+(.*?)$', r'h3. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{2}\s+(.*?)$', r'h2. \1', text, flags=re.MULTILINE)
|
||||
text = re.sub(r'^#{1}\s+(.*?)$', r'h1. \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert code blocks first (before other formatting)
|
||||
text = re.sub(
|
||||
r'```(\w+)\n(.*?)\n```', r'{code:\1}\n\2\n{code}', text, flags=re.DOTALL
|
||||
)
|
||||
text = re.sub(r'```\n(.*?)\n```', r'{code}\n\1\n{code}', text, flags=re.DOTALL)
|
||||
|
||||
# Convert inline code (`code`)
|
||||
text = re.sub(r'`([^`]+)`', r'{{\1}}', text)
|
||||
|
||||
# Convert markdown formatting to Jira formatting
|
||||
# Use temporary placeholders to avoid conflicts between bold and italic conversion
|
||||
|
||||
# First convert bold (double markers) to temporary placeholders
|
||||
text = re.sub(r'\*\*(.*?)\*\*', r'JIRA_BOLD_START\1JIRA_BOLD_END', text)
|
||||
text = re.sub(r'__(.*?)__', r'JIRA_BOLD_START\1JIRA_BOLD_END', text)
|
||||
|
||||
# Now convert single asterisk italics
|
||||
text = re.sub(r'\*([^*]+?)\*', r'_\1_', text)
|
||||
|
||||
# Convert underscore italics
|
||||
text = re.sub(r'(?<!_)_([^_]+?)_(?!_)', r'_\1_', text)
|
||||
|
||||
# Finally, restore bold markers
|
||||
text = text.replace('JIRA_BOLD_START', '*')
|
||||
text = text.replace('JIRA_BOLD_END', '*')
|
||||
|
||||
# Convert links [text](url)
|
||||
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'[\1|\2]', text)
|
||||
|
||||
# Convert unordered lists (- or * or +)
|
||||
text = re.sub(r'^[\s]*[-*+]\s+(.*?)$', r'* \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert ordered lists (1. 2. etc.)
|
||||
text = re.sub(r'^[\s]*\d+\.\s+(.*?)$', r'# \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert strikethrough (~~text~~)
|
||||
text = re.sub(r'~~(.*?)~~', r'-\1-', text)
|
||||
|
||||
# Convert horizontal rules (---, ***, ___)
|
||||
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', r'----', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert blockquotes (> text)
|
||||
text = re.sub(r'^>\s+(.*?)$', r'bq. \1', text, flags=re.MULTILINE)
|
||||
|
||||
# Convert tables (basic support)
|
||||
# This is a simplified table conversion - Jira tables are quite different
|
||||
lines = text.split('\n')
|
||||
in_table = False
|
||||
converted_lines = []
|
||||
|
||||
for line in lines:
|
||||
if (
|
||||
'|' in line
|
||||
and line.strip().startswith('|')
|
||||
and line.strip().endswith('|')
|
||||
):
|
||||
# Skip markdown table separator lines (contain ---)
|
||||
if '---' in line:
|
||||
continue
|
||||
if not in_table:
|
||||
in_table = True
|
||||
# Convert markdown table row to Jira table row
|
||||
cells = [cell.strip() for cell in line.split('|')[1:-1]]
|
||||
converted_line = '|' + '|'.join(cells) + '|'
|
||||
converted_lines.append(converted_line)
|
||||
elif in_table and line.strip() and '|' not in line:
|
||||
in_table = False
|
||||
converted_lines.append(line)
|
||||
else:
|
||||
in_table = False
|
||||
converted_lines.append(line)
|
||||
|
||||
text = '\n'.join(converted_lines)
|
||||
|
||||
return text
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't raise it - return original text as fallback
|
||||
print(f'Error converting markdown to Jira markup: {str(e)}')
|
||||
return markdown_text or ''
|
||||
114
enterprise/migrations/env.py
Normal file
114
enterprise/migrations/env.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from google.cloud.sql.connector import Connector
|
||||
from sqlalchemy import create_engine
|
||||
from storage.base import Base
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
DB_USER = os.getenv('DB_USER', 'postgres')
|
||||
DB_PASS = os.getenv('DB_PASS', 'postgres')
|
||||
DB_HOST = os.getenv('DB_HOST', 'localhost')
|
||||
DB_PORT = os.getenv('DB_PORT', '5432')
|
||||
DB_NAME = os.getenv('DB_NAME', 'openhands')
|
||||
|
||||
GCP_DB_INSTANCE = os.getenv('GCP_DB_INSTANCE')
|
||||
GCP_PROJECT = os.getenv('GCP_PROJECT')
|
||||
GCP_REGION = os.getenv('GCP_REGION')
|
||||
|
||||
POOL_SIZE = int(os.getenv('DB_POOL_SIZE', '25'))
|
||||
MAX_OVERFLOW = int(os.getenv('DB_MAX_OVERFLOW', '10'))
|
||||
|
||||
|
||||
def get_engine(database_name=DB_NAME):
|
||||
"""Create SQLAlchemy engine with optional database name."""
|
||||
if GCP_DB_INSTANCE:
|
||||
|
||||
def get_db_connection():
|
||||
connector = Connector()
|
||||
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||
return connector.connect(
|
||||
instance_string,
|
||||
'pg8000',
|
||||
user=DB_USER,
|
||||
password=DB_PASS.strip(),
|
||||
db=database_name,
|
||||
)
|
||||
|
||||
return create_engine(
|
||||
'postgresql+pg8000://',
|
||||
creator=get_db_connection,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
else:
|
||||
url = f'postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{database_name}'
|
||||
return create_engine(
|
||||
url,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_pre_ping=True,
|
||||
)
|
||||
|
||||
|
||||
engine = get_engine()
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
"""
|
||||
url = config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
"""
|
||||
connectable = engine
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
version_table_schema=target_metadata.schema,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
26
enterprise/migrations/script.py.mako
Normal file
26
enterprise/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
45
enterprise/migrations/versions/001_create_feedback_table.py
Normal file
45
enterprise/migrations/versions/001_create_feedback_table.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Create feedback table
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
Create Date: 2024-03-19 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '001'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'feedback',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('version', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'polarity',
|
||||
sa.Enum('positive', 'negative', name='polarity_enum'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
'permissions',
|
||||
sa.Enum('public', 'private', name='permissions_enum'),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column('trajectory', sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('feedback')
|
||||
op.execute('DROP TYPE polarity_enum')
|
||||
op.execute('DROP TYPE permissions_enum')
|
||||
@@ -0,0 +1,45 @@
|
||||
"""create saas settings table
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2025-01-27 20:08:58.360566
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '002'
|
||||
down_revision: Union[str, None] = '001'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# This was created to match the settings object - in future some of these strings should probabyl
|
||||
# be replaced with enum types.
|
||||
op.create_table(
|
||||
'settings',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column('language', sa.String(), nullable=True),
|
||||
sa.Column('agent', sa.String(), nullable=True),
|
||||
sa.Column('max_iterations', sa.Integer(), nullable=True),
|
||||
sa.Column('security_analyzer', sa.String(), nullable=True),
|
||||
sa.Column('confirmation_mode', sa.Boolean(), nullable=True, default=False),
|
||||
sa.Column('llm_model', sa.String(), nullable=True),
|
||||
sa.Column('llm_api_key', sa.String(), nullable=True),
|
||||
sa.Column('llm_base_url', sa.String(), nullable=True),
|
||||
sa.Column('remote_runtime_resource_factor', sa.Integer(), nullable=True),
|
||||
sa.Column('github_token', sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
'enable_default_condenser', sa.Boolean(), nullable=False, default=False
|
||||
),
|
||||
sa.Column('user_consents_to_analytics', sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('settings')
|
||||
@@ -0,0 +1,35 @@
|
||||
"""create saas conversations table
|
||||
|
||||
Revision ID: 003
|
||||
Revises: 002
|
||||
Create Date: 2025-01-29 09:36:49.475467
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '003'
|
||||
down_revision: Union[str, None] = '002'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'conversation_metadata',
|
||||
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=False, index=True),
|
||||
sa.Column('selected_repository', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.String(), nullable=True),
|
||||
sa.Column('last_updated_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False, index=True),
|
||||
sa.PrimaryKeyConstraint('conversation_id'),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('conversation_metadata')
|
||||
@@ -0,0 +1,47 @@
|
||||
"""create saas conversations table
|
||||
|
||||
Revision ID: 004
|
||||
Revises: 003
|
||||
Create Date: 2025-01-29 09:36:49.475467
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '004'
|
||||
down_revision: Union[str, None] = '003'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'billing_sessions',
|
||||
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
'status',
|
||||
sa.Enum(
|
||||
'in_progress',
|
||||
'completed',
|
||||
'cancelled',
|
||||
'error',
|
||||
name='billing_session_status_enum',
|
||||
),
|
||||
nullable=False,
|
||||
default='in_progress',
|
||||
),
|
||||
sa.Column('price', sa.DECIMAL(19, 4), nullable=False),
|
||||
sa.Column('price_code', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('billing_sessions')
|
||||
op.execute('DROP TYPE billing_session_status_enum')
|
||||
26
enterprise/migrations/versions/005_add_margin_column.py
Normal file
26
enterprise/migrations/versions/005_add_margin_column.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""add margin column
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 004
|
||||
Create Date: 2025-02-10 08:36:49.475467
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '005'
|
||||
down_revision: Union[str, None] = '004'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column('settings', sa.Column('margin', sa.Float(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('settings', 'margin')
|
||||
@@ -0,0 +1,29 @@
|
||||
"""add branch column to convo metadata table
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 005
|
||||
Create Date: 2025-02-11 14:59:09.415
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '006'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'conversation_metadata',
|
||||
sa.Column('selected_branch', sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('conversation_metadata', 'selected_branch')
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add enable_sound_notifications column to settings table
|
||||
|
||||
Revision ID: 007
|
||||
Revises: 006
|
||||
Create Date: 2025-05-01 10:00:00.000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '007'
|
||||
down_revision: Union[str, None] = '006'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
'settings',
|
||||
sa.Column(
|
||||
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column('settings', 'enable_sound_notifications')
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user