mirror of
https://github.com/All-Hands-AI/OpenHands.git
synced 2026-01-09 14:57:59 -05:00
Enterprise code and docker build (#10770)
This commit is contained in:
95
.github/workflows/ghcr-build.yml
vendored
95
.github/workflows/ghcr-build.yml
vendored
@@ -10,14 +10,14 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- main
|
- main
|
||||||
tags:
|
tags:
|
||||||
- '*'
|
- "*"
|
||||||
pull_request:
|
pull_request:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
reason:
|
reason:
|
||||||
description: 'Reason for manual trigger'
|
description: "Reason for manual trigger"
|
||||||
required: true
|
required: true
|
||||||
default: ''
|
default: ""
|
||||||
|
|
||||||
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
|
# If triggered by a PR, it will be in the same group. However, each commit on main will be in its own unique group
|
||||||
concurrency:
|
concurrency:
|
||||||
@@ -120,7 +120,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: useblacksmith/setup-python@v6
|
uses: useblacksmith/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: "3.12"
|
||||||
cache: poetry
|
cache: poetry
|
||||||
- name: Install Python dependencies using Poetry
|
- name: Install Python dependencies using Poetry
|
||||||
run: make install-python-dependencies POETRY_GROUP=main INSTALL_PLAYWRIGHT=0
|
run: make install-python-dependencies POETRY_GROUP=main INSTALL_PLAYWRIGHT=0
|
||||||
@@ -166,6 +166,89 @@ jobs:
|
|||||||
name: runtime-src-${{ matrix.base_image.tag }}
|
name: runtime-src-${{ matrix.base_image.tag }}
|
||||||
path: containers/runtime
|
path: containers/runtime
|
||||||
|
|
||||||
|
ghcr_build_enterprise:
|
||||||
|
name: Push Enterprise Image
|
||||||
|
runs-on: blacksmith-8vcpu-ubuntu-2204
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
needs: [define-matrix, ghcr_build_app]
|
||||||
|
# Do not build enterprise in forks
|
||||||
|
if: github.event.pull_request.head.repo.fork != true
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
# Set up Docker Buildx for better performance
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
with:
|
||||||
|
driver-opts: network=host
|
||||||
|
|
||||||
|
- name: Login to GHCR
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.repository_owner }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract metadata (tags, labels) for Docker
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ghcr.io/all-hands-ai/enterprise-server
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=ref,event=pr
|
||||||
|
type=sha
|
||||||
|
type=sha,format=long
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
|
type=semver,pattern={{major}}
|
||||||
|
flavor: |
|
||||||
|
latest=auto
|
||||||
|
prefix=
|
||||||
|
suffix=
|
||||||
|
- name: Determine app image tag
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
# Duplicated with build.sh
|
||||||
|
sanitized_ref_name=$(echo "$GITHUB_REF_NAME" | sed 's/[^a-zA-Z0-9.-]\+/-/g')
|
||||||
|
OPENHANDS_BUILD_VERSION=$sanitized_ref_name
|
||||||
|
sanitized_ref_name=$(echo "$sanitized_ref_name" | tr '[:upper:]' '[:lower:]') # lower case is required in tagging
|
||||||
|
echo "OPENHANDS_DOCKER_TAG=${sanitized_ref_name}" >> $GITHUB_ENV
|
||||||
|
- name: Build and push Docker image
|
||||||
|
uses: useblacksmith/build-push-action@v1
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: enterprise/Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
build-args: |
|
||||||
|
OPENHANDS_VERSION=${{ env.OPENHANDS_DOCKER_TAG }}
|
||||||
|
platforms: linux/amd64
|
||||||
|
# Add build provenance
|
||||||
|
provenance: true
|
||||||
|
# Add build attestations for better security
|
||||||
|
sbom: true
|
||||||
|
|
||||||
|
enterprise-preview:
|
||||||
|
name: Enterprise preview
|
||||||
|
if: |
|
||||||
|
(github.event_name == 'pull_request' && github.event.action == 'labeled' && github.event.label.name == 'deploy') ||
|
||||||
|
(github.event_name == 'pull_request' && github.event.action != 'labeled' && contains(github.event.pull_request.labels.*.name, 'deploy'))
|
||||||
|
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||||
|
needs: [ghcr_build_enterprise]
|
||||||
|
steps:
|
||||||
|
- name: Trigger remote job
|
||||||
|
run: |
|
||||||
|
curl --fail-with-body -sS -X POST \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.PAT_TOKEN }}" \
|
||||||
|
-H "Accept: application/vnd.github+json" \
|
||||||
|
-d "{\"ref\": \"main\", \"inputs\": {\"openhandsPrNumber\": \"${{ github.event.pull_request.number }}\", \"deployEnvironment\": \"feature\", \"enterpriseImageTag\": \"pr-${{ github.event.pull_request.number }}\" }}" \
|
||||||
|
https://api.github.com/repos/All-Hands-AI/deploy/actions/workflows/deploy.yaml/dispatches
|
||||||
|
|
||||||
# Run unit tests with the Docker runtime Docker images as root
|
# Run unit tests with the Docker runtime Docker images as root
|
||||||
test_runtime_root:
|
test_runtime_root:
|
||||||
name: RT Unit Tests (Root)
|
name: RT Unit Tests (Root)
|
||||||
@@ -202,7 +285,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: useblacksmith/setup-python@v6
|
uses: useblacksmith/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: "3.12"
|
||||||
cache: poetry
|
cache: poetry
|
||||||
- name: Install Python dependencies using Poetry
|
- name: Install Python dependencies using Poetry
|
||||||
run: make install-python-dependencies INSTALL_PLAYWRIGHT=0
|
run: make install-python-dependencies INSTALL_PLAYWRIGHT=0
|
||||||
@@ -264,7 +347,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: useblacksmith/setup-python@v6
|
uses: useblacksmith/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.12'
|
python-version: "3.12"
|
||||||
cache: poetry
|
cache: poetry
|
||||||
- name: Install Python dependencies using Poetry
|
- name: Install Python dependencies using Poetry
|
||||||
run: make install-python-dependencies POETRY_GROUP=main,test,runtime INSTALL_PLAYWRIGHT=0
|
run: make install-python-dependencies POETRY_GROUP=main,test,runtime INSTALL_PLAYWRIGHT=0
|
||||||
|
|||||||
18
.github/workflows/lint.yml
vendored
18
.github/workflows/lint.yml
vendored
@@ -55,6 +55,24 @@ jobs:
|
|||||||
- name: Run pre-commit hooks
|
- name: Run pre-commit hooks
|
||||||
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
run: pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||||
|
|
||||||
|
lint-enterprise-python:
|
||||||
|
name: Lint enterprise python
|
||||||
|
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
- name: Set up python
|
||||||
|
uses: useblacksmith/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: 3.12
|
||||||
|
cache: "pip"
|
||||||
|
- name: Install pre-commit
|
||||||
|
run: pip install pre-commit==4.2.0
|
||||||
|
- name: Run pre-commit hooks
|
||||||
|
working-directory: ./enterprise
|
||||||
|
run: pre-commit run --all-files --config ./dev_config/python/.pre-commit-config.yaml
|
||||||
|
|
||||||
# Check version consistency across documentation
|
# Check version consistency across documentation
|
||||||
check-version-consistency:
|
check-version-consistency:
|
||||||
name: Check version consistency
|
name: Check version consistency
|
||||||
|
|||||||
33
.github/workflows/py-tests.yml
vendored
33
.github/workflows/py-tests.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
|||||||
name: Python Tests on Linux
|
name: Python Tests on Linux
|
||||||
runs-on: blacksmith-4vcpu-ubuntu-2204
|
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||||
env:
|
env:
|
||||||
INSTALL_DOCKER: '0' # Set to '0' to skip Docker installation
|
INSTALL_DOCKER: "0" # Set to '0' to skip Docker installation
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.12']
|
python-version: ["3.12"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
@@ -35,14 +35,14 @@ jobs:
|
|||||||
- name: Setup Node.js
|
- name: Setup Node.js
|
||||||
uses: useblacksmith/setup-node@v5
|
uses: useblacksmith/setup-node@v5
|
||||||
with:
|
with:
|
||||||
node-version: '22.x'
|
node-version: "22.x"
|
||||||
- name: Install poetry via pipx
|
- name: Install poetry via pipx
|
||||||
run: pipx install poetry
|
run: pipx install poetry
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: useblacksmith/setup-python@v6
|
uses: useblacksmith/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: 'poetry'
|
cache: "poetry"
|
||||||
- name: Install Python dependencies using Poetry
|
- name: Install Python dependencies using Poetry
|
||||||
run: poetry install --with dev,test,runtime
|
run: poetry install --with dev,test,runtime
|
||||||
- name: Build Environment
|
- name: Build Environment
|
||||||
@@ -58,7 +58,7 @@ jobs:
|
|||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
python-version: ['3.12']
|
python-version: ["3.12"]
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Install pipx
|
- name: Install pipx
|
||||||
@@ -69,7 +69,7 @@ jobs:
|
|||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
cache: 'poetry'
|
cache: "poetry"
|
||||||
- name: Install Python dependencies using Poetry
|
- name: Install Python dependencies using Poetry
|
||||||
run: poetry install --with dev,test,runtime
|
run: poetry install --with dev,test,runtime
|
||||||
- name: Run Windows unit tests
|
- name: Run Windows unit tests
|
||||||
@@ -83,3 +83,24 @@ jobs:
|
|||||||
PYTHONPATH: ".;$env:PYTHONPATH"
|
PYTHONPATH: ".;$env:PYTHONPATH"
|
||||||
TEST_RUNTIME: local
|
TEST_RUNTIME: local
|
||||||
DEBUG: "1"
|
DEBUG: "1"
|
||||||
|
test-enterprise:
|
||||||
|
name: Enterprise Python Unit Tests
|
||||||
|
runs-on: blacksmith-4vcpu-ubuntu-2204
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
python-version: ["3.12"]
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Install poetry via pipx
|
||||||
|
run: pipx install poetry
|
||||||
|
- name: Set up Python
|
||||||
|
uses: useblacksmith/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
cache: "poetry"
|
||||||
|
- name: Install Python dependencies using Poetry
|
||||||
|
working-directory: ./enterprise
|
||||||
|
run: poetry install --with dev,test
|
||||||
|
- name: Run Unit Tests
|
||||||
|
working-directory: ./enterprise
|
||||||
|
run: PYTHONPATH=".:$PYTHONPATH" poetry run pytest --forked -n auto -svv -p no:ddtrace -p no:ddtrace.pytest_bdd -p no:ddtrace.pytest_benchmark ./tests/unit
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ repos:
|
|||||||
rev: v5.0.0
|
rev: v5.0.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: trailing-whitespace
|
- id: trailing-whitespace
|
||||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/)
|
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/)
|
exclude: ^(docs/|modules/|python/|openhands-ui/|third_party/|enterprise/)
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
args: ["--allow-multiple-documents"]
|
args: ["--allow-multiple-documents"]
|
||||||
- id: debug-statements
|
- id: debug-statements
|
||||||
@@ -28,19 +28,28 @@ repos:
|
|||||||
entry: ruff check --config dev_config/python/ruff.toml
|
entry: ruff check --config dev_config/python/ruff.toml
|
||||||
types_or: [python, pyi, jupyter]
|
types_or: [python, pyi, jupyter]
|
||||||
args: [--fix, --unsafe-fixes]
|
args: [--fix, --unsafe-fixes]
|
||||||
exclude: third_party/
|
exclude: ^(third_party/|enterprise/)
|
||||||
# Run the formatter.
|
# Run the formatter.
|
||||||
- id: ruff-format
|
- id: ruff-format
|
||||||
entry: ruff format --config dev_config/python/ruff.toml
|
entry: ruff format --config dev_config/python/ruff.toml
|
||||||
types_or: [python, pyi, jupyter]
|
types_or: [python, pyi, jupyter]
|
||||||
exclude: third_party/
|
exclude: ^(third_party/|enterprise/)
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.15.0
|
rev: v1.15.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
[types-requests, types-setuptools, types-pyyaml, types-toml, types-docker, types-Markdown, pydantic, lxml]
|
[
|
||||||
|
types-requests,
|
||||||
|
types-setuptools,
|
||||||
|
types-pyyaml,
|
||||||
|
types-toml,
|
||||||
|
types-docker,
|
||||||
|
types-Markdown,
|
||||||
|
pydantic,
|
||||||
|
lxml,
|
||||||
|
]
|
||||||
# To see gaps add `--html-report mypy-report/`
|
# To see gaps add `--html-report mypy-report/`
|
||||||
entry: mypy --config-file dev_config/python/mypy.ini openhands/
|
entry: mypy --config-file dev_config/python/mypy.ini openhands/
|
||||||
always_run: true
|
always_run: true
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ no_implicit_optional = True
|
|||||||
strict_optional = True
|
strict_optional = True
|
||||||
|
|
||||||
# Exclude third-party runtime directory from type checking
|
# Exclude third-party runtime directory from type checking
|
||||||
exclude = third_party/
|
exclude = (third_party/|enterprise/)
|
||||||
|
|
||||||
[mypy-openhands.memory.condenser.impl.*]
|
[mypy-openhands.memory.condenser.impl.*]
|
||||||
disable_error_code = override
|
disable_error_code = override
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Exclude third-party runtime directory from linting
|
# Exclude third-party runtime directory from linting
|
||||||
exclude = ["third_party/"]
|
exclude = ["third_party/", "enterprise/"]
|
||||||
|
|
||||||
[lint]
|
[lint]
|
||||||
select = [
|
select = [
|
||||||
|
|||||||
26
enterprise/Dockerfile
Normal file
26
enterprise/Dockerfile
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
ARG OPENHANDS_VERSION=latest
|
||||||
|
ARG BASE="ghcr.io/all-hands-ai/openhands"
|
||||||
|
FROM ${BASE}:${OPENHANDS_VERSION}
|
||||||
|
|
||||||
|
# Datadog labels
|
||||||
|
LABEL com.datadoghq.tags.service="deploy"
|
||||||
|
LABEL com.datadoghq.tags.env="${DD_ENV}"
|
||||||
|
|
||||||
|
# Install Node.js v20+ and npm (which includes npx)
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y curl && \
|
||||||
|
curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \
|
||||||
|
apt-get install -y nodejs && \
|
||||||
|
apt-get install -y jq gettext && \
|
||||||
|
apt-get clean
|
||||||
|
|
||||||
|
RUN pip install alembic psycopg2-binary cloud-sql-python-connector pg8000 gspread stripe python-keycloak asyncpg sqlalchemy[asyncio] resend tenacity slack-sdk ddtrace posthog "limits==5.2.0" coredis prometheus-client shap scikit-learn pandas numpy
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY enterprise .
|
||||||
|
|
||||||
|
RUN chown -R openhands:openhands /app && chmod -R 770 /app
|
||||||
|
USER openhands
|
||||||
|
|
||||||
|
# Command will be overridden by Kubernetes deployment template
|
||||||
|
CMD ["uvicorn", "saas_server:app", "--host", "0.0.0.0", "--port", "3000"]
|
||||||
42
enterprise/Makefile
Normal file
42
enterprise/Makefile
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
BACKEND_HOST ?= "127.0.0.1"
|
||||||
|
BACKEND_PORT = 3000
|
||||||
|
BACKEND_HOST_PORT = "$(BACKEND_HOST):$(BACKEND_PORT)"
|
||||||
|
FRONTEND_PORT = 3001
|
||||||
|
OPENHANDS_PATH ?= "../../OpenHands"
|
||||||
|
OPENHANDS := $(OPENHANDS_PATH)
|
||||||
|
OPENHANDS_FRONTEND_PATH = $(OPENHANDS)/frontend/build
|
||||||
|
|
||||||
|
# ANSI color codes
|
||||||
|
GREEN=$(shell tput -Txterm setaf 2)
|
||||||
|
YELLOW=$(shell tput -Txterm setaf 3)
|
||||||
|
RED=$(shell tput -Txterm setaf 1)
|
||||||
|
BLUE=$(shell tput -Txterm setaf 6)
|
||||||
|
RESET=$(shell tput -Txterm sgr0)
|
||||||
|
|
||||||
|
build:
|
||||||
|
@poetry install
|
||||||
|
@cd $(OPENHANDS) && $(MAKE) build
|
||||||
|
|
||||||
|
|
||||||
|
_run_setup:
|
||||||
|
@echo "$(YELLOW)Starting backend server...$(RESET)"
|
||||||
|
@cd app && FRONTEND_DIRECTORY=$(OPENHANDS_FRONTEND_PATH) poetry run uvicorn saas_server:app --host $(BACKEND_HOST) --port $(BACKEND_PORT) &
|
||||||
|
@echo "$(YELLOW)Waiting for the backend to start...$(RESET)"
|
||||||
|
@until nc -z localhost $(BACKEND_PORT); do sleep 0.1; done
|
||||||
|
@echo "$(GREEN)Backend started successfully.$(RESET)"
|
||||||
|
|
||||||
|
run:
|
||||||
|
@echo "$(YELLOW)Running the app...$(RESET)"
|
||||||
|
@$(MAKE) -s _run_setup
|
||||||
|
@cd $(OPENHANDS) && $(MAKE) -s start-frontend
|
||||||
|
@echo "$(GREEN)Application started successfully.$(RESET)"
|
||||||
|
|
||||||
|
# Start backend
|
||||||
|
start-backend:
|
||||||
|
@echo "$(YELLOW)Starting backend...$(RESET)"
|
||||||
|
@echo "$(OPENHANDS_FRONTEND_PATH)"
|
||||||
|
@cd app && FRONTEND_DIRECTORY=$(OPENHANDS_FRONTEND_PATH) poetry run uvicorn saas_server:app --host $(BACKEND_HOST) --port $(BACKEND_PORT) --reload-dir $(OPENHANDS_PATH) --reload --reload-dir ./ --reload-exclude "./workspace"
|
||||||
|
|
||||||
|
|
||||||
|
lint:
|
||||||
|
@poetry run pre-commit run --all-files --show-diff-on-failure --config ./dev_config/python/.pre-commit-config.yaml
|
||||||
44
enterprise/README.md
Normal file
44
enterprise/README.md
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# Closed Source extension of Openhands proper (OSS)
|
||||||
|
|
||||||
|
The closed source (CSS) code in the `/app` directory builds on top of open source (OSS) code, extending its functionality. The CSS code is entangled with the OSS code in two ways
|
||||||
|
|
||||||
|
- CSS stacks on top of OSS. For example, the middleware in CSS is stacked right on top of the middlewares in OSS. In `SAAS`, the middleware from BOTH repos will be present and running (which can sometimes cause conflicts)
|
||||||
|
|
||||||
|
- CSS overrides the implementation in OSS (only one is present at a time). For example, the server config [`SaasServerConfig`](https://github.com/All-Hands-AI/deploy/blob/main/app/server/config.py#L43) which overrides [`ServerConfig`](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L8) on OSS. This is done through dynamic imports ([see here](https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/server/config/server_config.py#L37-#L45))
|
||||||
|
|
||||||
|
Key areas that change on `SAAS` are
|
||||||
|
|
||||||
|
- Authentication
|
||||||
|
- User settings
|
||||||
|
- etc
|
||||||
|
|
||||||
|
## Authentication
|
||||||
|
|
||||||
|
| Aspect | OSS | CSS |
|
||||||
|
| ------------------------- | ------------------------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| **Authentication Method** | User adds a personal access token (PAT) through the UI | User performs OAuth through the UI. The Github app provides a short-lived access token and refresh token |
|
||||||
|
| **Token Storage** | PAT is stored in **Settings** | Token is stored in **GithubTokenManager** (a file store in our backend) |
|
||||||
|
| **Authenticated status** | We simply check if token exists in `Settings` | We issue a signed cookie with `github_user_id` during oauth, so subsequent requests with the cookie can be considered authenticated |
|
||||||
|
|
||||||
|
Note that in the future, authentication will happen via keycloak. All modifications for authentication will happen in CSS.
|
||||||
|
|
||||||
|
## GitHub Service
|
||||||
|
|
||||||
|
The github service is responsible for interacting with Github APIs. As a consequence, it uses the user's token and refreshes it if need be
|
||||||
|
|
||||||
|
| Aspect | OSS | CSS |
|
||||||
|
| ------------------------- | -------------------------------------- | ---------------------------------------------- |
|
||||||
|
| **Class used** | `GitHubService` | `SaaSGitHubService` |
|
||||||
|
| **Token used** | User's PAT fetched from `Settings` | User's token fetched from `GitHubTokenManager` |
|
||||||
|
| **Refresh functionality** | **N/A**; user provides PAT for the app | Uses the `GitHubTokenManager` to refresh |
|
||||||
|
|
||||||
|
NOTE: in the future we will simply replace the `GithubTokenManager` with keycloak. The `SaaSGithubService` should interact with keycloack instead.
|
||||||
|
|
||||||
|
# Areas that are BRITTLE!
|
||||||
|
|
||||||
|
## User ID vs User Token
|
||||||
|
|
||||||
|
- On OSS, the entire APP revolves around the Github token the user sets. `openhands/server` uses `request.state.github_token` for the entire app
|
||||||
|
- On CSS, the entire APP resolves around the Github User ID. This is because the cookie sets it, so `openhands/server` AND `deploy/app/server` depend on it and completly ignore `request.state.github_token` (token is fetched from `GithubTokenManager` instead)
|
||||||
|
|
||||||
|
Note that introducing Github User ID on OSS, for instance, will cause large breakages.
|
||||||
1
enterprise/__init__.py
Normal file
1
enterprise/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# App package for OpenHands
|
||||||
79
enterprise/alembic.ini
Normal file
79
enterprise/alembic.ini
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = migrations
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python>=3.9 or backports.zoneinfo library.
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||||
|
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = DEBUG
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = DEBUG
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = DEBUG
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
2770
enterprise/allhands-realm-github-provider.json.tmpl
Normal file
2770
enterprise/allhands-realm-github-provider.json.tmpl
Normal file
File diff suppressed because it is too large
Load Diff
56
enterprise/dev_config/python/.pre-commit-config.yaml
Normal file
56
enterprise/dev_config/python/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.5.0
|
||||||
|
hooks:
|
||||||
|
- id: trailing-whitespace
|
||||||
|
exclude: docs/modules/python
|
||||||
|
files: ^enterprise/
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
exclude: docs/modules/python
|
||||||
|
files: ^enterprise/
|
||||||
|
- id: check-yaml
|
||||||
|
files: ^enterprise/
|
||||||
|
- id: debug-statements
|
||||||
|
files: ^enterprise/
|
||||||
|
- repo: https://github.com/abravalheri/validate-pyproject
|
||||||
|
rev: v0.16
|
||||||
|
hooks:
|
||||||
|
- id: validate-pyproject
|
||||||
|
types: [toml]
|
||||||
|
files: ^enterprise/pyproject\.toml$
|
||||||
|
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
# Ruff version.
|
||||||
|
rev: v0.4.1
|
||||||
|
hooks:
|
||||||
|
# Run the linter.
|
||||||
|
- id: ruff
|
||||||
|
entry: ruff check --config enterprise/dev_config/python/ruff.toml
|
||||||
|
types_or: [python, pyi, jupyter]
|
||||||
|
args: [--fix]
|
||||||
|
files: ^enterprise/
|
||||||
|
# Run the formatter.
|
||||||
|
- id: ruff-format
|
||||||
|
entry: ruff format --config enterprise/dev_config/python/ruff.toml
|
||||||
|
types_or: [python, pyi, jupyter]
|
||||||
|
files: ^enterprise/
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v1.9.0
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
additional_dependencies:
|
||||||
|
- types-requests
|
||||||
|
- types-setuptools
|
||||||
|
- types-pyyaml
|
||||||
|
- types-toml
|
||||||
|
- types-redis
|
||||||
|
- lxml
|
||||||
|
# TODO: Add OpenHands in parent
|
||||||
|
- stripe==11.5.0
|
||||||
|
- pygithub==2.6.1
|
||||||
|
# To see gaps add `--html-report mypy-report/`
|
||||||
|
entry: mypy --config-file enterprise/dev_config/python/mypy.ini enterprise/
|
||||||
|
always_run: true
|
||||||
|
pass_filenames: false
|
||||||
|
files: ^enterprise/
|
||||||
21
enterprise/dev_config/python/mypy.ini
Normal file
21
enterprise/dev_config/python/mypy.ini
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[mypy]
|
||||||
|
warn_unused_configs = True
|
||||||
|
ignore_missing_imports = True
|
||||||
|
check_untyped_defs = True
|
||||||
|
explicit_package_bases = True
|
||||||
|
warn_unreachable = True
|
||||||
|
warn_redundant_casts = True
|
||||||
|
no_implicit_optional = True
|
||||||
|
strict_optional = True
|
||||||
|
exclude = (^enterprise/migrations/.*|^openhands/.*)
|
||||||
|
|
||||||
|
[mypy-enterprise.tests.unit.test_auth_routes.*]
|
||||||
|
disable_error_code = union-attr
|
||||||
|
|
||||||
|
[mypy-enterprise.sync.install_gitlab_webhooks.*]
|
||||||
|
disable_error_code = redundant-cast
|
||||||
|
|
||||||
|
# Let the other config check base openhands packages
|
||||||
|
[mypy-openhands.*]
|
||||||
|
follow_imports = skip
|
||||||
|
ignore_missing_imports = True
|
||||||
31
enterprise/dev_config/python/ruff.toml
Normal file
31
enterprise/dev_config/python/ruff.toml
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
[lint]
|
||||||
|
select = [
|
||||||
|
"E",
|
||||||
|
"W",
|
||||||
|
"F",
|
||||||
|
"I",
|
||||||
|
"Q",
|
||||||
|
"B",
|
||||||
|
]
|
||||||
|
|
||||||
|
ignore = [
|
||||||
|
"E501",
|
||||||
|
"B003",
|
||||||
|
"B007",
|
||||||
|
"B008", # Allow function calls in argument defaults (FastAPI Query pattern)
|
||||||
|
"B009",
|
||||||
|
"B010",
|
||||||
|
"B904",
|
||||||
|
"B018",
|
||||||
|
]
|
||||||
|
|
||||||
|
exclude = [
|
||||||
|
"app/migrations/*"
|
||||||
|
]
|
||||||
|
|
||||||
|
[lint.flake8-quotes]
|
||||||
|
docstring-quotes = "double"
|
||||||
|
inline-quotes = "single"
|
||||||
|
|
||||||
|
[format]
|
||||||
|
quote-style = "single"
|
||||||
0
enterprise/experiments/__init__.py
Normal file
0
enterprise/experiments/__init__.py
Normal file
47
enterprise/experiments/constants.py
Normal file
47
enterprise/experiments/constants.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import posthog
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
|
||||||
|
# Initialize PostHog
|
||||||
|
posthog.api_key = os.environ.get('POSTHOG_CLIENT_KEY', 'phc_placeholder')
|
||||||
|
posthog.host = os.environ.get('POSTHOG_HOST', 'https://us.i.posthog.com')
|
||||||
|
|
||||||
|
# Log PostHog configuration with masked API key for security
|
||||||
|
api_key = posthog.api_key
|
||||||
|
if api_key and len(api_key) > 8:
|
||||||
|
masked_key = f'{api_key[:4]}...{api_key[-4:]}'
|
||||||
|
else:
|
||||||
|
masked_key = 'not_set_or_too_short'
|
||||||
|
logger.info('posthog_configuration', extra={'posthog_api_key_masked': masked_key})
|
||||||
|
|
||||||
|
# Global toggle for the experiment manager
|
||||||
|
ENABLE_EXPERIMENT_MANAGER = (
|
||||||
|
os.environ.get('ENABLE_EXPERIMENT_MANAGER', 'false').lower() == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the current experiment type from environment variable
|
||||||
|
# If None, no experiment is running
|
||||||
|
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT = os.environ.get(
|
||||||
|
'EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT', ''
|
||||||
|
)
|
||||||
|
# System prompt experiment toggle
|
||||||
|
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT = os.environ.get(
|
||||||
|
'EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT', ''
|
||||||
|
)
|
||||||
|
|
||||||
|
EXPERIMENT_CLAUDE4_VS_GPT5 = os.environ.get('EXPERIMENT_CLAUDE4_VS_GPT5', '')
|
||||||
|
|
||||||
|
EXPERIMENT_CONDENSER_MAX_STEP = os.environ.get('EXPERIMENT_CONDENSER_MAX_STEP', '')
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager:run_conversation_variant_test:experiment_config',
|
||||||
|
extra={
|
||||||
|
'enable_experiment_manager': ENABLE_EXPERIMENT_MANAGER,
|
||||||
|
'experiment_litellm_default_model_experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||||
|
'experiment_system_prompt_experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||||
|
'experiment_claude4_vs_gpt5_experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||||
|
'experiment_condenser_max_step': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||||
|
},
|
||||||
|
)
|
||||||
90
enterprise/experiments/experiment_manager.py
Normal file
90
enterprise/experiments/experiment_manager.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
from experiments.constants import (
|
||||||
|
ENABLE_EXPERIMENT_MANAGER,
|
||||||
|
)
|
||||||
|
from experiments.experiment_versions import (
|
||||||
|
handle_claude4_vs_gpt5_experiment,
|
||||||
|
handle_condenser_max_step_experiment,
|
||||||
|
handle_system_prompt_experiment,
|
||||||
|
)
|
||||||
|
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.experiments.experiment_manager import ExperimentManager
|
||||||
|
|
||||||
|
|
||||||
|
class SaaSExperimentManager(ExperimentManager):
|
||||||
|
@staticmethod
|
||||||
|
def run_conversation_variant_test(user_id, conversation_id, conversation_settings):
|
||||||
|
"""
|
||||||
|
Run conversation variant test and potentially modify the conversation settings
|
||||||
|
based on the PostHog feature flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
conversation_settings: The conversation settings that may include convo_id and llm_model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The modified conversation settings
|
||||||
|
"""
|
||||||
|
logger.debug(
|
||||||
|
'experiment_manager:run_conversation_variant_test:started',
|
||||||
|
extra={'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip all experiment processing if the experiment manager is disabled
|
||||||
|
if not ENABLE_EXPERIMENT_MANAGER:
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager:run_conversation_variant_test:skipped',
|
||||||
|
extra={'reason': 'experiment_manager_disabled'},
|
||||||
|
)
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
# Apply conversation-scoped experiments
|
||||||
|
conversation_settings = handle_claude4_vs_gpt5_experiment(
|
||||||
|
user_id, conversation_id, conversation_settings
|
||||||
|
)
|
||||||
|
conversation_settings = handle_condenser_max_step_experiment(
|
||||||
|
user_id, conversation_id, conversation_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_config_variant_test(
|
||||||
|
user_id: str, conversation_id: str, config: OpenHandsConfig
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run agent config variant test and potentially modify the OpenHands config
|
||||||
|
based on the current experiment type and PostHog feature flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
config: The OpenHands configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The modified OpenHands configuration
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager:run_config_variant_test:started',
|
||||||
|
extra={'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip all experiment processing if the experiment manager is disabled
|
||||||
|
if not ENABLE_EXPERIMENT_MANAGER:
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager:run_config_variant_test:skipped',
|
||||||
|
extra={'reason': 'experiment_manager_disabled'},
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Pass the entire OpenHands config to the system prompt experiment
|
||||||
|
# Let the experiment handler directly modify the config as needed
|
||||||
|
modified_config = handle_system_prompt_experiment(
|
||||||
|
user_id, conversation_id, config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Condenser max step experiment is applied via conversation variant test,
|
||||||
|
# not config variant test. Return modified config from system prompt only.
|
||||||
|
return modified_config
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
LiteLLM model experiment handler.
|
||||||
|
|
||||||
|
This module contains the handler for the LiteLLM model experiment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import posthog
|
||||||
|
from experiments.constants import EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT
|
||||||
|
from server.constants import (
|
||||||
|
IS_FEATURE_ENV,
|
||||||
|
build_litellm_proxy_model_path,
|
||||||
|
get_default_litellm_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
|
||||||
|
|
||||||
|
def handle_litellm_default_model_experiment(
|
||||||
|
user_id, conversation_id, conversation_settings
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handle the LiteLLM model experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
conversation_settings: The conversation settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified conversation settings
|
||||||
|
"""
|
||||||
|
# No-op if the specific experiment is not enabled
|
||||||
|
if not EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT:
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager:ab_testing:skipped',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'reason': 'experiment_not_enabled',
|
||||||
|
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
# Use experiment name as the flag key
|
||||||
|
try:
|
||||||
|
enabled_variant = posthog.get_feature_flag(
|
||||||
|
EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT, conversation_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:get_feature_flag:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
# Log the experiment event
|
||||||
|
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||||
|
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
posthog.capture(
|
||||||
|
distinct_id=posthog_user_id,
|
||||||
|
event='model_set',
|
||||||
|
properties={
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'original_user_id': user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:posthog_capture:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_LITELLM_DEFAULT_MODEL_EXPERIMENT,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Continue execution as this is not critical
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'posthog_capture',
|
||||||
|
extra={
|
||||||
|
'event': 'model_set',
|
||||||
|
'posthog_user_id': posthog_user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the model based on the feature flag variant
|
||||||
|
if enabled_variant == 'claude37':
|
||||||
|
# Use the shared utility to construct the LiteLLM proxy model path
|
||||||
|
model = build_litellm_proxy_model_path('claude-3-7-sonnet-20250219')
|
||||||
|
# Update the conversation settings with the selected model
|
||||||
|
conversation_settings.llm_model = model
|
||||||
|
else:
|
||||||
|
# Update the conversation settings with the default model for the current version
|
||||||
|
conversation_settings.llm_model = get_default_litellm_model()
|
||||||
|
|
||||||
|
return conversation_settings
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
System prompt experiment handler.
|
||||||
|
|
||||||
|
This module contains the handler for the system prompt experiment that uses
|
||||||
|
the PostHog variant as the system prompt filename.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import posthog
|
||||||
|
from experiments.constants import EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT
|
||||||
|
from server.constants import IS_FEATURE_ENV
|
||||||
|
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||||
|
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
|
||||||
|
|
||||||
|
def _get_system_prompt_variant(user_id, conversation_id):
|
||||||
|
"""
|
||||||
|
Get the system prompt variant for the experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||||
|
"""
|
||||||
|
# No-op if the specific experiment is not enabled
|
||||||
|
if not EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT:
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager_002:ab_testing:skipped',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'reason': 'experiment_not_enabled',
|
||||||
|
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use experiment name as the flag key
|
||||||
|
try:
|
||||||
|
enabled_variant = posthog.get_feature_flag(
|
||||||
|
EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT, conversation_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:get_feature_flag:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Store the experiment assignment in the database
|
||||||
|
try:
|
||||||
|
experiment_store = ExperimentAssignmentStore()
|
||||||
|
experiment_store.update_experiment_variant(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
experiment_name='system_prompt_experiment',
|
||||||
|
variant=enabled_variant,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:store_assignment:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Log the experiment event
|
||||||
|
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||||
|
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
posthog.capture(
|
||||||
|
distinct_id=posthog_user_id,
|
||||||
|
event='system_prompt_set',
|
||||||
|
properties={
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'original_user_id': user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:posthog_capture:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_SYSTEM_PROMPT_EXPERIMENT,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Continue execution as this is not critical
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'posthog_capture',
|
||||||
|
extra={
|
||||||
|
'event': 'system_prompt_set',
|
||||||
|
'posthog_user_id': posthog_user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return enabled_variant
|
||||||
|
|
||||||
|
|
||||||
|
def handle_system_prompt_experiment(
|
||||||
|
user_id, conversation_id, config: OpenHandsConfig
|
||||||
|
) -> OpenHandsConfig:
|
||||||
|
"""
|
||||||
|
Handle the system prompt experiment for OpenHands config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
config: The OpenHands configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified OpenHands configuration
|
||||||
|
"""
|
||||||
|
enabled_variant = _get_system_prompt_variant(user_id, conversation_id)
|
||||||
|
|
||||||
|
# If variant is None, experiment is not enabled or there was an error
|
||||||
|
if enabled_variant is None:
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Deep copy the config to avoid modifying the original
|
||||||
|
modified_config = copy.deepcopy(config)
|
||||||
|
|
||||||
|
# Set the system prompt filename based on the variant
|
||||||
|
if enabled_variant == 'control':
|
||||||
|
# Use the long-horizon system prompt for the control variant
|
||||||
|
agent_config = modified_config.get_agent_config(modified_config.default_agent)
|
||||||
|
agent_config.system_prompt_filename = 'system_prompt_long_horizon.j2'
|
||||||
|
agent_config.enable_plan_mode = True
|
||||||
|
elif enabled_variant == 'interactive':
|
||||||
|
modified_config.get_agent_config(
|
||||||
|
modified_config.default_agent
|
||||||
|
).system_prompt_filename = 'system_prompt_interactive.j2'
|
||||||
|
elif enabled_variant == 'no_tools':
|
||||||
|
modified_config.get_agent_config(
|
||||||
|
modified_config.default_agent
|
||||||
|
).system_prompt_filename = 'system_prompt.j2'
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
'system_prompt_experiment:unknown_variant',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'reason': 'no explicit mapping; returning original config',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Log which prompt is being used
|
||||||
|
logger.info(
|
||||||
|
'system_prompt_experiment:prompt_selected',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'system_prompt_filename': modified_config.get_agent_config(
|
||||||
|
modified_config.default_agent
|
||||||
|
).system_prompt_filename,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return modified_config
|
||||||
@@ -0,0 +1,132 @@
|
|||||||
|
"""
|
||||||
|
LiteLLM model experiment handler.
|
||||||
|
|
||||||
|
This module contains the handler for the LiteLLM model experiment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import posthog
|
||||||
|
from experiments.constants import EXPERIMENT_CLAUDE4_VS_GPT5
|
||||||
|
from server.constants import (
|
||||||
|
IS_FEATURE_ENV,
|
||||||
|
build_litellm_proxy_model_path,
|
||||||
|
get_default_litellm_model,
|
||||||
|
)
|
||||||
|
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_variant(user_id, conversation_id) -> str | None:
|
||||||
|
if not EXPERIMENT_CLAUDE4_VS_GPT5:
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager:ab_testing:skipped',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'reason': 'experiment_not_enabled',
|
||||||
|
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
enabled_variant = posthog.get_feature_flag(
|
||||||
|
EXPERIMENT_CLAUDE4_VS_GPT5, conversation_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:get_feature_flag:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Store the experiment assignment in the database
|
||||||
|
try:
|
||||||
|
experiment_store = ExperimentAssignmentStore()
|
||||||
|
experiment_store.update_experiment_variant(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
experiment_name='claude4_vs_gpt5_experiment',
|
||||||
|
variant=enabled_variant,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:store_assignment:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Log the experiment event
|
||||||
|
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||||
|
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
posthog.capture(
|
||||||
|
distinct_id=posthog_user_id,
|
||||||
|
event='claude4_or_gpt5_set',
|
||||||
|
properties={
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'original_user_id': user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:posthog_capture:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_CLAUDE4_VS_GPT5,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Continue execution as this is not critical
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'posthog_capture',
|
||||||
|
extra={
|
||||||
|
'event': 'claude4_or_gpt5_set',
|
||||||
|
'posthog_user_id': posthog_user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return enabled_variant
|
||||||
|
|
||||||
|
|
||||||
|
def handle_claude4_vs_gpt5_experiment(user_id, conversation_id, conversation_settings):
|
||||||
|
"""
|
||||||
|
Handle the LiteLLM model experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
conversation_settings: The conversation settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified conversation settings
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled_variant = _get_model_variant(user_id, conversation_id)
|
||||||
|
|
||||||
|
if not enabled_variant:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Set the model based on the feature flag variant
|
||||||
|
if enabled_variant == 'gpt5':
|
||||||
|
model = build_litellm_proxy_model_path('gpt-5-2025-08-07')
|
||||||
|
conversation_settings.llm_model = model
|
||||||
|
else:
|
||||||
|
conversation_settings.llm_model = get_default_litellm_model()
|
||||||
|
|
||||||
|
return conversation_settings
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
Condenser max step experiment handler.
|
||||||
|
|
||||||
|
This module contains the handler for the condenser max step experiment that tests
|
||||||
|
different max_size values for the condenser configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import posthog
|
||||||
|
from experiments.constants import EXPERIMENT_CONDENSER_MAX_STEP
|
||||||
|
from server.constants import IS_FEATURE_ENV
|
||||||
|
from storage.experiment_assignment_store import ExperimentAssignmentStore
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
|
||||||
|
|
||||||
|
def _get_condenser_max_step_variant(user_id, conversation_id):
|
||||||
|
"""
|
||||||
|
Get the condenser max step variant for the experiment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID
|
||||||
|
conversation_id: The conversation ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: The PostHog variant name or None if experiment is not enabled or error occurs
|
||||||
|
"""
|
||||||
|
# No-op if the specific experiment is not enabled
|
||||||
|
if not EXPERIMENT_CONDENSER_MAX_STEP:
|
||||||
|
logger.info(
|
||||||
|
'experiment_manager_004:ab_testing:skipped',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'reason': 'experiment_not_enabled',
|
||||||
|
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use experiment name as the flag key
|
||||||
|
try:
|
||||||
|
enabled_variant = posthog.get_feature_flag(
|
||||||
|
EXPERIMENT_CONDENSER_MAX_STEP, conversation_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:get_feature_flag:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Store the experiment assignment in the database
|
||||||
|
try:
|
||||||
|
experiment_store = ExperimentAssignmentStore()
|
||||||
|
experiment_store.update_experiment_variant(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
experiment_name='condenser_max_step_experiment',
|
||||||
|
variant=enabled_variant,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:store_assignment:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Fail the experiment if we cannot track the splits - results would not be explainable
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Log the experiment event
|
||||||
|
# If this is a feature environment, add "FEATURE_" prefix to user_id for PostHog
|
||||||
|
posthog_user_id = f'FEATURE_{user_id}' if IS_FEATURE_ENV else user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
posthog.capture(
|
||||||
|
distinct_id=posthog_user_id,
|
||||||
|
event='condenser_max_step_set',
|
||||||
|
properties={
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'original_user_id': user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'experiment_manager:posthog_capture:failed',
|
||||||
|
extra={
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'experiment': EXPERIMENT_CONDENSER_MAX_STEP,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Continue execution as this is not critical
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'posthog_capture',
|
||||||
|
extra={
|
||||||
|
'event': 'condenser_max_step_set',
|
||||||
|
'posthog_user_id': posthog_user_id,
|
||||||
|
'is_feature_env': IS_FEATURE_ENV,
|
||||||
|
'conversation_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return enabled_variant
|
||||||
|
|
||||||
|
|
||||||
|
def handle_condenser_max_step_experiment(
|
||||||
|
user_id: str, conversation_id: str, conversation_settings
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Handle the condenser max step experiment for conversation settings.
|
||||||
|
|
||||||
|
We should not modify persistent user settings. Instead, apply the experiment
|
||||||
|
variant to the conversation's in-memory settings object for this session only.
|
||||||
|
|
||||||
|
Variants:
|
||||||
|
- control -> condenser_max_size = 120
|
||||||
|
- treatment -> condenser_max_size = 80
|
||||||
|
|
||||||
|
Returns the (potentially) modified conversation_settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled_variant = _get_condenser_max_step_variant(user_id, conversation_id)
|
||||||
|
|
||||||
|
if enabled_variant is None:
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
if enabled_variant == 'control':
|
||||||
|
condenser_max_size = 120
|
||||||
|
elif enabled_variant == 'treatment':
|
||||||
|
condenser_max_size = 80
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
'condenser_max_step_experiment:unknown_variant',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'reason': 'unknown variant; returning original conversation settings',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Apply the variant to this conversation only; do not persist to DB.
|
||||||
|
# Not all OpenHands versions expose `condenser_max_size` on settings.
|
||||||
|
if hasattr(conversation_settings, 'condenser_max_size'):
|
||||||
|
conversation_settings.condenser_max_size = condenser_max_size
|
||||||
|
logger.info(
|
||||||
|
'condenser_max_step_experiment:conversation_settings_applied',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'condenser_max_size': condenser_max_size,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'condenser_max_step_experiment:field_missing_on_settings',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'reason': 'condenser_max_size not present on ConversationInitData',
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'condenser_max_step_experiment:apply_failed',
|
||||||
|
extra={
|
||||||
|
'user_id': user_id,
|
||||||
|
'convo_id': conversation_id,
|
||||||
|
'variant': enabled_variant,
|
||||||
|
'error': str(e),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return conversation_settings
|
||||||
|
|
||||||
|
return conversation_settings
|
||||||
25
enterprise/experiments/experiment_versions/__init__.py
Normal file
25
enterprise/experiments/experiment_versions/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
Experiment versions package.
|
||||||
|
|
||||||
|
This package contains handlers for different experiment versions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from experiments.experiment_versions._001_litellm_default_model_experiment import (
|
||||||
|
handle_litellm_default_model_experiment,
|
||||||
|
)
|
||||||
|
from experiments.experiment_versions._002_system_prompt_experiment import (
|
||||||
|
handle_system_prompt_experiment,
|
||||||
|
)
|
||||||
|
from experiments.experiment_versions._003_llm_claude4_vs_gpt5_experiment import (
|
||||||
|
handle_claude4_vs_gpt5_experiment,
|
||||||
|
)
|
||||||
|
from experiments.experiment_versions._004_condenser_max_step_experiment import (
|
||||||
|
handle_condenser_max_step_experiment,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'handle_litellm_default_model_experiment',
|
||||||
|
'handle_system_prompt_experiment',
|
||||||
|
'handle_claude4_vs_gpt5_experiment',
|
||||||
|
'handle_condenser_max_step_experiment',
|
||||||
|
]
|
||||||
0
enterprise/integrations/__init__.py
Normal file
0
enterprise/integrations/__init__.py
Normal file
70
enterprise/integrations/bitbucket/bitbucket_service.py
Normal file
70
enterprise/integrations/bitbucket/bitbucket_service.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.bitbucket.bitbucket_service import BitBucketService
|
||||||
|
from openhands.integrations.service_types import ProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class SaaSBitBucketService(BitBucketService):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
external_auth_token: SecretStr | None = None,
|
||||||
|
external_auth_id: str | None = None,
|
||||||
|
token: SecretStr | None = None,
|
||||||
|
external_token_manager: bool = False,
|
||||||
|
base_domain: str | None = None,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f'SaaSBitBucketService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, bitbucket_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
user_id=user_id,
|
||||||
|
external_auth_token=external_auth_token,
|
||||||
|
external_auth_id=external_auth_id,
|
||||||
|
token=token,
|
||||||
|
external_token_manager=external_token_manager,
|
||||||
|
base_domain=base_domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.external_auth_token = external_auth_token
|
||||||
|
self.external_auth_id = external_auth_id
|
||||||
|
self.token_manager = TokenManager(external=external_token_manager)
|
||||||
|
|
||||||
|
async def get_latest_token(self) -> SecretStr | None:
|
||||||
|
bitbucket_token = None
|
||||||
|
if self.external_auth_token:
|
||||||
|
bitbucket_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token(
|
||||||
|
self.external_auth_token.get_secret_value(),
|
||||||
|
idp=ProviderType.BITBUCKET,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got BitBucket token {bitbucket_token} from access token: {self.external_auth_token}'
|
||||||
|
)
|
||||||
|
elif self.external_auth_id:
|
||||||
|
offline_token = await self.token_manager.load_offline_token(
|
||||||
|
self.external_auth_id
|
||||||
|
)
|
||||||
|
bitbucket_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token_from_offline_token(
|
||||||
|
offline_token, ProviderType.BITBUCKET
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'Got BitBucket token {bitbucket_token.get_secret_value()} from external auth user ID: {self.external_auth_id}'
|
||||||
|
)
|
||||||
|
elif self.user_id:
|
||||||
|
bitbucket_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||||
|
self.user_id, ProviderType.BITBUCKET
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got BitBucket token {bitbucket_token} from user ID: {self.user_id}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning('external_auth_token and user_id not set!')
|
||||||
|
return bitbucket_token
|
||||||
692
enterprise/integrations/github/data_collector.py
Normal file
692
enterprise/integrations/github/data_collector.py
Normal file
@@ -0,0 +1,692 @@
|
|||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from github import Github, GithubIntegration
|
||||||
|
from integrations.github.github_view import (
|
||||||
|
GithubIssue,
|
||||||
|
)
|
||||||
|
from integrations.github.queries import PR_QUERY_BY_NODE_ID
|
||||||
|
from integrations.models import Message
|
||||||
|
from integrations.types import PRStatus, ResolverViewInterface
|
||||||
|
from integrations.utils import HOST
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
from storage.openhands_pr import OpenhandsPR
|
||||||
|
from storage.openhands_pr_store import OpenhandsPRStore
|
||||||
|
|
||||||
|
from openhands.core.config import load_openhands_config
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||||
|
from openhands.integrations.service_types import ProviderType
|
||||||
|
from openhands.storage import get_file_store
|
||||||
|
from openhands.storage.locations import get_conversation_dir
|
||||||
|
|
||||||
|
config = load_openhands_config()
|
||||||
|
file_store = get_file_store(config.file_store, config.file_store_path)
|
||||||
|
|
||||||
|
|
||||||
|
COLLECT_GITHUB_INTERACTIONS = (
|
||||||
|
os.getenv('COLLECT_GITHUB_INTERACTIONS', 'false') == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriggerType(str, Enum):
|
||||||
|
ISSUE_LABEL = 'issue-label'
|
||||||
|
ISSUE_COMMENT = 'issue-coment'
|
||||||
|
PR_COMMENT_MACRO = 'label'
|
||||||
|
INLINE_PR_COMMENT_MACRO = 'inline-label'
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubDataCollector:
|
||||||
|
"""
|
||||||
|
Saves data on Cloud Resolver Interactions
|
||||||
|
|
||||||
|
1. We always save
|
||||||
|
- Resolver trigger (comment or label)
|
||||||
|
- Metadata (who started the job, repo name, issue number)
|
||||||
|
|
||||||
|
2. We save data for the type of interaction
|
||||||
|
a. For labelled issues, we save
|
||||||
|
- {conversation_dir}/{conversation_id}/github_data/issue__{repo_name}_{issue_number}.json
|
||||||
|
- issue number
|
||||||
|
- trigger
|
||||||
|
- metadata
|
||||||
|
- body
|
||||||
|
- title
|
||||||
|
- comments
|
||||||
|
|
||||||
|
- {conversation_dir}/{conversation_id}/github_data/pr__{repo_name}_{pr_number}.json
|
||||||
|
- pr_number
|
||||||
|
- metadata
|
||||||
|
- body
|
||||||
|
- title
|
||||||
|
- commits/authors
|
||||||
|
|
||||||
|
3. For all PRs that were opened with the resolver, we save
|
||||||
|
- github_data/prs/{repo_name}_{pr_number}/data.json
|
||||||
|
- pr_number
|
||||||
|
- title
|
||||||
|
- body
|
||||||
|
- commits/authors
|
||||||
|
- code diffs
|
||||||
|
- merge status (either merged/closed)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.file_store = file_store
|
||||||
|
self.issues_path = 'github_data/issue-{}-{}/data.json'
|
||||||
|
self.matching_pr_path = 'github_data/pr-{}-{}/data.json'
|
||||||
|
# self.full_saved_pr_path = 'github_data/prs/{}-{}/data.json'
|
||||||
|
self.full_saved_pr_path = 'prs/github/{}-{}/data.json'
|
||||||
|
self.github_integration = GithubIntegration(
|
||||||
|
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
)
|
||||||
|
self.conversation_id = None
|
||||||
|
|
||||||
|
async def _get_repo_node_id(self, repo_id: str, gh_client) -> str:
|
||||||
|
"""
|
||||||
|
Get the new GitHub GraphQL node ID for a repository using the GitHub client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||||
|
gh_client: SaaSGitHubService client with authentication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New format node ID for GraphQL queries (e.g., "R_kgDOLfkiww")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await gh_client.get_repository_node_id(repo_id)
|
||||||
|
except Exception:
|
||||||
|
# Fallback to old format if REST API fails
|
||||||
|
node_string = f'010:Repository{repo_id}'
|
||||||
|
return base64.b64encode(node_string.encode()).decode()
|
||||||
|
|
||||||
|
def _create_file_name(
|
||||||
|
self, path: str, repo_id: str, number: int, conversation_id: str | None
|
||||||
|
):
|
||||||
|
suffix = path.format(repo_id, number)
|
||||||
|
|
||||||
|
if conversation_id:
|
||||||
|
return f'{get_conversation_dir(conversation_id)}{suffix}'
|
||||||
|
|
||||||
|
return suffix
|
||||||
|
|
||||||
|
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||||
|
token_data = self.github_integration.get_access_token(
|
||||||
|
installation_id # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return token_data.token
|
||||||
|
|
||||||
|
def _check_openhands_author(self, name, login) -> bool:
|
||||||
|
return (
|
||||||
|
name == 'openhands'
|
||||||
|
or login == 'openhands'
|
||||||
|
or login == 'openhands-agent'
|
||||||
|
or login == 'openhands-ai'
|
||||||
|
or login == 'openhands-staging'
|
||||||
|
or login == 'openhands-exp'
|
||||||
|
or (login and 'openhands' in login.lower())
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_issue_comments(
|
||||||
|
self, installation_id: str, repo_name: str, issue_number: int, conversation_id
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve all comments from an issue until a comment with conversation_id is found
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
installation_token = self._get_installation_access_token(installation_id)
|
||||||
|
|
||||||
|
with Github(installation_token) as github_client:
|
||||||
|
repo = github_client.get_repo(repo_name)
|
||||||
|
issue = repo.get_issue(issue_number)
|
||||||
|
comments = []
|
||||||
|
|
||||||
|
for comment in issue.get_comments():
|
||||||
|
comment_data = {
|
||||||
|
'id': comment.id,
|
||||||
|
'body': comment.body,
|
||||||
|
'created_at': comment.created_at.isoformat(),
|
||||||
|
'user': comment.user.login,
|
||||||
|
}
|
||||||
|
|
||||||
|
# If we find a comment containing conversation_id, stop collecting comments
|
||||||
|
if conversation_id in comment.body:
|
||||||
|
break
|
||||||
|
|
||||||
|
comments.append(comment_data)
|
||||||
|
|
||||||
|
return comments
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _save_data(self, path: str, data: dict[str, Any]):
|
||||||
|
"""Save data to a path"""
|
||||||
|
self.file_store.write(path, json.dumps(data))
|
||||||
|
|
||||||
|
def _save_issue(
|
||||||
|
self,
|
||||||
|
github_view: GithubIssue,
|
||||||
|
trigger_type: TriggerType,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save issue data when it's labeled with openhands
|
||||||
|
|
||||||
|
1. Save under {conversation_dir}/{conversation_id}/github_data/issue_{issue_number}.json
|
||||||
|
2. Save issue snapshot (title, body, comments)
|
||||||
|
3. Save trigger type (label)
|
||||||
|
4. Save PR opened (if exists, this information comes later when agent has finished its task)
|
||||||
|
- Save commit shas
|
||||||
|
- Save author info
|
||||||
|
5. Was PR merged or closed
|
||||||
|
"""
|
||||||
|
|
||||||
|
conversation_id = github_view.conversation_id
|
||||||
|
|
||||||
|
if not conversation_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
issue_number = github_view.issue_number
|
||||||
|
file_name = self._create_file_name(
|
||||||
|
path=self.issues_path,
|
||||||
|
repo_id=github_view.full_repo_name,
|
||||||
|
number=issue_number,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
payload_data = github_view.raw_payload.message.get('payload', {})
|
||||||
|
isssue_details = payload_data.get('issue', {})
|
||||||
|
is_repo_private = payload_data.get('repository', {}).get('private', 'true')
|
||||||
|
title = isssue_details.get('title', '')
|
||||||
|
body = isssue_details.get('body', '')
|
||||||
|
|
||||||
|
# Get comments for the issue
|
||||||
|
comments = self._get_issue_comments(
|
||||||
|
github_view.installation_id,
|
||||||
|
github_view.full_repo_name,
|
||||||
|
issue_number,
|
||||||
|
conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'trigger': trigger_type,
|
||||||
|
'metadata': {
|
||||||
|
'user': github_view.user_info.username,
|
||||||
|
'repo_name': github_view.full_repo_name,
|
||||||
|
'is_repo_private': is_repo_private,
|
||||||
|
'number': issue_number,
|
||||||
|
},
|
||||||
|
'contents': {
|
||||||
|
'title': title,
|
||||||
|
'body': body,
|
||||||
|
'comments': comments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
self._save_data(file_name, data)
|
||||||
|
logger.info(
|
||||||
|
f'[Github]: Saved issue #{issue_number} for {github_view.full_repo_name}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_pr_commits(self, installation_id: str, repo_name: str, pr_number: int):
|
||||||
|
commits = []
|
||||||
|
installation_token = self._get_installation_access_token(installation_id)
|
||||||
|
with Github(installation_token) as github_client:
|
||||||
|
repo = github_client.get_repo(repo_name)
|
||||||
|
pr = repo.get_pull(pr_number)
|
||||||
|
|
||||||
|
for commit in pr.get_commits():
|
||||||
|
commit_data = {
|
||||||
|
'sha': commit.sha,
|
||||||
|
'authors': commit.author.login if commit.author else None,
|
||||||
|
'committed_date': commit.commit.committer.date.isoformat()
|
||||||
|
if commit.commit and commit.commit.committer
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
commits.append(commit_data)
|
||||||
|
|
||||||
|
return commits
|
||||||
|
|
||||||
|
def _extract_repo_metadata(self, repo_data: dict) -> dict:
|
||||||
|
"""Extract repository metadata from GraphQL response"""
|
||||||
|
return {
|
||||||
|
'name': repo_data.get('name'),
|
||||||
|
'owner': repo_data.get('owner', {}).get('login'),
|
||||||
|
'languages': [
|
||||||
|
lang['name'] for lang in repo_data.get('languages', {}).get('nodes', [])
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _process_commits_page(self, pr_data: dict, commits: list) -> None:
|
||||||
|
"""Process commits from a single GraphQL page"""
|
||||||
|
commit_nodes = pr_data.get('commits', {}).get('nodes', [])
|
||||||
|
for commit_node in commit_nodes:
|
||||||
|
commit = commit_node['commit']
|
||||||
|
author_info = commit.get('author', {})
|
||||||
|
commit_data = {
|
||||||
|
'sha': commit['oid'],
|
||||||
|
'message': commit['message'],
|
||||||
|
'committed_date': commit.get('committedDate'),
|
||||||
|
'author': {
|
||||||
|
'name': author_info.get('name'),
|
||||||
|
'email': author_info.get('email'),
|
||||||
|
'github_login': author_info.get('user', {}).get('login'),
|
||||||
|
},
|
||||||
|
'stats': {
|
||||||
|
'additions': commit.get('additions', 0),
|
||||||
|
'deletions': commit.get('deletions', 0),
|
||||||
|
'changed_files': commit.get('changedFiles', 0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
commits.append(commit_data)
|
||||||
|
|
||||||
|
def _process_pr_comments_page(self, pr_data: dict, pr_comments: list) -> None:
|
||||||
|
"""Process PR comments from a single GraphQL page"""
|
||||||
|
comment_nodes = pr_data.get('comments', {}).get('nodes', [])
|
||||||
|
for comment in comment_nodes:
|
||||||
|
comment_data = {
|
||||||
|
'author': comment.get('author', {}).get('login'),
|
||||||
|
'body': comment.get('body'),
|
||||||
|
'created_at': comment.get('createdAt'),
|
||||||
|
'type': 'pr_comment',
|
||||||
|
}
|
||||||
|
pr_comments.append(comment_data)
|
||||||
|
|
||||||
|
def _process_review_comments_page(
|
||||||
|
self, pr_data: dict, review_comments: list
|
||||||
|
) -> None:
|
||||||
|
"""Process reviews and review comments from a single GraphQL page"""
|
||||||
|
review_nodes = pr_data.get('reviews', {}).get('nodes', [])
|
||||||
|
for review in review_nodes:
|
||||||
|
# Add the review itself if it has a body
|
||||||
|
if review.get('body', '').strip():
|
||||||
|
review_data = {
|
||||||
|
'author': review.get('author', {}).get('login'),
|
||||||
|
'body': review.get('body'),
|
||||||
|
'created_at': review.get('createdAt'),
|
||||||
|
'state': review.get('state'),
|
||||||
|
'type': 'review',
|
||||||
|
}
|
||||||
|
review_comments.append(review_data)
|
||||||
|
|
||||||
|
# Add individual review comments
|
||||||
|
review_comment_nodes = review.get('comments', {}).get('nodes', [])
|
||||||
|
for review_comment in review_comment_nodes:
|
||||||
|
review_comment_data = {
|
||||||
|
'author': review_comment.get('author', {}).get('login'),
|
||||||
|
'body': review_comment.get('body'),
|
||||||
|
'created_at': review_comment.get('createdAt'),
|
||||||
|
'type': 'review_comment',
|
||||||
|
}
|
||||||
|
review_comments.append(review_comment_data)
|
||||||
|
|
||||||
|
def _count_openhands_activity(
|
||||||
|
self, commits: list, review_comments: list, pr_comments: list
|
||||||
|
) -> tuple[int, int, int]:
|
||||||
|
"""Count OpenHands commits, review comments, and general PR comments"""
|
||||||
|
openhands_commit_count = 0
|
||||||
|
openhands_review_comment_count = 0
|
||||||
|
openhands_general_comment_count = 0
|
||||||
|
|
||||||
|
# Count commits by OpenHands (check both name and login)
|
||||||
|
for commit in commits:
|
||||||
|
author = commit.get('author', {})
|
||||||
|
author_name = author.get('name', '').lower()
|
||||||
|
author_login = (
|
||||||
|
author.get('github_login', '').lower()
|
||||||
|
if author.get('github_login')
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._check_openhands_author(author_name, author_login):
|
||||||
|
openhands_commit_count += 1
|
||||||
|
|
||||||
|
# Count review comments by OpenHands
|
||||||
|
for review_comment in review_comments:
|
||||||
|
author_login = (
|
||||||
|
review_comment.get('author', '').lower()
|
||||||
|
if review_comment.get('author')
|
||||||
|
else ''
|
||||||
|
)
|
||||||
|
author_name = '' # Initialize to avoid reference before assignment
|
||||||
|
if self._check_openhands_author(author_name, author_login):
|
||||||
|
openhands_review_comment_count += 1
|
||||||
|
|
||||||
|
# Count general PR comments by OpenHands
|
||||||
|
for pr_comment in pr_comments:
|
||||||
|
author_login = (
|
||||||
|
pr_comment.get('author', '').lower() if pr_comment.get('author') else ''
|
||||||
|
)
|
||||||
|
author_name = '' # Initialize to avoid reference before assignment
|
||||||
|
if self._check_openhands_author(author_name, author_login):
|
||||||
|
openhands_general_comment_count += 1
|
||||||
|
|
||||||
|
return (
|
||||||
|
openhands_commit_count,
|
||||||
|
openhands_review_comment_count,
|
||||||
|
openhands_general_comment_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_final_data_structure(
|
||||||
|
self,
|
||||||
|
repo_data: dict,
|
||||||
|
pr_data: dict,
|
||||||
|
commits: list,
|
||||||
|
pr_comments: list,
|
||||||
|
review_comments: list,
|
||||||
|
openhands_commit_count: int,
|
||||||
|
openhands_review_comment_count: int,
|
||||||
|
openhands_general_comment_count: int = 0,
|
||||||
|
) -> dict:
|
||||||
|
"""Build the final data structure for JSON storage"""
|
||||||
|
|
||||||
|
is_merged = pr_data['merged']
|
||||||
|
merged_by = None
|
||||||
|
merge_commit_sha = None
|
||||||
|
if is_merged:
|
||||||
|
merged_by = pr_data.get('mergedBy', {}).get('login')
|
||||||
|
merge_commit_sha = pr_data.get('mergeCommit', {}).get('oid')
|
||||||
|
|
||||||
|
return {
|
||||||
|
'repo_metadata': self._extract_repo_metadata(repo_data),
|
||||||
|
'pr_metadata': {
|
||||||
|
'username': pr_data.get('author', {}).get('login'),
|
||||||
|
'number': pr_data['number'],
|
||||||
|
'title': pr_data['title'],
|
||||||
|
'body': pr_data['body'],
|
||||||
|
'comments': pr_comments,
|
||||||
|
},
|
||||||
|
'commits': commits,
|
||||||
|
'review_comments': review_comments,
|
||||||
|
'merge_status': {
|
||||||
|
'merged': pr_data['merged'],
|
||||||
|
'merged_by': merged_by,
|
||||||
|
'state': pr_data['state'],
|
||||||
|
'merge_commit_sha': merge_commit_sha,
|
||||||
|
},
|
||||||
|
'openhands_stats': {
|
||||||
|
'num_commits': openhands_commit_count,
|
||||||
|
'num_review_comments': openhands_review_comment_count,
|
||||||
|
'num_general_comments': openhands_general_comment_count,
|
||||||
|
'helped_author': openhands_commit_count > 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def save_full_pr(self, openhands_pr: OpenhandsPR) -> None:
|
||||||
|
"""
|
||||||
|
Save PR information including metadata and commit details using GraphQL
|
||||||
|
|
||||||
|
Saves:
|
||||||
|
- Repo metadata (repo name, languages, contributors)
|
||||||
|
- PR metadata (number, title, body, author, comments)
|
||||||
|
- Commit information (sha, authors, message, stats)
|
||||||
|
- Merge status
|
||||||
|
- Num openhands commits
|
||||||
|
- Num openhands review comments
|
||||||
|
"""
|
||||||
|
pr_number = openhands_pr.pr_number
|
||||||
|
installation_id = openhands_pr.installation_id
|
||||||
|
repo_id = openhands_pr.repo_id
|
||||||
|
|
||||||
|
# Get installation token and create Github client
|
||||||
|
# This will fail if the user decides to revoke OpenHands' access to their repo
|
||||||
|
# In this case, we will simply return when the exception occurs
|
||||||
|
# This will not lead to infinite loops when processing PRs as we log number of attempts and cap max attempts independently from this
|
||||||
|
try:
|
||||||
|
installation_token = self._get_installation_access_token(installation_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f'Failed to generate token for {openhands_pr.repo_name}: {e}'
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
gh_client = GithubServiceImpl(token=SecretStr(installation_token))
|
||||||
|
|
||||||
|
# Get the new format GraphQL node ID
|
||||||
|
node_id = await self._get_repo_node_id(repo_id, gh_client)
|
||||||
|
|
||||||
|
# Initialize data structures
|
||||||
|
commits: list[dict] = []
|
||||||
|
pr_comments: list[dict] = []
|
||||||
|
review_comments: list[dict] = []
|
||||||
|
pr_data = None
|
||||||
|
repo_data = None
|
||||||
|
|
||||||
|
# Pagination cursors
|
||||||
|
commits_after = None
|
||||||
|
comments_after = None
|
||||||
|
reviews_after = None
|
||||||
|
|
||||||
|
# Fetch all data with pagination
|
||||||
|
while True:
|
||||||
|
variables = {
|
||||||
|
'nodeId': node_id,
|
||||||
|
'pr_number': pr_number,
|
||||||
|
'commits_after': commits_after,
|
||||||
|
'comments_after': comments_after,
|
||||||
|
'reviews_after': reviews_after,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await gh_client.execute_graphql_query(
|
||||||
|
PR_QUERY_BY_NODE_ID, variables
|
||||||
|
)
|
||||||
|
if not result.get('data', {}).get('node', {}).get('pullRequest'):
|
||||||
|
break
|
||||||
|
|
||||||
|
pr_data = result['data']['node']['pullRequest']
|
||||||
|
repo_data = result['data']['node']
|
||||||
|
|
||||||
|
# Process data from this page using modular methods
|
||||||
|
self._process_commits_page(pr_data, commits)
|
||||||
|
self._process_pr_comments_page(pr_data, pr_comments)
|
||||||
|
self._process_review_comments_page(pr_data, review_comments)
|
||||||
|
|
||||||
|
# Check pagination for all three types
|
||||||
|
has_more_commits = (
|
||||||
|
pr_data.get('commits', {})
|
||||||
|
.get('pageInfo', {})
|
||||||
|
.get('hasNextPage', False)
|
||||||
|
)
|
||||||
|
has_more_comments = (
|
||||||
|
pr_data.get('comments', {})
|
||||||
|
.get('pageInfo', {})
|
||||||
|
.get('hasNextPage', False)
|
||||||
|
)
|
||||||
|
has_more_reviews = (
|
||||||
|
pr_data.get('reviews', {})
|
||||||
|
.get('pageInfo', {})
|
||||||
|
.get('hasNextPage', False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cursors
|
||||||
|
if has_more_commits:
|
||||||
|
commits_after = (
|
||||||
|
pr_data.get('commits', {}).get('pageInfo', {}).get('endCursor')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
commits_after = None
|
||||||
|
|
||||||
|
if has_more_comments:
|
||||||
|
comments_after = (
|
||||||
|
pr_data.get('comments', {}).get('pageInfo', {}).get('endCursor')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
comments_after = None
|
||||||
|
|
||||||
|
if has_more_reviews:
|
||||||
|
reviews_after = (
|
||||||
|
pr_data.get('reviews', {}).get('pageInfo', {}).get('endCursor')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
reviews_after = None
|
||||||
|
|
||||||
|
# Continue if there's more data to fetch
|
||||||
|
if not (has_more_commits or has_more_comments or has_more_reviews):
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Error fetching PR data', exc_info=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not pr_data or not repo_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Count OpenHands activity using modular method
|
||||||
|
(
|
||||||
|
openhands_commit_count,
|
||||||
|
openhands_review_comment_count,
|
||||||
|
openhands_general_comment_count,
|
||||||
|
) = self._count_openhands_activity(commits, review_comments, pr_comments)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Github]: PR #{pr_number} - OpenHands commits: {openhands_commit_count}, review comments: {openhands_review_comment_count}, general comments: {openhands_general_comment_count}'
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'[Github]: PR #{pr_number} - Total collected: {len(commits)} commits, {len(pr_comments)} PR comments, {len(review_comments)} review comments'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build final data structure using modular method
|
||||||
|
data = self._build_final_data_structure(
|
||||||
|
repo_data,
|
||||||
|
pr_data or {},
|
||||||
|
commits,
|
||||||
|
pr_comments,
|
||||||
|
review_comments,
|
||||||
|
openhands_commit_count,
|
||||||
|
openhands_review_comment_count,
|
||||||
|
openhands_general_comment_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the OpenhandsPR object with OpenHands statistics
|
||||||
|
store = OpenhandsPRStore.get_instance()
|
||||||
|
openhands_helped_author = openhands_commit_count > 0
|
||||||
|
|
||||||
|
# Update the PR with OpenHands statistics
|
||||||
|
update_success = store.update_pr_openhands_stats(
|
||||||
|
repo_id=repo_id,
|
||||||
|
pr_number=pr_number,
|
||||||
|
original_updated_at=openhands_pr.updated_at,
|
||||||
|
openhands_helped_author=openhands_helped_author,
|
||||||
|
num_openhands_commits=openhands_commit_count,
|
||||||
|
num_openhands_review_comments=openhands_review_comment_count,
|
||||||
|
num_openhands_general_comments=openhands_general_comment_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not update_success:
|
||||||
|
logger.warning(
|
||||||
|
f'[Github]: Failed to update OpenHands stats for PR #{pr_number} in repo {repo_id} - PR may have been modified concurrently'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to file
|
||||||
|
file_name = self._create_file_name(
|
||||||
|
path=self.full_saved_pr_path,
|
||||||
|
repo_id=repo_id,
|
||||||
|
number=pr_number,
|
||||||
|
conversation_id=None,
|
||||||
|
)
|
||||||
|
self._save_data(file_name, data)
|
||||||
|
logger.info(
|
||||||
|
f'[Github]: Saved full PR #{pr_number} for repo {repo_id} with OpenHands stats: commits={openhands_commit_count}, reviews={openhands_review_comment_count}, general_comments={openhands_general_comment_count}, helped={openhands_helped_author}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_for_conversation_url(self, body):
|
||||||
|
conversation_pattern = re.search(
|
||||||
|
rf'https://{HOST}/conversations/([a-zA-Z0-9-]+)(?:\s|[.,;!?)]|$)', body
|
||||||
|
)
|
||||||
|
if conversation_pattern:
|
||||||
|
return conversation_pattern.group(1)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_pr_closed_or_merged(self, payload):
|
||||||
|
"""
|
||||||
|
Check if PR was closed (regardless of conversation URL)
|
||||||
|
"""
|
||||||
|
action = payload.get('action', '')
|
||||||
|
return action == 'closed' and 'pull_request' in payload
|
||||||
|
|
||||||
|
def _track_closed_or_merged_pr(self, payload):
|
||||||
|
"""
|
||||||
|
Track PR closed/merged event
|
||||||
|
"""
|
||||||
|
|
||||||
|
repo_id = str(payload['repository']['id'])
|
||||||
|
pr_number = payload['number']
|
||||||
|
installation_id = str(payload['installation']['id'])
|
||||||
|
private = payload['repository']['private']
|
||||||
|
repo_name = payload['repository']['full_name']
|
||||||
|
|
||||||
|
pr_data = payload['pull_request']
|
||||||
|
|
||||||
|
# Extract PR metrics
|
||||||
|
num_reviewers = len(pr_data.get('requested_reviewers', []))
|
||||||
|
num_commits = pr_data.get('commits', 0)
|
||||||
|
num_review_comments = pr_data.get('review_comments', 0)
|
||||||
|
num_general_comments = pr_data.get('comments', 0)
|
||||||
|
num_changed_files = pr_data.get('changed_files', 0)
|
||||||
|
num_additions = pr_data.get('additions', 0)
|
||||||
|
num_deletions = pr_data.get('deletions', 0)
|
||||||
|
merged = pr_data.get('merged', False)
|
||||||
|
|
||||||
|
# Extract closed_at timestamp
|
||||||
|
# Example: "closed_at":"2025-06-19T21:19:36Z"
|
||||||
|
closed_at_str = pr_data.get('closed_at')
|
||||||
|
created_at = pr_data.get('created_at')
|
||||||
|
|
||||||
|
closed_at = datetime.fromisoformat(closed_at_str.replace('Z', '+00:00'))
|
||||||
|
|
||||||
|
# Determine status based on whether it was merged
|
||||||
|
status = PRStatus.MERGED if merged else PRStatus.CLOSED
|
||||||
|
|
||||||
|
store = OpenhandsPRStore.get_instance()
|
||||||
|
|
||||||
|
pr = OpenhandsPR(
|
||||||
|
repo_name=repo_name,
|
||||||
|
repo_id=repo_id,
|
||||||
|
pr_number=pr_number,
|
||||||
|
status=status,
|
||||||
|
provider=ProviderType.GITHUB.value,
|
||||||
|
installation_id=installation_id,
|
||||||
|
private=private,
|
||||||
|
num_reviewers=num_reviewers,
|
||||||
|
num_commits=num_commits,
|
||||||
|
num_review_comments=num_review_comments,
|
||||||
|
num_changed_files=num_changed_files,
|
||||||
|
num_additions=num_additions,
|
||||||
|
num_deletions=num_deletions,
|
||||||
|
merged=merged,
|
||||||
|
created_at=created_at,
|
||||||
|
closed_at=closed_at,
|
||||||
|
# These properties will be enriched later
|
||||||
|
openhands_helped_author=None,
|
||||||
|
num_openhands_commits=None,
|
||||||
|
num_openhands_review_comments=None,
|
||||||
|
num_general_comments=num_general_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
store.insert_pr(pr)
|
||||||
|
logger.info(f'Tracked PR {status}: {repo_id}#{pr_number}')
|
||||||
|
|
||||||
|
def process_payload(self, message: Message):
|
||||||
|
if not COLLECT_GITHUB_INTERACTIONS:
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_payload = message.message.get('payload', {})
|
||||||
|
|
||||||
|
if self._is_pr_closed_or_merged(raw_payload):
|
||||||
|
self._track_closed_or_merged_pr(raw_payload)
|
||||||
|
|
||||||
|
async def save_data(self, github_view: ResolverViewInterface):
|
||||||
|
if not COLLECT_GITHUB_INTERACTIONS:
|
||||||
|
return
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO: track issue metadata in DB and save comments to filestore
|
||||||
344
enterprise/integrations/github/github_manager.py
Normal file
344
enterprise/integrations/github/github_manager.py
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
from github import Github, GithubIntegration
|
||||||
|
from integrations.github.data_collector import GitHubDataCollector
|
||||||
|
from integrations.github.github_solvability import summarize_issue_solvability
|
||||||
|
from integrations.github.github_view import (
|
||||||
|
GithubFactory,
|
||||||
|
GithubFailingAction,
|
||||||
|
GithubInlinePRComment,
|
||||||
|
GithubIssue,
|
||||||
|
GithubIssueComment,
|
||||||
|
GithubPRComment,
|
||||||
|
)
|
||||||
|
from integrations.manager import Manager
|
||||||
|
from integrations.models import (
|
||||||
|
Message,
|
||||||
|
SourceType,
|
||||||
|
)
|
||||||
|
from integrations.types import ResolverViewInterface
|
||||||
|
from integrations.utils import (
|
||||||
|
CONVERSATION_URL,
|
||||||
|
HOST_URL,
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from server.utils.conversation_callback_utils import register_callback_processor
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||||
|
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||||
|
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||||
|
from openhands.utils.async_utils import call_sync_from_async
|
||||||
|
|
||||||
|
|
||||||
|
class GithubManager(Manager):
|
||||||
|
def __init__(
|
||||||
|
self, token_manager: TokenManager, data_collector: GitHubDataCollector
|
||||||
|
):
|
||||||
|
self.token_manager = token_manager
|
||||||
|
self.data_collector = data_collector
|
||||||
|
self.github_integration = GithubIntegration(
|
||||||
|
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
)
|
||||||
|
|
||||||
|
self.jinja_env = Environment(
|
||||||
|
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'github')
|
||||||
|
)
|
||||||
|
|
||||||
|
def _confirm_incoming_source_type(self, message: Message):
|
||||||
|
if message.source != SourceType.GITHUB:
|
||||||
|
raise ValueError(f'Unexpected message source {message.source}')
|
||||||
|
|
||||||
|
def _get_full_repo_name(self, repo_obj: dict) -> str:
|
||||||
|
owner = repo_obj['owner']['login']
|
||||||
|
repo_name = repo_obj['name']
|
||||||
|
|
||||||
|
return f'{owner}/{repo_name}'
|
||||||
|
|
||||||
|
def _get_installation_access_token(self, installation_id: str) -> str:
|
||||||
|
# get_access_token is typed to only accept int, but it can handle str.
|
||||||
|
token_data = self.github_integration.get_access_token(
|
||||||
|
installation_id # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
return token_data.token
|
||||||
|
|
||||||
|
def _add_reaction(
|
||||||
|
self, github_view: ResolverViewInterface, reaction: str, installation_token: str
|
||||||
|
):
|
||||||
|
"""Add a reaction to the GitHub issue, PR, or comment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
github_view: The GitHub view object containing issue/PR/comment info
|
||||||
|
reaction: The reaction to add (e.g. "eyes", "+1", "-1", "laugh", "confused", "heart", "hooray", "rocket")
|
||||||
|
installation_token: GitHub installation access token for API access
|
||||||
|
"""
|
||||||
|
with Github(installation_token) as github_client:
|
||||||
|
repo = github_client.get_repo(github_view.full_repo_name)
|
||||||
|
# Add reaction based on view type
|
||||||
|
if isinstance(github_view, GithubInlinePRComment):
|
||||||
|
pr = repo.get_pull(github_view.issue_number)
|
||||||
|
inline_comment = pr.get_review_comment(github_view.comment_id)
|
||||||
|
inline_comment.create_reaction(reaction)
|
||||||
|
|
||||||
|
elif isinstance(github_view, (GithubIssueComment, GithubPRComment)):
|
||||||
|
issue = repo.get_issue(github_view.issue_number)
|
||||||
|
comment = issue.get_comment(github_view.comment_id)
|
||||||
|
comment.create_reaction(reaction)
|
||||||
|
else:
|
||||||
|
issue = repo.get_issue(github_view.issue_number)
|
||||||
|
issue.create_reaction(reaction)
|
||||||
|
|
||||||
|
def _user_has_write_access_to_repo(
|
||||||
|
self, installation_id: str, full_repo_name: str, username: str
|
||||||
|
) -> bool:
|
||||||
|
"""Check if the user is an owner, collaborator, or member of the repository."""
|
||||||
|
with self.github_integration.get_github_for_installation(
|
||||||
|
installation_id, # type: ignore[arg-type]
|
||||||
|
{},
|
||||||
|
) as repos:
|
||||||
|
repository = repos.get_repo(full_repo_name)
|
||||||
|
|
||||||
|
# Check if the user is a collaborator
|
||||||
|
try:
|
||||||
|
collaborator = repository.get_collaborator_permission(username)
|
||||||
|
if collaborator in ['admin', 'write']:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# If the above fails, check if the user is an owner or member
|
||||||
|
org = repository.organization
|
||||||
|
if org:
|
||||||
|
user = org.get_members(username)
|
||||||
|
return user is not None
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def is_job_requested(self, message: Message) -> bool:
|
||||||
|
self._confirm_incoming_source_type(message)
|
||||||
|
|
||||||
|
installation_id = message.message['installation']
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
repo_obj = payload.get('repository')
|
||||||
|
if not repo_obj:
|
||||||
|
return False
|
||||||
|
username = payload.get('sender', {}).get('login')
|
||||||
|
repo_name = self._get_full_repo_name(repo_obj)
|
||||||
|
|
||||||
|
# Suggestions contain `@openhands` macro; avoid kicking off jobs for system recommendations
|
||||||
|
if GithubFactory.is_pr_comment(
|
||||||
|
message
|
||||||
|
) and GithubFailingAction.unqiue_suggestions_header in payload.get(
|
||||||
|
'comment', {}
|
||||||
|
).get('body', ''):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if GithubFactory.is_eligible_for_conversation_starter(
|
||||||
|
message
|
||||||
|
) and self._user_has_write_access_to_repo(installation_id, repo_name, username):
|
||||||
|
await GithubFactory.trigger_conversation_starter(message)
|
||||||
|
|
||||||
|
if not (
|
||||||
|
GithubFactory.is_labeled_issue(message)
|
||||||
|
or GithubFactory.is_issue_comment(message)
|
||||||
|
or GithubFactory.is_pr_comment(message)
|
||||||
|
or GithubFactory.is_inline_pr_comment(message)
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f'[GitHub] Checking permissions for {username} in {repo_name}')
|
||||||
|
|
||||||
|
return self._user_has_write_access_to_repo(installation_id, repo_name, username)
|
||||||
|
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
self._confirm_incoming_source_type(message)
|
||||||
|
try:
|
||||||
|
await call_sync_from_async(self.data_collector.process_payload, message)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
'[Github]: Error processing payload for gh interaction', exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if await self.is_job_requested(message):
|
||||||
|
github_view = await GithubFactory.create_github_view_from_payload(
|
||||||
|
message, self.token_manager
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Creating job for {github_view.user_info.username} in {github_view.full_repo_name}#{github_view.issue_number}'
|
||||||
|
)
|
||||||
|
# Get the installation token
|
||||||
|
installation_token = self._get_installation_access_token(
|
||||||
|
github_view.installation_id
|
||||||
|
)
|
||||||
|
# Store the installation token
|
||||||
|
self.token_manager.store_org_token(
|
||||||
|
github_view.installation_id, installation_token
|
||||||
|
)
|
||||||
|
# Add eyes reaction to acknowledge we've read the request
|
||||||
|
self._add_reaction(github_view, 'eyes', installation_token)
|
||||||
|
await self.start_job(github_view)
|
||||||
|
|
||||||
|
async def send_message(self, message: Message, github_view: ResolverViewInterface):
|
||||||
|
installation_token = self.token_manager.load_org_token(
|
||||||
|
github_view.installation_id
|
||||||
|
)
|
||||||
|
if not installation_token:
|
||||||
|
logger.warning('Missing installation token')
|
||||||
|
return
|
||||||
|
|
||||||
|
outgoing_message = message.message
|
||||||
|
|
||||||
|
if isinstance(github_view, GithubInlinePRComment):
|
||||||
|
with Github(installation_token) as github_client:
|
||||||
|
repo = github_client.get_repo(github_view.full_repo_name)
|
||||||
|
pr = repo.get_pull(github_view.issue_number)
|
||||||
|
pr.create_review_comment_reply(
|
||||||
|
comment_id=github_view.comment_id, body=outgoing_message
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (
|
||||||
|
isinstance(github_view, GithubPRComment)
|
||||||
|
or isinstance(github_view, GithubIssueComment)
|
||||||
|
or isinstance(github_view, GithubIssue)
|
||||||
|
):
|
||||||
|
with Github(installation_token) as github_client:
|
||||||
|
repo = github_client.get_repo(github_view.full_repo_name)
|
||||||
|
issue = repo.get_issue(number=github_view.issue_number)
|
||||||
|
issue.create_comment(outgoing_message)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning('Unsupported location')
|
||||||
|
return
|
||||||
|
|
||||||
|
async def start_job(self, github_view: ResolverViewInterface):
|
||||||
|
"""Kick off a job with openhands agent.
|
||||||
|
|
||||||
|
1. Get user credential
|
||||||
|
2. Initialize new conversation with repo
|
||||||
|
3. Save interaction data
|
||||||
|
"""
|
||||||
|
# Importing here prevents circular import
|
||||||
|
from server.conversation_callback_processor.github_callback_processor import (
|
||||||
|
GithubCallbackProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg_info = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_info = github_view.user_info
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Starting job for user {user_info.username} (id={user_info.user_id})'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create conversation
|
||||||
|
user_token = await self.token_manager.get_idp_token_from_idp_user_id(
|
||||||
|
str(user_info.user_id), ProviderType.GITHUB
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_token:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitHub] No token found for user {user_info.username} (id={user_info.user_id})'
|
||||||
|
)
|
||||||
|
raise MissingSettingsError('Missing settings')
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Creating new conversation for user {user_info.username}'
|
||||||
|
)
|
||||||
|
|
||||||
|
secret_store = UserSecrets(
|
||||||
|
provider_tokens=MappingProxyType(
|
||||||
|
{
|
||||||
|
ProviderType.GITHUB: ProviderToken(
|
||||||
|
token=SecretStr(user_token),
|
||||||
|
user_id=str(user_info.user_id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# We first initialize a conversation and generate the solvability report BEFORE starting the conversation runtime
|
||||||
|
# This helps us accumulate llm spend without requiring a running runtime. This setups us up for
|
||||||
|
# 1. If there is a problem starting the runtime we still have accumulated total conversation cost
|
||||||
|
# 2. In the future, based on the report confidence we can conditionally start the conversation
|
||||||
|
# 3. Once the conversation is started, its base cost will include the report's spend as well which allows us to control max budget per resolver task
|
||||||
|
convo_metadata = await github_view.initialize_new_conversation()
|
||||||
|
solvability_summary = None
|
||||||
|
try:
|
||||||
|
if user_token:
|
||||||
|
solvability_summary = await summarize_issue_solvability(
|
||||||
|
github_view, user_token
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
'[Github]: No user token available for solvability analysis'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[Github]: Error summarizing issue solvability: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
await github_view.create_new_conversation(
|
||||||
|
self.jinja_env, secret_store.provider_tokens, convo_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_id = github_view.conversation_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Created conversation {conversation_id} for user {user_info.username}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a GithubCallbackProcessor
|
||||||
|
processor = GithubCallbackProcessor(
|
||||||
|
github_view=github_view,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback processor
|
||||||
|
register_callback_processor(conversation_id, processor)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Github] Registered callback processor for conversation {conversation_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send message with conversation link
|
||||||
|
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||||
|
base_msg = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||||
|
# Combine messages: include solvability report with "I'm on it!" if successful
|
||||||
|
if solvability_summary:
|
||||||
|
msg_info = f'{base_msg}\n\n{solvability_summary}'
|
||||||
|
else:
|
||||||
|
msg_info = base_msg
|
||||||
|
|
||||||
|
except MissingSettingsError as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitHub] Missing settings error for user {user_info.username}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = f'@{user_info.username} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except LLMAuthenticationError as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitHub] LLM authentication error for user {user_info.username}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
msg = self.create_outgoing_message(msg_info)
|
||||||
|
await self.send_message(msg, github_view)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception('[Github]: Error starting job')
|
||||||
|
msg = self.create_outgoing_message(
|
||||||
|
msg='Uh oh! There was an unexpected error starting the job :('
|
||||||
|
)
|
||||||
|
await self.send_message(msg, github_view)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.data_collector.save_data(github_view)
|
||||||
|
except Exception:
|
||||||
|
logger.warning('[Github]: Error saving interaction data', exc_info=True)
|
||||||
143
enterprise/integrations/github/github_service.py
Normal file
143
enterprise/integrations/github/github_service.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from integrations.utils import store_repositories_in_db
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.github.github_service import GitHubService
|
||||||
|
from openhands.integrations.service_types import ProviderType, Repository
|
||||||
|
from openhands.server.types import AppMode
|
||||||
|
|
||||||
|
|
||||||
|
class SaaSGitHubService(GitHubService):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
external_auth_token: SecretStr | None = None,
|
||||||
|
external_auth_id: str | None = None,
|
||||||
|
token: SecretStr | None = None,
|
||||||
|
external_token_manager: bool = False,
|
||||||
|
base_domain: str | None = None,
|
||||||
|
):
|
||||||
|
logger.debug(
|
||||||
|
f'SaaSGitHubService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, github_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
user_id=user_id,
|
||||||
|
external_auth_token=external_auth_token,
|
||||||
|
external_auth_id=external_auth_id,
|
||||||
|
token=token,
|
||||||
|
external_token_manager=external_token_manager,
|
||||||
|
base_domain=base_domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.external_auth_token = external_auth_token
|
||||||
|
self.external_auth_id = external_auth_id
|
||||||
|
self.token_manager = TokenManager(external=external_token_manager)
|
||||||
|
|
||||||
|
async def get_latest_token(self) -> SecretStr | None:
|
||||||
|
github_token = None
|
||||||
|
if self.external_auth_token:
|
||||||
|
github_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token(
|
||||||
|
self.external_auth_token.get_secret_value(), ProviderType.GITHUB
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got GitHub token {github_token} from access token: {self.external_auth_token}'
|
||||||
|
)
|
||||||
|
elif self.external_auth_id:
|
||||||
|
offline_token = await self.token_manager.load_offline_token(
|
||||||
|
self.external_auth_id
|
||||||
|
)
|
||||||
|
github_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token_from_offline_token(
|
||||||
|
offline_token, ProviderType.GITHUB
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got GitHub token {github_token} from external auth user ID: {self.external_auth_id}'
|
||||||
|
)
|
||||||
|
elif self.user_id:
|
||||||
|
github_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||||
|
self.user_id, ProviderType.GITHUB
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got GitHub token {github_token} from user ID: {self.user_id}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning('external_auth_token and user_id not set!')
|
||||||
|
return github_token
|
||||||
|
|
||||||
|
async def get_pr_patches(
|
||||||
|
self, owner: str, repo: str, pr_number: int, per_page: int = 30, page: int = 1
|
||||||
|
):
|
||||||
|
"""Get patches for files changed in a PR with pagination support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owner: Repository owner
|
||||||
|
repo: Repository name
|
||||||
|
pr_number: Pull request number
|
||||||
|
per_page: Number of files per page (default: 30, max: 100)
|
||||||
|
page: Page number to fetch (default: 1)
|
||||||
|
"""
|
||||||
|
url = f'https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}/files'
|
||||||
|
params = {'per_page': min(per_page, 100), 'page': page} # GitHub max is 100
|
||||||
|
response, headers = await self._make_request(url, params)
|
||||||
|
|
||||||
|
# Parse pagination info from headers
|
||||||
|
has_next_page = 'next' in headers.get('link', '')
|
||||||
|
total_count = int(headers.get('total', 0))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'files': response,
|
||||||
|
'pagination': {
|
||||||
|
'has_next_page': has_next_page,
|
||||||
|
'total_count': total_count,
|
||||||
|
'current_page': page,
|
||||||
|
'per_page': per_page,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_repository_node_id(self, repo_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the new GitHub GraphQL node ID for a repository using REST API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: Numeric repository ID as string (e.g., "123456789")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New format node ID for GraphQL queries (e.g., "R_kgDOLfkiww")
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If the API request fails or node_id is not found
|
||||||
|
"""
|
||||||
|
url = f'https://api.github.com/repositories/{repo_id}'
|
||||||
|
response, _ = await self._make_request(url)
|
||||||
|
node_id = response.get('node_id')
|
||||||
|
if not node_id:
|
||||||
|
raise Exception(f'No node_id found for repository {repo_id}')
|
||||||
|
return node_id
|
||||||
|
|
||||||
|
async def get_paginated_repos(self, page, per_page, sort, installation_id):
|
||||||
|
repositories = await super().get_paginated_repos(
|
||||||
|
page, per_page, sort, installation_id
|
||||||
|
)
|
||||||
|
asyncio.create_task(
|
||||||
|
store_repositories_in_db(repositories, self.external_auth_id)
|
||||||
|
)
|
||||||
|
return repositories
|
||||||
|
|
||||||
|
async def get_all_repositories(
|
||||||
|
self, sort: str, app_mode: AppMode
|
||||||
|
) -> list[Repository]:
|
||||||
|
repositories = await super().get_all_repositories(sort, app_mode)
|
||||||
|
# Schedule the background task without awaiting it
|
||||||
|
asyncio.create_task(
|
||||||
|
store_repositories_in_db(repositories, self.external_auth_id)
|
||||||
|
)
|
||||||
|
# Return repositories immediately
|
||||||
|
return repositories
|
||||||
183
enterprise/integrations/github/github_solvability.py
Normal file
183
enterprise/integrations/github/github_solvability.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
from github import Github
|
||||||
|
from integrations.github.github_view import (
|
||||||
|
GithubInlinePRComment,
|
||||||
|
GithubIssueComment,
|
||||||
|
GithubPRComment,
|
||||||
|
GithubViewType,
|
||||||
|
)
|
||||||
|
from integrations.solvability.data import load_classifier
|
||||||
|
from integrations.solvability.models.report import SolvabilityReport
|
||||||
|
from integrations.solvability.models.summary import SolvabilitySummary
|
||||||
|
from integrations.utils import ENABLE_SOLVABILITY_ANALYSIS
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from server.auth.token_manager import get_config
|
||||||
|
from storage.database import session_maker
|
||||||
|
from storage.saas_settings_store import SaasSettingsStore
|
||||||
|
|
||||||
|
from openhands.core.config import LLMConfig
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.utils.async_utils import call_sync_from_async
|
||||||
|
from openhands.utils.utils import create_registry_and_conversation_stats
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_github_issue_context(
|
||||||
|
github_view: GithubViewType,
|
||||||
|
user_token: str,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch full GitHub issue/PR context including title, body, and comments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
full_repo_name: Full repository name in the format 'owner/repo'
|
||||||
|
issue_number: The issue or PR number
|
||||||
|
user_token: GitHub user access token
|
||||||
|
max_comments: Maximum number of comments to fetch (default: 10)
|
||||||
|
max_comment_length: Maximum length of each comment to include in the context (default: 500)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A comprehensive string containing the issue/PR context
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Build context string
|
||||||
|
context_parts = []
|
||||||
|
|
||||||
|
# Add title and body
|
||||||
|
context_parts.append(f'Title: {github_view.title}')
|
||||||
|
context_parts.append(f'Description:\n{github_view.description}')
|
||||||
|
|
||||||
|
with Github(user_token) as github_client:
|
||||||
|
repo = github_client.get_repo(github_view.full_repo_name)
|
||||||
|
issue = repo.get_issue(github_view.issue_number)
|
||||||
|
if issue.labels:
|
||||||
|
labels = [label.name for label in issue.labels]
|
||||||
|
context_parts.append(f"Labels: {', '.join(labels)}")
|
||||||
|
|
||||||
|
for comment in github_view.previous_comments:
|
||||||
|
context_parts.append(f'- {comment.author}: {comment.body}')
|
||||||
|
|
||||||
|
return '\n\n'.join(context_parts)
|
||||||
|
|
||||||
|
|
||||||
|
async def summarize_issue_solvability(
|
||||||
|
github_view: GithubViewType,
|
||||||
|
user_token: str,
|
||||||
|
timeout: float = 60.0 * 5,
|
||||||
|
) -> str:
|
||||||
|
"""Generate a solvability summary for an issue using the resolver view interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resolver_view: A resolver view interface instance (e.g., GithubIssue, GithubPRComment)
|
||||||
|
user_token: GitHub user access token for API access
|
||||||
|
timeout: Maximum time in seconds to wait for the result (default: 60.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The solvability summary as a string
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If LLM settings cannot be found for the user
|
||||||
|
asyncio.TimeoutError: If the operation exceeds the specified timeout
|
||||||
|
"""
|
||||||
|
if not ENABLE_SOLVABILITY_ANALYSIS:
|
||||||
|
raise ValueError('Solvability report feature is disabled')
|
||||||
|
|
||||||
|
if github_view.user_info.keycloak_user_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f'[Solvability] No user ID found for user {github_view.user_info.username}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Grab the user's information so we can load their LLM configuration
|
||||||
|
store = SaasSettingsStore(
|
||||||
|
user_id=github_view.user_info.keycloak_user_id,
|
||||||
|
session_maker=session_maker,
|
||||||
|
config=get_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
user_settings = await store.load()
|
||||||
|
|
||||||
|
if user_settings is None:
|
||||||
|
raise ValueError(
|
||||||
|
f'[Solvability] No user settings found for user ID {github_view.user_info.user_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if solvability analysis is enabled for this user, exit early if
|
||||||
|
# needed
|
||||||
|
if not getattr(user_settings, 'enable_solvability_analysis', False):
|
||||||
|
raise ValueError(
|
||||||
|
f'Solvability analysis disabled for user {github_view.user_info.user_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm_config = LLMConfig(
|
||||||
|
model=user_settings.llm_model,
|
||||||
|
api_key=user_settings.llm_api_key.get_secret_value(),
|
||||||
|
base_url=user_settings.llm_base_url,
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f'[Solvability] Invalid LLM configuration for user {github_view.user_info.user_id}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch the full GitHub issue/PR context using the GitHub API
|
||||||
|
start_time = time.time()
|
||||||
|
issue_context = fetch_github_issue_context(github_view, user_token)
|
||||||
|
logger.info(
|
||||||
|
f'[Solvability] Grabbed issue context for {github_view.conversation_id}',
|
||||||
|
extra={
|
||||||
|
'conversation_id': github_view.conversation_id,
|
||||||
|
'response_latency': time.time() - start_time,
|
||||||
|
'full_repo_name': github_view.full_repo_name,
|
||||||
|
'issue_number': github_view.issue_number,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# For comment-based triggers, also include the specific comment that triggered the action
|
||||||
|
if isinstance(
|
||||||
|
github_view, (GithubIssueComment, GithubPRComment, GithubInlinePRComment)
|
||||||
|
):
|
||||||
|
issue_context += f'\n\nTriggering Comment:\n{github_view.comment_body}'
|
||||||
|
|
||||||
|
solvability_classifier = load_classifier('default-classifier')
|
||||||
|
|
||||||
|
async with asyncio.timeout(timeout):
|
||||||
|
solvability_report: SolvabilityReport = await call_sync_from_async(
|
||||||
|
lambda: solvability_classifier.solvability_report(
|
||||||
|
issue_context, llm_config=llm_config
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Solvability] Generated report for {github_view.conversation_id}',
|
||||||
|
extra={
|
||||||
|
'conversation_id': github_view.conversation_id,
|
||||||
|
'report': solvability_report.model_dump(exclude=['issue']),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_registry, conversation_stats, _ = create_registry_and_conversation_stats(
|
||||||
|
get_config(),
|
||||||
|
github_view.conversation_id,
|
||||||
|
github_view.user_info.keycloak_user_id,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
solvability_summary = await call_sync_from_async(
|
||||||
|
lambda: SolvabilitySummary.from_report(
|
||||||
|
solvability_report,
|
||||||
|
llm=llm_registry.get_llm(
|
||||||
|
service_id='solvability_analysis', config=llm_config
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
conversation_stats.save_metrics()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Solvability] Generated summary for {github_view.conversation_id}',
|
||||||
|
extra={
|
||||||
|
'conversation_id': github_view.conversation_id,
|
||||||
|
'summary': solvability_summary.model_dump(exclude=['content']),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return solvability_summary.format_as_markdown()
|
||||||
26
enterprise/integrations/github/github_types.py
Normal file
26
enterprise/integrations/github/github_types.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunStatus(Enum):
|
||||||
|
FAILURE = 'failure'
|
||||||
|
COMPLETED = 'completed'
|
||||||
|
PENDING = 'pending'
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
if isinstance(other, str):
|
||||||
|
return self.value == other
|
||||||
|
return super().__eq__(other)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRun(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
status: WorkflowRunStatus
|
||||||
|
|
||||||
|
model_config = {'use_enum_values': True}
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowRunGroup(BaseModel):
|
||||||
|
runs: dict[str, WorkflowRun]
|
||||||
756
enterprise/integrations/github/github_view.py
Normal file
756
enterprise/integrations/github/github_view.py
Normal file
@@ -0,0 +1,756 @@
|
|||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from github import Github, GithubIntegration
|
||||||
|
from github.Issue import Issue
|
||||||
|
from integrations.github.github_types import (
|
||||||
|
WorkflowRun,
|
||||||
|
WorkflowRunGroup,
|
||||||
|
WorkflowRunStatus,
|
||||||
|
)
|
||||||
|
from integrations.models import Message
|
||||||
|
from integrations.types import ResolverViewInterface, UserData
|
||||||
|
from integrations.utils import (
|
||||||
|
ENABLE_PROACTIVE_CONVERSATION_STARTERS,
|
||||||
|
HOST,
|
||||||
|
HOST_URL,
|
||||||
|
get_oh_labels,
|
||||||
|
has_exact_mention,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from server.auth.constants import GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
from server.auth.token_manager import TokenManager, get_config
|
||||||
|
from storage.database import session_maker
|
||||||
|
from storage.proactive_conversation_store import ProactiveConversationStore
|
||||||
|
from storage.saas_secrets_store import SaasSecretsStore
|
||||||
|
from storage.user_settings import UserSettings
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.github.github_service import GithubServiceImpl
|
||||||
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||||
|
from openhands.integrations.service_types import Comment
|
||||||
|
from openhands.server.services.conversation_service import (
|
||||||
|
initialize_conversation,
|
||||||
|
start_conversation,
|
||||||
|
)
|
||||||
|
from openhands.storage.data_models.conversation_metadata import (
|
||||||
|
ConversationMetadata,
|
||||||
|
ConversationTrigger,
|
||||||
|
)
|
||||||
|
from openhands.utils.async_utils import call_sync_from_async
|
||||||
|
|
||||||
|
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_proactive_conversation_setting(user_id: str | None) -> bool:
|
||||||
|
"""Get the user's proactive conversation setting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The keycloak user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if proactive conversations are enabled for this user, False otherwise
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function checks both the global environment variable kill switch AND
|
||||||
|
the user's individual setting. Both must be true for the function to return true.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If no user ID is provided, we can't check user settings
|
||||||
|
if not user_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_setting():
|
||||||
|
with session_maker() as session:
|
||||||
|
settings = (
|
||||||
|
session.query(UserSettings)
|
||||||
|
.filter(UserSettings.keycloak_user_id == user_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not settings or settings.enable_proactive_conversation_starters is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return settings.enable_proactive_conversation_starters
|
||||||
|
|
||||||
|
return await call_sync_from_async(_get_setting)
|
||||||
|
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# SECTION: Github view types
|
||||||
|
# =================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GithubIssue(ResolverViewInterface):
|
||||||
|
issue_number: int
|
||||||
|
installation_id: int
|
||||||
|
full_repo_name: str
|
||||||
|
is_public_repo: bool
|
||||||
|
user_info: UserData
|
||||||
|
raw_payload: Message
|
||||||
|
conversation_id: str
|
||||||
|
uuid: str | None
|
||||||
|
should_extract: bool
|
||||||
|
send_summary_instruction: bool
|
||||||
|
title: str
|
||||||
|
description: str
|
||||||
|
previous_comments: list[Comment]
|
||||||
|
|
||||||
|
async def _load_resolver_context(self):
|
||||||
|
github_service = GithubServiceImpl(
|
||||||
|
external_auth_id=self.user_info.keycloak_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.previous_comments = await github_service.get_issue_or_pr_comments(
|
||||||
|
self.full_repo_name, self.issue_number
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
self.title,
|
||||||
|
self.description,
|
||||||
|
) = await github_service.get_issue_or_pr_title_and_body(
|
||||||
|
self.full_repo_name, self.issue_number
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
issue_number=self.issue_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'issue_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
issue_title=self.title,
|
||||||
|
issue_body=self.description,
|
||||||
|
previous_comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
async def _get_user_secrets(self):
|
||||||
|
secrets_store = SaasSecretsStore(
|
||||||
|
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||||
|
)
|
||||||
|
user_secrets = await secrets_store.load()
|
||||||
|
|
||||||
|
return user_secrets.custom_secrets if user_secrets else None
|
||||||
|
|
||||||
|
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||||
|
# FIXME: Handle if initialize_conversation returns None
|
||||||
|
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||||
|
user_id=self.user_info.keycloak_user_id,
|
||||||
|
conversation_id=None,
|
||||||
|
selected_repository=self.full_repo_name,
|
||||||
|
selected_branch=None,
|
||||||
|
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||||
|
git_provider=ProviderType.GITHUB,
|
||||||
|
)
|
||||||
|
self.conversation_id = conversation_metadata.conversation_id
|
||||||
|
return conversation_metadata
|
||||||
|
|
||||||
|
async def create_new_conversation(
|
||||||
|
self,
|
||||||
|
jinja_env: Environment,
|
||||||
|
git_provider_tokens: PROVIDER_TOKEN_TYPE,
|
||||||
|
conversation_metadata: ConversationMetadata,
|
||||||
|
):
|
||||||
|
custom_secrets = await self._get_user_secrets()
|
||||||
|
|
||||||
|
user_instructions, conversation_instructions = await self._get_instructions(
|
||||||
|
jinja_env
|
||||||
|
)
|
||||||
|
|
||||||
|
await start_conversation(
|
||||||
|
user_id=self.user_info.keycloak_user_id,
|
||||||
|
git_provider_tokens=git_provider_tokens,
|
||||||
|
custom_secrets=custom_secrets,
|
||||||
|
initial_user_msg=user_instructions,
|
||||||
|
image_urls=None,
|
||||||
|
replay_json=None,
|
||||||
|
conversation_id=conversation_metadata.conversation_id,
|
||||||
|
conversation_metadata=conversation_metadata,
|
||||||
|
conversation_instructions=conversation_instructions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GithubIssueComment(GithubIssue):
|
||||||
|
comment_body: str
|
||||||
|
comment_id: int
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||||
|
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
issue_comment=self.comment_body
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'issue_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
issue_number=self.issue_number,
|
||||||
|
issue_title=self.title,
|
||||||
|
issue_body=self.description,
|
||||||
|
previous_comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GithubPRComment(GithubIssueComment):
|
||||||
|
branch_name: str
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('pr_update_prompt.j2')
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
pr_comment=self.comment_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'pr_update_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
pr_number=self.issue_number,
|
||||||
|
branch_name=self.branch_name,
|
||||||
|
pr_title=self.title,
|
||||||
|
pr_body=self.description,
|
||||||
|
comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
async def initialize_new_conversation(self) -> ConversationMetadata:
|
||||||
|
# FIXME: Handle if initialize_conversation returns None
|
||||||
|
conversation_metadata: ConversationMetadata = await initialize_conversation( # type: ignore[assignment]
|
||||||
|
user_id=self.user_info.keycloak_user_id,
|
||||||
|
conversation_id=None,
|
||||||
|
selected_repository=self.full_repo_name,
|
||||||
|
selected_branch=self.branch_name,
|
||||||
|
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||||
|
git_provider=ProviderType.GITHUB,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_id = conversation_metadata.conversation_id
|
||||||
|
return conversation_metadata
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GithubInlinePRComment(GithubPRComment):
|
||||||
|
file_location: str
|
||||||
|
line_number: int
|
||||||
|
comment_node_id: str
|
||||||
|
|
||||||
|
async def _load_resolver_context(self):
|
||||||
|
github_service = GithubServiceImpl(
|
||||||
|
external_auth_id=self.user_info.keycloak_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
self.title,
|
||||||
|
self.description,
|
||||||
|
) = await github_service.get_issue_or_pr_title_and_body(
|
||||||
|
self.full_repo_name, self.issue_number
|
||||||
|
)
|
||||||
|
|
||||||
|
self.previous_comments = await github_service.get_review_thread_comments(
|
||||||
|
self.comment_node_id, self.full_repo_name, self.issue_number
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('pr_update_prompt.j2')
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
pr_comment=self.comment_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'pr_update_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
pr_number=self.issue_number,
|
||||||
|
pr_title=self.title,
|
||||||
|
pr_body=self.description,
|
||||||
|
branch_name=self.branch_name,
|
||||||
|
file_location=self.file_location,
|
||||||
|
line_number=self.line_number,
|
||||||
|
comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GithubFailingAction:
|
||||||
|
unqiue_suggestions_header: str = (
|
||||||
|
'Looks like there are a few issues preventing this PR from being merged!'
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_latest_sha(pr: Issue) -> str:
|
||||||
|
pr_obj = pr.as_pull_request()
|
||||||
|
return pr_obj.head.sha
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_retrieve_workflows_callback(pr: Issue, head_sha: str):
|
||||||
|
def get_all_workflows():
|
||||||
|
repo = pr.repository
|
||||||
|
workflows = repo.get_workflow_runs(head_sha=head_sha)
|
||||||
|
|
||||||
|
runs = {}
|
||||||
|
|
||||||
|
for workflow in workflows:
|
||||||
|
conclusion = workflow.conclusion
|
||||||
|
workflow_conclusion = WorkflowRunStatus.COMPLETED
|
||||||
|
if conclusion is None:
|
||||||
|
workflow_conclusion = WorkflowRunStatus.PENDING # type: ignore[unreachable]
|
||||||
|
elif conclusion == WorkflowRunStatus.FAILURE.value:
|
||||||
|
workflow_conclusion = WorkflowRunStatus.FAILURE
|
||||||
|
|
||||||
|
runs[str(workflow.id)] = WorkflowRun(
|
||||||
|
id=str(workflow.id), name=workflow.name, status=workflow_conclusion
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowRunGroup(runs=runs)
|
||||||
|
|
||||||
|
return get_all_workflows
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete_old_comment_if_exists(pr: Issue):
|
||||||
|
paginated_comments = pr.get_comments()
|
||||||
|
for page in range(paginated_comments.totalCount):
|
||||||
|
comments = paginated_comments.get_page(page)
|
||||||
|
for comment in comments:
|
||||||
|
if GithubFailingAction.unqiue_suggestions_header in comment.body:
|
||||||
|
comment.delete()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_suggestions(
|
||||||
|
failed_jobs: dict, pr_number: int, branch_name: str | None = None
|
||||||
|
) -> str:
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Collect failing actions with their specific names
|
||||||
|
if failed_jobs['actions']:
|
||||||
|
failing_actions = failed_jobs['actions']
|
||||||
|
issues.append(('GitHub Actions are failing:', False))
|
||||||
|
for action in failing_actions:
|
||||||
|
issues.append((action, True))
|
||||||
|
|
||||||
|
if any(failed_jobs['merge conflict']):
|
||||||
|
issues.append(('There are merge conflicts', False))
|
||||||
|
|
||||||
|
# Format each line with proper indentation and dashes
|
||||||
|
formatted_issues = []
|
||||||
|
for issue, is_nested in issues:
|
||||||
|
if is_nested:
|
||||||
|
formatted_issues.append(f' - {issue}')
|
||||||
|
else:
|
||||||
|
formatted_issues.append(f'- {issue}')
|
||||||
|
issues_text = '\n'.join(formatted_issues)
|
||||||
|
|
||||||
|
# Build list of possible suggestions based on actual issues
|
||||||
|
suggestions = []
|
||||||
|
branch_info = f' at branch `{branch_name}`' if branch_name else ''
|
||||||
|
|
||||||
|
if any(failed_jobs['merge conflict']):
|
||||||
|
suggestions.append(
|
||||||
|
f'@OpenHands please fix the merge conflicts on PR #{pr_number}{branch_info}'
|
||||||
|
)
|
||||||
|
if any(failed_jobs['actions']):
|
||||||
|
suggestions.append(
|
||||||
|
f'@OpenHands please fix the failing actions on PR #{pr_number}{branch_info}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Take at most 2 suggestions
|
||||||
|
suggestions = suggestions[:2]
|
||||||
|
|
||||||
|
help_text = """If you'd like me to help, just leave a comment, like
|
||||||
|
|
||||||
|
```
|
||||||
|
{}
|
||||||
|
```
|
||||||
|
|
||||||
|
Feel free to include any additional details that might help me get this PR into a better state.
|
||||||
|
|
||||||
|
<sub><sup>You can manage your notification [settings]({})</sup></sub>""".format(
|
||||||
|
'\n```\n\nor\n\n```\n'.join(suggestions), f'{HOST_URL}/settings/app'
|
||||||
|
)
|
||||||
|
|
||||||
|
return f'{GithubFailingAction.unqiue_suggestions_header}\n\n{issues_text}\n\n{help_text}'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def leave_requesting_comment(pr: Issue, failed_runs: WorkflowRunGroup):
|
||||||
|
failed_jobs: dict = {'actions': [], 'merge conflict': []}
|
||||||
|
|
||||||
|
pr_obj = pr.as_pull_request()
|
||||||
|
if not pr_obj.mergeable:
|
||||||
|
failed_jobs['merge conflict'].append('Merge conflict detected')
|
||||||
|
|
||||||
|
for _, workflow_run in failed_runs.runs.items():
|
||||||
|
if workflow_run.status == WorkflowRunStatus.FAILURE:
|
||||||
|
failed_jobs['actions'].append(workflow_run.name)
|
||||||
|
|
||||||
|
logger.info(f'[GitHub] Found failing jobs for PR #{pr.number}: {failed_jobs}')
|
||||||
|
|
||||||
|
# Get the branch name
|
||||||
|
branch_name = pr_obj.head.ref
|
||||||
|
|
||||||
|
# Get suggestions with branch name included
|
||||||
|
suggestions = GithubFailingAction.get_suggestions(
|
||||||
|
failed_jobs, pr.number, branch_name
|
||||||
|
)
|
||||||
|
|
||||||
|
GithubFailingAction.delete_old_comment_if_exists(pr)
|
||||||
|
pr.create_comment(suggestions)
|
||||||
|
|
||||||
|
|
||||||
|
GithubViewType = (
|
||||||
|
GithubInlinePRComment | GithubPRComment | GithubIssueComment | GithubIssue
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# SECTION: Factory to create appriorate Github view
|
||||||
|
# =================================================
|
||||||
|
|
||||||
|
|
||||||
|
class GithubFactory:
|
||||||
|
@staticmethod
|
||||||
|
def is_labeled_issue(message: Message):
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
action = payload.get('action', '')
|
||||||
|
|
||||||
|
if action == 'labeled' and 'label' in payload and 'issue' in payload:
|
||||||
|
label_name = payload['label'].get('name', '')
|
||||||
|
if label_name == OH_LABEL:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_issue_comment(message: Message):
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
action = payload.get('action', '')
|
||||||
|
|
||||||
|
if (
|
||||||
|
action == 'created'
|
||||||
|
and 'comment' in payload
|
||||||
|
and 'issue' in payload
|
||||||
|
and 'pull_request' not in payload['issue']
|
||||||
|
):
|
||||||
|
comment_body = payload['comment']['body']
|
||||||
|
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_pr_comment(message: Message):
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
action = payload.get('action', '')
|
||||||
|
|
||||||
|
if (
|
||||||
|
action == 'created'
|
||||||
|
and 'comment' in payload
|
||||||
|
and 'issue' in payload
|
||||||
|
and 'pull_request' in payload['issue']
|
||||||
|
):
|
||||||
|
comment_body = payload['comment'].get('body', '')
|
||||||
|
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_inline_pr_comment(message: Message):
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
action = payload.get('action', '')
|
||||||
|
|
||||||
|
if action == 'created' and 'comment' in payload and 'pull_request' in payload:
|
||||||
|
comment_body = payload['comment'].get('body', '')
|
||||||
|
if has_exact_mention(comment_body, INLINE_OH_LABEL):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_eligible_for_conversation_starter(message: Message):
|
||||||
|
if not ENABLE_PROACTIVE_CONVERSATION_STARTERS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
action = payload.get('action', '')
|
||||||
|
|
||||||
|
if not (action == 'completed' and 'workflow_run' in payload):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def trigger_conversation_starter(message: Message):
|
||||||
|
"""Trigger a conversation starter when a workflow fails.
|
||||||
|
|
||||||
|
This is the updated version that checks user settings.
|
||||||
|
"""
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
workflow_payload = payload['workflow_run']
|
||||||
|
status = WorkflowRunStatus.COMPLETED
|
||||||
|
|
||||||
|
if workflow_payload['conclusion'] == 'failure':
|
||||||
|
status = WorkflowRunStatus.FAILURE
|
||||||
|
elif workflow_payload['conclusion'] is None:
|
||||||
|
status = WorkflowRunStatus.PENDING
|
||||||
|
|
||||||
|
workflow_run = WorkflowRun(
|
||||||
|
id=str(workflow_payload['id']), name=workflow_payload['name'], status=status
|
||||||
|
)
|
||||||
|
|
||||||
|
selected_repo = GithubFactory.get_full_repo_name(payload['repository'])
|
||||||
|
head_branch = payload['workflow_run']['head_branch']
|
||||||
|
|
||||||
|
# Get the user ID to check their settings
|
||||||
|
user_id = None
|
||||||
|
try:
|
||||||
|
sender_id = payload['sender']['id']
|
||||||
|
token_manager = TokenManager()
|
||||||
|
user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||||
|
sender_id, ProviderType.GITHUB
|
||||||
|
)
|
||||||
|
except (KeyError, Exception) as e:
|
||||||
|
logger.warning(
|
||||||
|
f'Failed to get user ID for proactive conversation check: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if proactive conversations are enabled for this user
|
||||||
|
if not await get_user_proactive_conversation_setting(user_id):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _interact_with_github() -> Issue | None:
|
||||||
|
with GithubIntegration(
|
||||||
|
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
) as integration:
|
||||||
|
access_token = integration.get_access_token(
|
||||||
|
payload['installation']['id']
|
||||||
|
).token
|
||||||
|
|
||||||
|
with Github(access_token) as gh:
|
||||||
|
repo = gh.get_repo(selected_repo)
|
||||||
|
login = (
|
||||||
|
payload['organization']['login']
|
||||||
|
if 'organization' in payload
|
||||||
|
else payload['sender']['login']
|
||||||
|
)
|
||||||
|
|
||||||
|
# See if a pull request is open
|
||||||
|
open_pulls = repo.get_pulls(state='open', head=f'{login}:{head_branch}')
|
||||||
|
if open_pulls.totalCount > 0:
|
||||||
|
prs = open_pulls.get_page(0)
|
||||||
|
relevant_pr = prs[0]
|
||||||
|
issue = repo.get_issue(number=relevant_pr.number)
|
||||||
|
return issue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue: Issue | None = await call_sync_from_async(_interact_with_github)
|
||||||
|
if not issue:
|
||||||
|
return False
|
||||||
|
|
||||||
|
incoming_commit = payload['workflow_run']['head_sha']
|
||||||
|
latest_sha = GithubFailingAction.get_latest_sha(issue)
|
||||||
|
if latest_sha != incoming_commit:
|
||||||
|
# Return as this commit is not the latest
|
||||||
|
return False
|
||||||
|
|
||||||
|
convo_store = ProactiveConversationStore()
|
||||||
|
workflow_group = await convo_store.store_workflow_information(
|
||||||
|
provider=ProviderType.GITHUB,
|
||||||
|
repo_id=payload['repository']['id'],
|
||||||
|
incoming_commit=incoming_commit,
|
||||||
|
workflow=workflow_run,
|
||||||
|
pr_number=issue.number,
|
||||||
|
get_all_workflows=GithubFailingAction.create_retrieve_workflows_callback(
|
||||||
|
issue, incoming_commit
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not workflow_group:
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Workflow completed for {selected_repo}#{issue.number} on branch {head_branch}'
|
||||||
|
)
|
||||||
|
GithubFailingAction.leave_requesting_comment(issue, workflow_group)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_full_repo_name(repo_obj: dict) -> str:
|
||||||
|
owner = repo_obj['owner']['login']
|
||||||
|
repo_name = repo_obj['name']
|
||||||
|
return f'{owner}/{repo_name}'
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_github_view_from_payload(
|
||||||
|
message: Message, token_manager: TokenManager
|
||||||
|
) -> ResolverViewInterface:
|
||||||
|
"""Create the appropriate class (GithubIssue or GithubPRComment) based on the payload.
|
||||||
|
Also return metadata about the event (e.g., action type).
|
||||||
|
"""
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
repo_obj = payload['repository']
|
||||||
|
user_id = payload['sender']['id']
|
||||||
|
username = payload['sender']['login']
|
||||||
|
|
||||||
|
keyloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||||
|
user_id, ProviderType.GITHUB
|
||||||
|
)
|
||||||
|
|
||||||
|
if keyloak_user_id is None:
|
||||||
|
logger.warning(f'Got invalid keyloak user id for GitHub User {user_id} ')
|
||||||
|
|
||||||
|
selected_repo = GithubFactory.get_full_repo_name(repo_obj)
|
||||||
|
is_public_repo = not repo_obj.get('private', True)
|
||||||
|
user_info = UserData(
|
||||||
|
user_id=user_id, username=username, keycloak_user_id=keyloak_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
installation_id = message.message['installation']
|
||||||
|
|
||||||
|
if GithubFactory.is_labeled_issue(message):
|
||||||
|
issue_number = payload['issue']['number']
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Creating view for labeled issue from {username} in {selected_repo}#{issue_number}'
|
||||||
|
)
|
||||||
|
return GithubIssue(
|
||||||
|
issue_number=issue_number,
|
||||||
|
installation_id=installation_id,
|
||||||
|
full_repo_name=selected_repo,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
raw_payload=message,
|
||||||
|
user_info=user_info,
|
||||||
|
conversation_id='',
|
||||||
|
uuid=str(uuid4()),
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif GithubFactory.is_issue_comment(message):
|
||||||
|
issue_number = payload['issue']['number']
|
||||||
|
comment_body = payload['comment']['body']
|
||||||
|
comment_id = payload['comment']['id']
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Creating view for issue comment from {username} in {selected_repo}#{issue_number}'
|
||||||
|
)
|
||||||
|
return GithubIssueComment(
|
||||||
|
issue_number=issue_number,
|
||||||
|
comment_body=comment_body,
|
||||||
|
comment_id=comment_id,
|
||||||
|
installation_id=installation_id,
|
||||||
|
full_repo_name=selected_repo,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
raw_payload=message,
|
||||||
|
user_info=user_info,
|
||||||
|
conversation_id='',
|
||||||
|
uuid=None,
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif GithubFactory.is_pr_comment(message):
|
||||||
|
issue_number = payload['issue']['number']
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Creating view for PR comment from {username} in {selected_repo}#{issue_number}'
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token = ''
|
||||||
|
with GithubIntegration(
|
||||||
|
GITHUB_APP_CLIENT_ID, GITHUB_APP_PRIVATE_KEY
|
||||||
|
) as integration:
|
||||||
|
access_token = integration.get_access_token(installation_id).token
|
||||||
|
|
||||||
|
head_ref = None
|
||||||
|
with Github(access_token) as gh:
|
||||||
|
repo = gh.get_repo(selected_repo)
|
||||||
|
pull_request = repo.get_pull(issue_number)
|
||||||
|
head_ref = pull_request.head.ref
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Found PR branch {head_ref} for {selected_repo}#{issue_number}'
|
||||||
|
)
|
||||||
|
|
||||||
|
comment_id = payload['comment']['id']
|
||||||
|
return GithubPRComment(
|
||||||
|
issue_number=issue_number,
|
||||||
|
branch_name=head_ref,
|
||||||
|
comment_body=payload['comment']['body'],
|
||||||
|
comment_id=comment_id,
|
||||||
|
installation_id=installation_id,
|
||||||
|
full_repo_name=selected_repo,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
raw_payload=message,
|
||||||
|
user_info=user_info,
|
||||||
|
conversation_id='',
|
||||||
|
uuid=None,
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
elif GithubFactory.is_inline_pr_comment(message):
|
||||||
|
pr_number = payload['pull_request']['number']
|
||||||
|
branch_name = payload['pull_request']['head']['ref']
|
||||||
|
comment_id = payload['comment']['id']
|
||||||
|
comment_node_id = payload['comment']['node_id']
|
||||||
|
file_path = payload['comment']['path']
|
||||||
|
line_number = payload['comment']['line']
|
||||||
|
logger.info(
|
||||||
|
f'[GitHub] Creating view for inline PR comment from {username} in {selected_repo}#{pr_number} at {file_path}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return GithubInlinePRComment(
|
||||||
|
issue_number=pr_number,
|
||||||
|
branch_name=branch_name,
|
||||||
|
comment_body=payload['comment']['body'],
|
||||||
|
comment_node_id=comment_node_id,
|
||||||
|
comment_id=comment_id,
|
||||||
|
file_location=file_path,
|
||||||
|
line_number=line_number,
|
||||||
|
installation_id=installation_id,
|
||||||
|
full_repo_name=selected_repo,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
raw_payload=message,
|
||||||
|
user_info=user_info,
|
||||||
|
conversation_id='',
|
||||||
|
uuid=None,
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid payload: must contain either 'issue' or 'pull_request'"
|
||||||
|
)
|
||||||
102
enterprise/integrations/github/queries.py
Normal file
102
enterprise/integrations/github/queries.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
PR_QUERY_BY_NODE_ID = """
|
||||||
|
query($nodeId: ID!, $pr_number: Int!, $commits_after: String, $comments_after: String, $reviews_after: String) {
|
||||||
|
node(id: $nodeId) {
|
||||||
|
... on Repository {
|
||||||
|
name
|
||||||
|
owner {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
languages(first: 10, orderBy: {field: SIZE, direction: DESC}) {
|
||||||
|
nodes {
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pullRequest(number: $pr_number) {
|
||||||
|
number
|
||||||
|
title
|
||||||
|
body
|
||||||
|
author {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
merged
|
||||||
|
mergedAt
|
||||||
|
mergedBy {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
state
|
||||||
|
mergeCommit {
|
||||||
|
oid
|
||||||
|
}
|
||||||
|
comments(first: 50, after: $comments_after) {
|
||||||
|
pageInfo {
|
||||||
|
hasNextPage
|
||||||
|
endCursor
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
author {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
body
|
||||||
|
createdAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
commits(first: 50, after: $commits_after) {
|
||||||
|
pageInfo {
|
||||||
|
hasNextPage
|
||||||
|
endCursor
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
commit {
|
||||||
|
oid
|
||||||
|
message
|
||||||
|
committedDate
|
||||||
|
author {
|
||||||
|
name
|
||||||
|
email
|
||||||
|
user {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
}
|
||||||
|
additions
|
||||||
|
deletions
|
||||||
|
changedFiles
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reviews(first: 50, after: $reviews_after) {
|
||||||
|
pageInfo {
|
||||||
|
hasNextPage
|
||||||
|
endCursor
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
author {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
body
|
||||||
|
state
|
||||||
|
createdAt
|
||||||
|
comments(first: 50) {
|
||||||
|
pageInfo {
|
||||||
|
hasNextPage
|
||||||
|
endCursor
|
||||||
|
}
|
||||||
|
nodes {
|
||||||
|
author {
|
||||||
|
login
|
||||||
|
}
|
||||||
|
body
|
||||||
|
createdAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rateLimit {
|
||||||
|
remaining
|
||||||
|
limit
|
||||||
|
resetAt
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
249
enterprise/integrations/gitlab/gitlab_manager.py
Normal file
249
enterprise/integrations/gitlab/gitlab_manager.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
from types import MappingProxyType
|
||||||
|
|
||||||
|
from integrations.gitlab.gitlab_view import (
|
||||||
|
GitlabFactory,
|
||||||
|
GitlabInlineMRComment,
|
||||||
|
GitlabIssue,
|
||||||
|
GitlabIssueComment,
|
||||||
|
GitlabMRComment,
|
||||||
|
GitlabViewType,
|
||||||
|
)
|
||||||
|
from integrations.manager import Manager
|
||||||
|
from integrations.models import Message, SourceType
|
||||||
|
from integrations.types import ResolverViewInterface
|
||||||
|
from integrations.utils import (
|
||||||
|
CONVERSATION_URL,
|
||||||
|
HOST_URL,
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from server.utils.conversation_callback_utils import register_callback_processor
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||||
|
from openhands.integrations.provider import ProviderToken, ProviderType
|
||||||
|
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||||
|
from openhands.storage.data_models.user_secrets import UserSecrets
|
||||||
|
|
||||||
|
|
||||||
|
class GitlabManager(Manager):
|
||||||
|
def __init__(self, token_manager: TokenManager, data_collector: None = None):
|
||||||
|
self.token_manager = token_manager
|
||||||
|
|
||||||
|
self.jinja_env = Environment(
|
||||||
|
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'gitlab')
|
||||||
|
)
|
||||||
|
|
||||||
|
def _confirm_incoming_source_type(self, message: Message):
|
||||||
|
if message.source != SourceType.GITLAB:
|
||||||
|
raise ValueError(f'Unexpected message source {message.source}')
|
||||||
|
|
||||||
|
async def _user_has_write_access_to_repo(
|
||||||
|
self, project_id: str, user_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the user has write access to the repository (can pull/push changes and open merge requests).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_id: The ID of the GitLab project
|
||||||
|
username: The username of the user
|
||||||
|
user_id: The GitLab user ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the user has write access to the repository, False otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
keycloak_user_id = await self.token_manager.get_user_id_from_idp_user_id(
|
||||||
|
user_id, ProviderType.GITLAB
|
||||||
|
)
|
||||||
|
if keycloak_user_id is None:
|
||||||
|
logger.warning(f'Got invalid keyloak user id for GitLab User {user_id}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
gitlab_service = GitLabServiceImpl(external_auth_id=keycloak_user_id)
|
||||||
|
return await gitlab_service.user_has_write_access(project_id)
|
||||||
|
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
self._confirm_incoming_source_type(message)
|
||||||
|
if await self.is_job_requested(message):
|
||||||
|
gitlab_view = await GitlabFactory.create_gitlab_view_from_payload(
|
||||||
|
message, self.token_manager
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Creating job for {gitlab_view.user_info.username} in {gitlab_view.full_repo_name}#{gitlab_view.issue_number}'
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.start_job(gitlab_view)
|
||||||
|
|
||||||
|
async def is_job_requested(self, message) -> bool:
|
||||||
|
self._confirm_incoming_source_type(message)
|
||||||
|
if not (
|
||||||
|
GitlabFactory.is_labeled_issue(message)
|
||||||
|
or GitlabFactory.is_issue_comment(message)
|
||||||
|
or GitlabFactory.is_mr_comment(message)
|
||||||
|
or GitlabFactory.is_mr_comment(message, inline=True)
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
payload = message.message['payload']
|
||||||
|
|
||||||
|
repo_obj = payload['project']
|
||||||
|
project_id = repo_obj['id']
|
||||||
|
selected_project = repo_obj['path_with_namespace']
|
||||||
|
user = payload['user']
|
||||||
|
user_id = user['id']
|
||||||
|
username = user['username']
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Checking permissions for {username} in {selected_project}'
|
||||||
|
)
|
||||||
|
|
||||||
|
has_write_access = await self._user_has_write_access_to_repo(
|
||||||
|
project_id=str(project_id), user_id=user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab]: {username} access in {selected_project}: {has_write_access}'
|
||||||
|
)
|
||||||
|
# Check if the user has write access to the repository
|
||||||
|
return has_write_access
|
||||||
|
|
||||||
|
async def send_message(self, message: Message, gitlab_view: ResolverViewInterface):
|
||||||
|
"""
|
||||||
|
Send a message to GitLab based on the view type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to send
|
||||||
|
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||||
|
"""
|
||||||
|
keycloak_user_id = gitlab_view.user_info.keycloak_user_id
|
||||||
|
gitlab_service = GitLabServiceImpl(external_auth_id=keycloak_user_id)
|
||||||
|
|
||||||
|
outgoing_message = message.message
|
||||||
|
|
||||||
|
if isinstance(gitlab_view, GitlabInlineMRComment) or isinstance(
|
||||||
|
gitlab_view, GitlabMRComment
|
||||||
|
):
|
||||||
|
await gitlab_service.reply_to_mr(
|
||||||
|
gitlab_view.project_id,
|
||||||
|
gitlab_view.issue_number,
|
||||||
|
gitlab_view.discussion_id,
|
||||||
|
message.message,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(gitlab_view, GitlabIssueComment):
|
||||||
|
await gitlab_service.reply_to_issue(
|
||||||
|
gitlab_view.project_id,
|
||||||
|
gitlab_view.issue_number,
|
||||||
|
gitlab_view.discussion_id,
|
||||||
|
outgoing_message,
|
||||||
|
)
|
||||||
|
elif isinstance(gitlab_view, GitlabIssue):
|
||||||
|
await gitlab_service.reply_to_issue(
|
||||||
|
gitlab_view.project_id,
|
||||||
|
gitlab_view.issue_number,
|
||||||
|
None, # no discussion id, issue is tagged
|
||||||
|
outgoing_message,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitLab] Unsupported view type: {type(gitlab_view).__name__}'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def start_job(self, gitlab_view: GitlabViewType):
|
||||||
|
"""
|
||||||
|
Start a job for the GitLab view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gitlab_view: The GitLab view object containing issue/PR/comment info
|
||||||
|
"""
|
||||||
|
# Importing here prevents circular import
|
||||||
|
from server.conversation_callback_processor.gitlab_callback_processor import (
|
||||||
|
GitlabCallbackProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
user_info = gitlab_view.user_info
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Starting job for {user_info.username} in {gitlab_view.full_repo_name}#{gitlab_view.issue_number}'
|
||||||
|
)
|
||||||
|
|
||||||
|
user_token = await self.token_manager.get_idp_token_from_idp_user_id(
|
||||||
|
str(user_info.user_id), ProviderType.GITLAB
|
||||||
|
)
|
||||||
|
|
||||||
|
if not user_token:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitLab] No token found for user {user_info.username} (id={user_info.user_id})'
|
||||||
|
)
|
||||||
|
raise MissingSettingsError('Missing settings')
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Creating new conversation for user {user_info.username}'
|
||||||
|
)
|
||||||
|
|
||||||
|
secret_store = UserSecrets(
|
||||||
|
provider_tokens=MappingProxyType(
|
||||||
|
{
|
||||||
|
ProviderType.GITLAB: ProviderToken(
|
||||||
|
token=SecretStr(user_token),
|
||||||
|
user_id=str(user_info.user_id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await gitlab_view.create_new_conversation(
|
||||||
|
self.jinja_env, secret_store.provider_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_id = gitlab_view.conversation_id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Created conversation {conversation_id} for user {user_info.username}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a GitlabCallbackProcessor for this conversation
|
||||||
|
processor = GitlabCallbackProcessor(
|
||||||
|
gitlab_view=gitlab_view,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback processor
|
||||||
|
register_callback_processor(conversation_id, processor)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Created callback processor for conversation {conversation_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||||
|
msg_info = f"I'm on it! {user_info.username} can [track my progress at all-hands.dev]({conversation_link})"
|
||||||
|
|
||||||
|
except MissingSettingsError as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitLab] Missing settings error for user {user_info.username}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = f'@{user_info.username} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except LLMAuthenticationError as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[GitLab] LLM authentication error for user {user_info.username}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = f'@{user_info.username} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
# Send the acknowledgment message
|
||||||
|
msg = self.create_outgoing_message(msg_info)
|
||||||
|
await self.send_message(msg, gitlab_view)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'[GitLab] Error starting job: {str(e)}')
|
||||||
|
msg = self.create_outgoing_message(
|
||||||
|
msg='Uh oh! There was an unexpected error starting the job :('
|
||||||
|
)
|
||||||
|
await self.send_message(msg, gitlab_view)
|
||||||
529
enterprise/integrations/gitlab/gitlab_service.py
Normal file
529
enterprise/integrations/gitlab/gitlab_service.py
Normal file
@@ -0,0 +1,529 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from integrations.types import GitLabResourceType
|
||||||
|
from integrations.utils import store_repositories_in_db
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from storage.gitlab_webhook import GitlabWebhook, WebhookStatus
|
||||||
|
from storage.gitlab_webhook_store import GitlabWebhookStore
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.gitlab.gitlab_service import GitLabService
|
||||||
|
from openhands.integrations.service_types import (
|
||||||
|
ProviderType,
|
||||||
|
RateLimitError,
|
||||||
|
Repository,
|
||||||
|
RequestMethod,
|
||||||
|
)
|
||||||
|
from openhands.server.types import AppMode
|
||||||
|
|
||||||
|
|
||||||
|
class SaaSGitLabService(GitLabService):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str | None = None,
|
||||||
|
external_auth_token: SecretStr | None = None,
|
||||||
|
external_auth_id: str | None = None,
|
||||||
|
token: SecretStr | None = None,
|
||||||
|
external_token_manager: bool = False,
|
||||||
|
base_domain: str | None = None,
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
f'SaaSGitLabService created with user_id {user_id}, external_auth_id {external_auth_id}, external_auth_token {'set' if external_auth_token else 'None'}, gitlab_token {'set' if token else 'None'}, external_token_manager {external_token_manager}'
|
||||||
|
)
|
||||||
|
super().__init__(
|
||||||
|
user_id=user_id,
|
||||||
|
external_auth_token=external_auth_token,
|
||||||
|
external_auth_id=external_auth_id,
|
||||||
|
token=token,
|
||||||
|
external_token_manager=external_token_manager,
|
||||||
|
base_domain=base_domain,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.external_auth_token = external_auth_token
|
||||||
|
self.external_auth_id = external_auth_id
|
||||||
|
self.token_manager = TokenManager(external=external_token_manager)
|
||||||
|
|
||||||
|
async def get_latest_token(self) -> SecretStr | None:
|
||||||
|
gitlab_token = None
|
||||||
|
if self.external_auth_token:
|
||||||
|
gitlab_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token(
|
||||||
|
self.external_auth_token.get_secret_value(), idp=ProviderType.GITLAB
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got GitLab token {gitlab_token} from access token: {self.external_auth_token}'
|
||||||
|
)
|
||||||
|
elif self.external_auth_id:
|
||||||
|
offline_token = await self.token_manager.load_offline_token(
|
||||||
|
self.external_auth_id
|
||||||
|
)
|
||||||
|
gitlab_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token_from_offline_token(
|
||||||
|
offline_token, ProviderType.GITLAB
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'Got GitLab token {gitlab_token.get_secret_value()} from external auth user ID: {self.external_auth_id}'
|
||||||
|
)
|
||||||
|
elif self.user_id:
|
||||||
|
gitlab_token = SecretStr(
|
||||||
|
await self.token_manager.get_idp_token_from_idp_user_id(
|
||||||
|
self.user_id, ProviderType.GITLAB
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f'Got Gitlab token {gitlab_token} from user ID: {self.user_id}'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning('external_auth_token and user_id not set!')
|
||||||
|
return gitlab_token
|
||||||
|
|
||||||
|
async def get_owned_groups(self) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get all groups for which the current user is the owner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[dict]: A list of groups owned by the current user.
|
||||||
|
"""
|
||||||
|
url = f'{self.BASE_URL}/groups'
|
||||||
|
params = {'owned': 'true', 'per_page': 100, 'top_level_only': 'true'}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response, headers = await self._make_request(url, params)
|
||||||
|
return response
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Error fetching owned groups', exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def add_owned_projects_and_groups_to_db(self, owned_personal_projects):
|
||||||
|
"""
|
||||||
|
Add owned projects and groups to the database for webhook tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
owned_personal_projects: List of personal projects owned by the user
|
||||||
|
"""
|
||||||
|
owned_groups = await self.get_owned_groups()
|
||||||
|
webhooks = []
|
||||||
|
|
||||||
|
def build_group_webhook_entries(groups):
|
||||||
|
return [
|
||||||
|
GitlabWebhook(
|
||||||
|
group_id=str(group['id']),
|
||||||
|
project_id=None,
|
||||||
|
user_id=self.external_auth_id,
|
||||||
|
webhook_exists=False,
|
||||||
|
)
|
||||||
|
for group in groups
|
||||||
|
]
|
||||||
|
|
||||||
|
def build_project_webhook_entries(projects):
|
||||||
|
return [
|
||||||
|
GitlabWebhook(
|
||||||
|
group_id=None,
|
||||||
|
project_id=str(project['id']),
|
||||||
|
user_id=self.external_auth_id,
|
||||||
|
webhook_exists=False,
|
||||||
|
)
|
||||||
|
for project in projects
|
||||||
|
]
|
||||||
|
|
||||||
|
# Collect all webhook entries
|
||||||
|
webhooks.extend(build_group_webhook_entries(owned_groups))
|
||||||
|
webhooks.extend(build_project_webhook_entries(owned_personal_projects))
|
||||||
|
|
||||||
|
# Store webhooks in the database
|
||||||
|
if webhooks:
|
||||||
|
try:
|
||||||
|
webhook_store = GitlabWebhookStore()
|
||||||
|
await webhook_store.store_webhooks(webhooks)
|
||||||
|
logger.info(
|
||||||
|
f'Added GitLab webhooks to db for user {self.external_auth_id}'
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Failed to add Gitlab webhooks to db', exc_info=True)
|
||||||
|
|
||||||
|
async def store_repository_data(
|
||||||
|
self, users_personal_projects: list[dict], repositories: list[Repository]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Store repository data in the database.
|
||||||
|
This function combines the functionality of add_owned_projects_and_groups_to_db and store_repositories_in_db.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
users_personal_projects: List of personal projects owned by the user
|
||||||
|
repositories: List of Repository objects to store
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# First, add owned projects and groups to the database
|
||||||
|
await self.add_owned_projects_and_groups_to_db(users_personal_projects)
|
||||||
|
|
||||||
|
# Then, store repositories in the database
|
||||||
|
await store_repositories_in_db(repositories, self.external_auth_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'Successfully stored repository data for user {self.external_auth_id}'
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Error storing repository data', exc_info=True)
|
||||||
|
|
||||||
|
async def get_all_repositories(
|
||||||
|
self, sort: str, app_mode: AppMode, store_in_background: bool = True
|
||||||
|
) -> list[Repository]:
|
||||||
|
"""
|
||||||
|
Get repositories for the authenticated user, including information about the kind of project.
|
||||||
|
Also collects repositories where the kind is "user" and the user is the owner.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sort: The field to sort repositories by
|
||||||
|
app_mode: The application mode (OSS or SAAS)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Repository]: A list of repositories for the authenticated user
|
||||||
|
"""
|
||||||
|
MAX_REPOS = 1000
|
||||||
|
PER_PAGE = 100 # Maximum allowed by GitLab API
|
||||||
|
all_repos: list[dict] = []
|
||||||
|
users_personal_projects: list[dict] = []
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
url = f'{self.BASE_URL}/projects'
|
||||||
|
# Map GitHub's sort values to GitLab's order_by values
|
||||||
|
order_by = {
|
||||||
|
'pushed': 'last_activity_at',
|
||||||
|
'updated': 'last_activity_at',
|
||||||
|
'created': 'created_at',
|
||||||
|
'full_name': 'name',
|
||||||
|
}.get(sort, 'last_activity_at')
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
try:
|
||||||
|
user_info = await self.get_user()
|
||||||
|
user_id = user_info.id
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f'Could not fetch user id: {e}')
|
||||||
|
|
||||||
|
while len(all_repos) < MAX_REPOS:
|
||||||
|
params = {
|
||||||
|
'page': str(page),
|
||||||
|
'per_page': str(PER_PAGE),
|
||||||
|
'order_by': order_by,
|
||||||
|
'sort': 'desc', # GitLab uses sort for direction (asc/desc)
|
||||||
|
'membership': 1, # Use 1 instead of True
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response, headers = await self._make_request(url, params)
|
||||||
|
|
||||||
|
if not response: # No more repositories
|
||||||
|
break
|
||||||
|
|
||||||
|
# Process each repository to identify user-owned ones
|
||||||
|
for repo in response:
|
||||||
|
namespace = repo.get('namespace', {})
|
||||||
|
kind = namespace.get('kind')
|
||||||
|
owner_id = repo.get('owner', {}).get('id')
|
||||||
|
|
||||||
|
# Collect user owned personal projects
|
||||||
|
if kind == 'user' and str(owner_id) == str(user_id):
|
||||||
|
users_personal_projects.append(repo)
|
||||||
|
|
||||||
|
# Add to all repos regardless
|
||||||
|
all_repos.append(repo)
|
||||||
|
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
# Check if we've reached the last page
|
||||||
|
link_header = headers.get('Link', '')
|
||||||
|
if 'rel="next"' not in link_header:
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
f'Error fetching repositories on page {page}', exc_info=True
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Trim to MAX_REPOS if needed and convert to Repository objects
|
||||||
|
all_repos = all_repos[:MAX_REPOS]
|
||||||
|
repositories = [
|
||||||
|
Repository(
|
||||||
|
id=str(repo.get('id')),
|
||||||
|
full_name=str(repo.get('path_with_namespace')),
|
||||||
|
stargazers_count=repo.get('star_count'),
|
||||||
|
git_provider=ProviderType.GITLAB,
|
||||||
|
is_public=repo.get('visibility') == 'public',
|
||||||
|
)
|
||||||
|
for repo in all_repos
|
||||||
|
]
|
||||||
|
|
||||||
|
# Store webhook and repository info
|
||||||
|
if store_in_background:
|
||||||
|
asyncio.create_task(
|
||||||
|
self.store_repository_data(users_personal_projects, repositories)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.store_repository_data(users_personal_projects, repositories)
|
||||||
|
return repositories
|
||||||
|
|
||||||
|
async def check_resource_exists(
|
||||||
|
self, resource_type: GitLabResourceType, resource_id: str
|
||||||
|
) -> tuple[bool, WebhookStatus | None]:
|
||||||
|
"""
|
||||||
|
Check if resource exists and the user has access to it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_type: The type of resource
|
||||||
|
resource_id: The ID of resource to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, str]: A tuple containing:
|
||||||
|
- bool: True if the resource exists and the user has access to it, False otherwise
|
||||||
|
- str: A reason message explaining the result
|
||||||
|
"""
|
||||||
|
|
||||||
|
if resource_type == GitLabResourceType.GROUP:
|
||||||
|
url = f'{self.BASE_URL}/groups/{resource_id}'
|
||||||
|
else:
|
||||||
|
url = f'{self.BASE_URL}/projects/{resource_id}'
|
||||||
|
|
||||||
|
try:
|
||||||
|
response, _ = await self._make_request(url)
|
||||||
|
# If we get a response, the resource exists and the user has access to it
|
||||||
|
return bool(response and 'id' in response), None
|
||||||
|
except RateLimitError:
|
||||||
|
return False, WebhookStatus.RATE_LIMITED
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Resource existence check failed', exc_info=True)
|
||||||
|
return False, WebhookStatus.INVALID
|
||||||
|
|
||||||
|
async def check_webhook_exists_on_resource(
|
||||||
|
self, resource_type: GitLabResourceType, resource_id: str, webhook_url: str
|
||||||
|
) -> tuple[bool, WebhookStatus | None]:
|
||||||
|
"""
|
||||||
|
Check if a webhook already exists for resource with a specific URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_type: The type of resource
|
||||||
|
resource_id: The ID of the resource to check
|
||||||
|
webhook_url: The URL of the webhook to check for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, str]: A tuple containing:
|
||||||
|
- bool: True if the webhook exists, False otherwise
|
||||||
|
- str: A reason message explaining the result
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Construct the URL based on the resource type
|
||||||
|
if resource_type == GitLabResourceType.GROUP:
|
||||||
|
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||||
|
else:
|
||||||
|
url = f'{self.BASE_URL}/projects/{resource_id}/hooks'
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get all webhooks for the resource
|
||||||
|
response, _ = await self._make_request(url)
|
||||||
|
|
||||||
|
# Check if any webhook has the specified URL
|
||||||
|
exists = False
|
||||||
|
if response:
|
||||||
|
for webhook in response:
|
||||||
|
if webhook.get('url') == webhook_url:
|
||||||
|
exists = True
|
||||||
|
|
||||||
|
return exists, None
|
||||||
|
|
||||||
|
except RateLimitError:
|
||||||
|
return False, WebhookStatus.RATE_LIMITED
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Webhook existence check failed', exc_info=True)
|
||||||
|
return False, WebhookStatus.INVALID
|
||||||
|
|
||||||
|
async def check_user_has_admin_access_to_resource(
|
||||||
|
self, resource_type: GitLabResourceType, resource_id: str
|
||||||
|
) -> tuple[bool, WebhookStatus | None]:
|
||||||
|
"""
|
||||||
|
Check if the user has admin access to resource (is either an owner or maintainer)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_type: The type of resource
|
||||||
|
resource_id: The ID of the resource to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, str]: A tuple containing:
|
||||||
|
- bool: True if the user has admin access to the resource (owner or maintainer), False otherwise
|
||||||
|
- str: A reason message explaining the result
|
||||||
|
"""
|
||||||
|
|
||||||
|
# For groups, we need to check if the user is an owner or maintainer
|
||||||
|
if resource_type == GitLabResourceType.GROUP:
|
||||||
|
url = f'{self.BASE_URL}/groups/{resource_id}/members/all'
|
||||||
|
try:
|
||||||
|
response, _ = await self._make_request(url)
|
||||||
|
# Check if the current user is in the members list with access level >= 40 (Maintainer or Owner)
|
||||||
|
|
||||||
|
exists = False
|
||||||
|
if response:
|
||||||
|
current_user = await self.get_user()
|
||||||
|
user_id = current_user.id
|
||||||
|
for member in response:
|
||||||
|
if (
|
||||||
|
str(member.get('id')) == str(user_id)
|
||||||
|
and member.get('access_level', 0) >= 40
|
||||||
|
):
|
||||||
|
exists = True
|
||||||
|
return exists, None
|
||||||
|
except RateLimitError:
|
||||||
|
return False, WebhookStatus.RATE_LIMITED
|
||||||
|
except Exception:
|
||||||
|
return False, WebhookStatus.INVALID
|
||||||
|
|
||||||
|
# For projects, we need to check if the user has maintainer or owner access
|
||||||
|
else:
|
||||||
|
url = f'{self.BASE_URL}/projects/{resource_id}/members/all'
|
||||||
|
try:
|
||||||
|
response, _ = await self._make_request(url)
|
||||||
|
exists = False
|
||||||
|
# Check if the current user is in the members list with access level >= 40 (Maintainer)
|
||||||
|
if response:
|
||||||
|
current_user = await self.get_user()
|
||||||
|
user_id = current_user.id
|
||||||
|
for member in response:
|
||||||
|
if (
|
||||||
|
str(member.get('id')) == str(user_id)
|
||||||
|
and member.get('access_level', 0) >= 40
|
||||||
|
):
|
||||||
|
exists = True
|
||||||
|
return exists, None
|
||||||
|
except RateLimitError:
|
||||||
|
return False, WebhookStatus.RATE_LIMITED
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Admin access check failed', exc_info=True)
|
||||||
|
return False, WebhookStatus.INVALID
|
||||||
|
|
||||||
|
async def install_webhook(
|
||||||
|
self,
|
||||||
|
resource_type: GitLabResourceType,
|
||||||
|
resource_id: str,
|
||||||
|
webhook_name: str,
|
||||||
|
webhook_url: str,
|
||||||
|
webhook_secret: str,
|
||||||
|
webhook_uuid: str,
|
||||||
|
scopes: list[str],
|
||||||
|
) -> tuple[str | None, WebhookStatus | None]:
|
||||||
|
"""
|
||||||
|
Install webhook for user's group or project
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource_type: The type of resource
|
||||||
|
resource_id: The ID of the resource to check
|
||||||
|
webhook_secret: Webhook secret that is used to verify payload
|
||||||
|
webhook_name: Name of webhook
|
||||||
|
webhook_url: Webhook URL
|
||||||
|
scopes: activity webhook listens for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, str]: A tuple containing:
|
||||||
|
- bool: True if installation was successful, False otherwise
|
||||||
|
- str: A reason message explaining the result
|
||||||
|
"""
|
||||||
|
|
||||||
|
description = 'Cloud OpenHands Resolver'
|
||||||
|
|
||||||
|
# Set up webhook parameters
|
||||||
|
webhook_data = {
|
||||||
|
'url': webhook_url,
|
||||||
|
'name': webhook_name,
|
||||||
|
'enable_ssl_verification': True,
|
||||||
|
'token': webhook_secret,
|
||||||
|
'description': description,
|
||||||
|
}
|
||||||
|
|
||||||
|
for scope in scopes:
|
||||||
|
webhook_data[scope] = True
|
||||||
|
|
||||||
|
# Add custom headers with user id
|
||||||
|
if self.external_auth_id:
|
||||||
|
webhook_data['custom_headers'] = [
|
||||||
|
{'key': 'X-OpenHands-User-ID', 'value': self.external_auth_id},
|
||||||
|
{'key': 'X-OpenHands-Webhook-ID', 'value': webhook_uuid},
|
||||||
|
]
|
||||||
|
|
||||||
|
if resource_type == GitLabResourceType.GROUP:
|
||||||
|
url = f'{self.BASE_URL}/groups/{resource_id}/hooks'
|
||||||
|
else:
|
||||||
|
url = f'{self.BASE_URL}/projects/{resource_id}/hooks'
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Make the API request
|
||||||
|
response, _ = await self._make_request(
|
||||||
|
url=url, params=webhook_data, method=RequestMethod.POST
|
||||||
|
)
|
||||||
|
|
||||||
|
if response and 'id' in response:
|
||||||
|
return str(response['id']), None
|
||||||
|
|
||||||
|
# Check if the webhook was created successfully
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
except RateLimitError:
|
||||||
|
return None, WebhookStatus.RATE_LIMITED
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Webhook installation failed', exc_info=True)
|
||||||
|
return None, WebhookStatus.INVALID
|
||||||
|
|
||||||
|
async def user_has_write_access(self, project_id: str) -> bool:
|
||||||
|
url = f'{self.BASE_URL}/projects/{project_id}'
|
||||||
|
try:
|
||||||
|
response, _ = await self._make_request(url)
|
||||||
|
# Check if the current user is in the members list with access level >= 30 (Developer)
|
||||||
|
|
||||||
|
if 'permissions' not in response:
|
||||||
|
logger.info('permissions not found', extra={'response': response})
|
||||||
|
return False
|
||||||
|
|
||||||
|
permissions = response['permissions']
|
||||||
|
if permissions['project_access']:
|
||||||
|
logger.info('[GitLab]: Checking project access')
|
||||||
|
return permissions['project_access']['access_level'] >= 30
|
||||||
|
|
||||||
|
if permissions['group_access']:
|
||||||
|
logger.info('[GitLab]: Checking group access')
|
||||||
|
return permissions['group_access']['access_level'] >= 30
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Access check failed', exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def reply_to_issue(
|
||||||
|
self, project_id: str, issue_number: str, discussion_id: str | None, body: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Either create new comment thread, or reply to comment thread (depending on discussion_id param)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if discussion_id:
|
||||||
|
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions/{discussion_id}/notes'
|
||||||
|
else:
|
||||||
|
url = f'{self.BASE_URL}/projects/{project_id}/issues/{issue_number}/discussions'
|
||||||
|
params = {'body': body}
|
||||||
|
|
||||||
|
await self._make_request(url=url, params=params, method=RequestMethod.POST)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'[GitLab]: Reply to issue failed {e}')
|
||||||
|
|
||||||
|
async def reply_to_mr(
|
||||||
|
self, project_id: str, merge_request_iid: str, discussion_id: str, body: str
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reply to comment thread on MR
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = f'{self.BASE_URL}/projects/{project_id}/merge_requests/{merge_request_iid}/discussions/{discussion_id}/notes'
|
||||||
|
params = {'body': body}
|
||||||
|
|
||||||
|
await self._make_request(url=url, params=params, method=RequestMethod.POST)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f'[GitLab]: Reply to MR failed {e}')
|
||||||
450
enterprise/integrations/gitlab/gitlab_view.py
Normal file
450
enterprise/integrations/gitlab/gitlab_view.py
Normal file
@@ -0,0 +1,450 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from integrations.models import Message
|
||||||
|
from integrations.types import ResolverViewInterface, UserData
|
||||||
|
from integrations.utils import HOST, get_oh_labels, has_exact_mention
|
||||||
|
from jinja2 import Environment
|
||||||
|
from server.auth.token_manager import TokenManager, get_config
|
||||||
|
from storage.database import session_maker
|
||||||
|
from storage.saas_secrets_store import SaasSecretsStore
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.gitlab.gitlab_service import GitLabServiceImpl
|
||||||
|
from openhands.integrations.provider import PROVIDER_TOKEN_TYPE, ProviderType
|
||||||
|
from openhands.integrations.service_types import Comment
|
||||||
|
from openhands.server.services.conversation_service import create_new_conversation
|
||||||
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|
||||||
|
OH_LABEL, INLINE_OH_LABEL = get_oh_labels(HOST)
|
||||||
|
CONFIDENTIAL_NOTE = 'confidential_note'
|
||||||
|
NOTE_TYPES = ['note', CONFIDENTIAL_NOTE]
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# SECTION: Factory to create appriorate Gitlab view
|
||||||
|
# =================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GitlabIssue(ResolverViewInterface):
|
||||||
|
installation_id: str # Webhook installation ID for Gitlab (comes from our DB)
|
||||||
|
issue_number: int
|
||||||
|
project_id: int
|
||||||
|
full_repo_name: str
|
||||||
|
is_public_repo: bool
|
||||||
|
user_info: UserData
|
||||||
|
raw_payload: Message
|
||||||
|
conversation_id: str
|
||||||
|
should_extract: bool
|
||||||
|
send_summary_instruction: bool
|
||||||
|
title: str
|
||||||
|
description: str
|
||||||
|
previous_comments: list[Comment]
|
||||||
|
is_mr: bool
|
||||||
|
|
||||||
|
async def _load_resolver_context(self):
|
||||||
|
gitlab_service = GitLabServiceImpl(
|
||||||
|
external_auth_id=self.user_info.keycloak_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
self.previous_comments = await gitlab_service.get_issue_or_mr_comments(
|
||||||
|
self.project_id, self.issue_number, is_mr=self.is_mr
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
self.title,
|
||||||
|
self.description,
|
||||||
|
) = await gitlab_service.get_issue_or_mr_title_and_body(
|
||||||
|
self.project_id, self.issue_number, is_mr=self.is_mr
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
issue_number=self.issue_number,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'issue_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
issue_title=self.title,
|
||||||
|
issue_body=self.description,
|
||||||
|
comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
async def _get_user_secrets(self):
|
||||||
|
secrets_store = SaasSecretsStore(
|
||||||
|
self.user_info.keycloak_user_id, session_maker, get_config()
|
||||||
|
)
|
||||||
|
user_secrets = await secrets_store.load()
|
||||||
|
|
||||||
|
return user_secrets.custom_secrets if user_secrets else None
|
||||||
|
|
||||||
|
async def create_new_conversation(
|
||||||
|
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||||
|
):
|
||||||
|
custom_secrets = await self._get_user_secrets()
|
||||||
|
|
||||||
|
user_instructions, conversation_instructions = await self._get_instructions(
|
||||||
|
jinja_env
|
||||||
|
)
|
||||||
|
agent_loop_info = await create_new_conversation(
|
||||||
|
user_id=self.user_info.keycloak_user_id,
|
||||||
|
git_provider_tokens=git_provider_tokens,
|
||||||
|
custom_secrets=custom_secrets,
|
||||||
|
selected_repository=self.full_repo_name,
|
||||||
|
selected_branch=None,
|
||||||
|
initial_user_msg=user_instructions,
|
||||||
|
conversation_instructions=conversation_instructions,
|
||||||
|
image_urls=None,
|
||||||
|
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||||
|
replay_json=None,
|
||||||
|
)
|
||||||
|
self.conversation_id = agent_loop_info.conversation_id
|
||||||
|
return self.conversation_id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GitlabIssueComment(GitlabIssue):
|
||||||
|
comment_body: str
|
||||||
|
discussion_id: str
|
||||||
|
confidential: bool
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('issue_prompt.j2')
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
issue_comment=self.comment_body
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'issue_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
issue_number=self.issue_number,
|
||||||
|
issue_title=self.title,
|
||||||
|
issue_body=self.description,
|
||||||
|
comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GitlabMRComment(GitlabIssueComment):
|
||||||
|
branch_name: str
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
mr_comment=self.comment_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'mr_update_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
mr_number=self.issue_number,
|
||||||
|
branch_name=self.branch_name,
|
||||||
|
mr_title=self.title,
|
||||||
|
mr_body=self.description,
|
||||||
|
comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
async def create_new_conversation(
|
||||||
|
self, jinja_env: Environment, git_provider_tokens: PROVIDER_TOKEN_TYPE
|
||||||
|
):
|
||||||
|
custom_secrets = await self._get_user_secrets()
|
||||||
|
|
||||||
|
user_instructions, conversation_instructions = await self._get_instructions(
|
||||||
|
jinja_env
|
||||||
|
)
|
||||||
|
agent_loop_info = await create_new_conversation(
|
||||||
|
user_id=self.user_info.keycloak_user_id,
|
||||||
|
git_provider_tokens=git_provider_tokens,
|
||||||
|
custom_secrets=custom_secrets,
|
||||||
|
selected_repository=self.full_repo_name,
|
||||||
|
selected_branch=self.branch_name,
|
||||||
|
initial_user_msg=user_instructions,
|
||||||
|
conversation_instructions=conversation_instructions,
|
||||||
|
image_urls=None,
|
||||||
|
conversation_trigger=ConversationTrigger.RESOLVER,
|
||||||
|
replay_json=None,
|
||||||
|
)
|
||||||
|
self.conversation_id = agent_loop_info.conversation_id
|
||||||
|
return self.conversation_id
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GitlabInlineMRComment(GitlabMRComment):
|
||||||
|
file_location: str
|
||||||
|
line_number: int
|
||||||
|
|
||||||
|
async def _load_resolver_context(self):
|
||||||
|
gitlab_service = GitLabServiceImpl(
|
||||||
|
external_auth_id=self.user_info.keycloak_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
self.title,
|
||||||
|
self.description,
|
||||||
|
) = await gitlab_service.get_issue_or_mr_title_and_body(
|
||||||
|
self.project_id, self.issue_number, is_mr=self.is_mr
|
||||||
|
)
|
||||||
|
|
||||||
|
self.previous_comments = await gitlab_service.get_review_thread_comments(
|
||||||
|
self.project_id, self.issue_number, self.discussion_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
user_instructions_template = jinja_env.get_template('mr_update_prompt.j2')
|
||||||
|
await self._load_resolver_context()
|
||||||
|
|
||||||
|
user_instructions = user_instructions_template.render(
|
||||||
|
mr_comment=self.comment_body,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'mr_update_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
mr_number=self.issue_number,
|
||||||
|
mr_title=self.title,
|
||||||
|
mr_body=self.description,
|
||||||
|
branch_name=self.branch_name,
|
||||||
|
file_location=self.file_location,
|
||||||
|
line_number=self.line_number,
|
||||||
|
comments=self.previous_comments,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_instructions, conversation_instructions
|
||||||
|
|
||||||
|
|
||||||
|
GitlabViewType = (
|
||||||
|
GitlabInlineMRComment | GitlabMRComment | GitlabIssueComment | GitlabIssue
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GitlabFactory:
|
||||||
|
@staticmethod
|
||||||
|
def is_labeled_issue(message: Message) -> bool:
|
||||||
|
payload = message.message['payload']
|
||||||
|
object_kind = payload.get('object_kind')
|
||||||
|
event_type = payload.get('event_type')
|
||||||
|
|
||||||
|
if object_kind == 'issue' and event_type == 'issue':
|
||||||
|
changes = payload.get('changes', {})
|
||||||
|
labels = changes.get('labels', {})
|
||||||
|
previous = labels.get('previous', [])
|
||||||
|
current = labels.get('current', [])
|
||||||
|
|
||||||
|
previous_labels = [obj['title'] for obj in previous]
|
||||||
|
current_labels = [obj['title'] for obj in current]
|
||||||
|
|
||||||
|
if OH_LABEL not in previous_labels and OH_LABEL in current_labels:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_issue_comment(message: Message) -> bool:
|
||||||
|
payload = message.message['payload']
|
||||||
|
object_kind = payload.get('object_kind')
|
||||||
|
event_type = payload.get('event_type')
|
||||||
|
issue = payload.get('issue')
|
||||||
|
|
||||||
|
if object_kind == 'note' and event_type in NOTE_TYPES and issue:
|
||||||
|
comment_body = payload.get('object_attributes', {}).get('note', '')
|
||||||
|
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_mr_comment(message: Message, inline=False) -> bool:
|
||||||
|
payload = message.message['payload']
|
||||||
|
object_kind = payload.get('object_kind')
|
||||||
|
event_type = payload.get('event_type')
|
||||||
|
merge_request = payload.get('merge_request')
|
||||||
|
|
||||||
|
if not (object_kind == 'note' and event_type in NOTE_TYPES and merge_request):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check whether not belongs to MR
|
||||||
|
object_attributes = payload.get('object_attributes', {})
|
||||||
|
noteable_type = object_attributes.get('noteable_type')
|
||||||
|
|
||||||
|
if noteable_type != 'MergeRequest':
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check whether comment is inline
|
||||||
|
change_position = object_attributes.get('change_position')
|
||||||
|
if inline and not change_position:
|
||||||
|
return False
|
||||||
|
if not inline and change_position:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check body
|
||||||
|
comment_body = object_attributes.get('note', '')
|
||||||
|
return has_exact_mention(comment_body, INLINE_OH_LABEL)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def determine_if_confidential(event_type: str):
|
||||||
|
return event_type == CONFIDENTIAL_NOTE
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_gitlab_view_from_payload(
|
||||||
|
message: Message, token_manager: TokenManager
|
||||||
|
) -> ResolverViewInterface:
|
||||||
|
payload = message.message['payload']
|
||||||
|
installation_id = message.message['installation_id']
|
||||||
|
user = payload['user']
|
||||||
|
user_id = user['id']
|
||||||
|
username = user['username']
|
||||||
|
repo_obj = payload['project']
|
||||||
|
selected_project = repo_obj['path_with_namespace']
|
||||||
|
is_public_repo = repo_obj['visibility_level'] == 0
|
||||||
|
project_id = payload['object_attributes']['project_id']
|
||||||
|
|
||||||
|
keycloak_user_id = await token_manager.get_user_id_from_idp_user_id(
|
||||||
|
user_id, ProviderType.GITLAB
|
||||||
|
)
|
||||||
|
|
||||||
|
user_info = UserData(
|
||||||
|
user_id=user_id, username=username, keycloak_user_id=keycloak_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if GitlabFactory.is_labeled_issue(message):
|
||||||
|
issue_iid = payload['object_attributes']['iid']
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Creating view for labeled issue from {username} in {selected_project}#{issue_iid}'
|
||||||
|
)
|
||||||
|
return GitlabIssue(
|
||||||
|
installation_id=installation_id,
|
||||||
|
issue_number=issue_iid,
|
||||||
|
project_id=project_id,
|
||||||
|
full_repo_name=selected_project,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
user_info=user_info,
|
||||||
|
raw_payload=message,
|
||||||
|
conversation_id='',
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
is_mr=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif GitlabFactory.is_issue_comment(message):
|
||||||
|
event_type = payload['event_type']
|
||||||
|
issue_iid = payload['issue']['iid']
|
||||||
|
object_attributes = payload['object_attributes']
|
||||||
|
discussion_id = object_attributes['discussion_id']
|
||||||
|
comment_body = object_attributes['note']
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Creating view for issue comment from {username} in {selected_project}#{issue_iid}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return GitlabIssueComment(
|
||||||
|
installation_id=installation_id,
|
||||||
|
comment_body=comment_body,
|
||||||
|
issue_number=issue_iid,
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
project_id=project_id,
|
||||||
|
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||||
|
full_repo_name=selected_project,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
user_info=user_info,
|
||||||
|
raw_payload=message,
|
||||||
|
conversation_id='',
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
is_mr=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif GitlabFactory.is_mr_comment(message):
|
||||||
|
event_type = payload['event_type']
|
||||||
|
merge_request_iid = payload['merge_request']['iid']
|
||||||
|
branch_name = payload['merge_request']['source_branch']
|
||||||
|
object_attributes = payload['object_attributes']
|
||||||
|
discussion_id = object_attributes['discussion_id']
|
||||||
|
comment_body = object_attributes['note']
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Creating view for merge request comment from {username} in {selected_project}#{merge_request_iid}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return GitlabMRComment(
|
||||||
|
installation_id=installation_id,
|
||||||
|
comment_body=comment_body,
|
||||||
|
issue_number=merge_request_iid, # Using issue_number as mr_number for compatibility
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
project_id=project_id,
|
||||||
|
full_repo_name=selected_project,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
user_info=user_info,
|
||||||
|
raw_payload=message,
|
||||||
|
conversation_id='',
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||||
|
branch_name=branch_name,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
is_mr=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif GitlabFactory.is_mr_comment(message, inline=True):
|
||||||
|
event_type = payload['event_type']
|
||||||
|
merge_request_iid = payload['merge_request']['iid']
|
||||||
|
branch_name = payload['merge_request']['source_branch']
|
||||||
|
object_attributes = payload['object_attributes']
|
||||||
|
comment_body = object_attributes['note']
|
||||||
|
position_info = object_attributes['position']
|
||||||
|
discussion_id = object_attributes['discussion_id']
|
||||||
|
file_location = object_attributes['position']['new_path']
|
||||||
|
line_number = (
|
||||||
|
position_info.get('new_line') or position_info.get('old_line') or 0
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[GitLab] Creating view for inline merge request comment from {username} in {selected_project}#{merge_request_iid}'
|
||||||
|
)
|
||||||
|
|
||||||
|
return GitlabInlineMRComment(
|
||||||
|
installation_id=installation_id,
|
||||||
|
issue_number=merge_request_iid, # Using issue_number as mr_number for compatibility
|
||||||
|
discussion_id=discussion_id,
|
||||||
|
project_id=project_id,
|
||||||
|
full_repo_name=selected_project,
|
||||||
|
is_public_repo=is_public_repo,
|
||||||
|
user_info=user_info,
|
||||||
|
raw_payload=message,
|
||||||
|
conversation_id='',
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
confidential=GitlabFactory.determine_if_confidential(event_type),
|
||||||
|
branch_name=branch_name,
|
||||||
|
file_location=file_location,
|
||||||
|
line_number=line_number,
|
||||||
|
comment_body=comment_body,
|
||||||
|
title='',
|
||||||
|
description='',
|
||||||
|
previous_comments=[],
|
||||||
|
is_mr=True,
|
||||||
|
)
|
||||||
503
enterprise/integrations/jira/jira_manager.py
Normal file
503
enterprise/integrations/jira/jira_manager.py
Normal file
@@ -0,0 +1,503 @@
|
|||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request
|
||||||
|
from integrations.jira.jira_types import JiraViewInterface
|
||||||
|
from integrations.jira.jira_view import (
|
||||||
|
JiraExistingConversationView,
|
||||||
|
JiraFactory,
|
||||||
|
JiraNewConversationView,
|
||||||
|
)
|
||||||
|
from integrations.manager import Manager
|
||||||
|
from integrations.models import JobContext, Message
|
||||||
|
from integrations.utils import (
|
||||||
|
HOST_URL,
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||||
|
filter_potential_repos_by_user_msg,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from server.utils.conversation_callback_utils import register_callback_processor
|
||||||
|
from storage.jira_integration_store import JiraIntegrationStore
|
||||||
|
from storage.jira_user import JiraUser
|
||||||
|
from storage.jira_workspace import JiraWorkspace
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.provider import ProviderHandler
|
||||||
|
from openhands.integrations.service_types import Repository
|
||||||
|
from openhands.server.shared import server_config
|
||||||
|
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
JIRA_CLOUD_API_URL = 'https://api.atlassian.com/ex/jira'
|
||||||
|
|
||||||
|
|
||||||
|
class JiraManager(Manager):
|
||||||
|
def __init__(self, token_manager: TokenManager):
|
||||||
|
self.token_manager = token_manager
|
||||||
|
self.integration_store = JiraIntegrationStore.get_instance()
|
||||||
|
self.jinja_env = Environment(
|
||||||
|
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def authenticate_user(
|
||||||
|
self, jira_user_id: str, workspace_id: int
|
||||||
|
) -> tuple[JiraUser | None, UserAuth | None]:
|
||||||
|
"""Authenticate Jira user and get their OpenHands user auth."""
|
||||||
|
|
||||||
|
# Find active Jira user by Keycloak user ID and workspace ID
|
||||||
|
jira_user = await self.integration_store.get_active_user(
|
||||||
|
jira_user_id, workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not jira_user:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira] No active Jira user found for {jira_user_id} in workspace {workspace_id}'
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||||
|
jira_user.keycloak_user_id
|
||||||
|
)
|
||||||
|
return jira_user, saas_user_auth
|
||||||
|
|
||||||
|
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||||
|
"""Get repositories that the user has access to."""
|
||||||
|
provider_tokens = await user_auth.get_provider_tokens()
|
||||||
|
if provider_tokens is None:
|
||||||
|
return []
|
||||||
|
access_token = await user_auth.get_access_token()
|
||||||
|
user_id = await user_auth.get_user_id()
|
||||||
|
client = ProviderHandler(
|
||||||
|
provider_tokens=provider_tokens,
|
||||||
|
external_auth_token=access_token,
|
||||||
|
external_auth_id=user_id,
|
||||||
|
)
|
||||||
|
repos: list[Repository] = await client.get_repositories(
|
||||||
|
'pushed', server_config.app_mode, None, None, None, None
|
||||||
|
)
|
||||||
|
return repos
|
||||||
|
|
||||||
|
async def validate_request(
|
||||||
|
self, request: Request
|
||||||
|
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||||
|
"""Verify Jira webhook signature."""
|
||||||
|
signature_header = request.headers.get('x-hub-signature')
|
||||||
|
signature = signature_header.split('=')[1] if signature_header else None
|
||||||
|
body = await request.body()
|
||||||
|
payload = await request.json()
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
if payload.get('webhookEvent') == 'comment_created':
|
||||||
|
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||||
|
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||||
|
selfUrl = payload.get('user', {}).get('self')
|
||||||
|
else:
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
parsedUrl = urlparse(selfUrl)
|
||||||
|
if parsedUrl.hostname:
|
||||||
|
workspace_name = parsedUrl.hostname
|
||||||
|
|
||||||
|
if not workspace_name:
|
||||||
|
logger.warning('[Jira] No workspace name found in webhook payload')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if not signature:
|
||||||
|
logger.warning('[Jira] No signature found in webhook headers')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||||
|
|
||||||
|
if not workspace:
|
||||||
|
logger.warning('[Jira] Could not identify workspace for webhook')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if workspace.status != 'active':
|
||||||
|
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||||
|
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||||
|
|
||||||
|
if hmac.compare_digest(signature, digest):
|
||||||
|
logger.info('[Jira] Webhook signature verified successfully')
|
||||||
|
return True, signature, payload
|
||||||
|
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||||
|
event_type = payload.get('webhookEvent')
|
||||||
|
|
||||||
|
if event_type == 'comment_created':
|
||||||
|
comment_data = payload.get('comment', {})
|
||||||
|
comment = comment_data.get('body', '')
|
||||||
|
|
||||||
|
if '@openhands' not in comment:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue_data = payload.get('issue', {})
|
||||||
|
issue_id = issue_data.get('id')
|
||||||
|
issue_key = issue_data.get('key')
|
||||||
|
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||||
|
|
||||||
|
user_data = comment_data.get('author', {})
|
||||||
|
user_email = user_data.get('emailAddress')
|
||||||
|
display_name = user_data.get('displayName')
|
||||||
|
account_id = user_data.get('accountId')
|
||||||
|
elif event_type == 'jira:issue_updated':
|
||||||
|
changelog = payload.get('changelog', {})
|
||||||
|
items = changelog.get('items', [])
|
||||||
|
labels = [
|
||||||
|
item.get('toString', '')
|
||||||
|
for item in items
|
||||||
|
if item.get('field') == 'labels' and 'toString' in item
|
||||||
|
]
|
||||||
|
|
||||||
|
if 'openhands' not in labels:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue_data = payload.get('issue', {})
|
||||||
|
issue_id = issue_data.get('id')
|
||||||
|
issue_key = issue_data.get('key')
|
||||||
|
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||||
|
|
||||||
|
user_data = payload.get('user', {})
|
||||||
|
user_email = user_data.get('emailAddress')
|
||||||
|
display_name = user_data.get('displayName')
|
||||||
|
account_id = user_data.get('accountId')
|
||||||
|
comment = ''
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
parsedUrl = urlparse(base_api_url)
|
||||||
|
if parsedUrl.hostname:
|
||||||
|
workspace_name = parsedUrl.hostname
|
||||||
|
|
||||||
|
if not all(
|
||||||
|
[
|
||||||
|
issue_id,
|
||||||
|
issue_key,
|
||||||
|
user_email,
|
||||||
|
display_name,
|
||||||
|
account_id,
|
||||||
|
workspace_name,
|
||||||
|
base_api_url,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return JobContext(
|
||||||
|
issue_id=issue_id,
|
||||||
|
issue_key=issue_key,
|
||||||
|
user_msg=comment,
|
||||||
|
user_email=user_email,
|
||||||
|
display_name=display_name,
|
||||||
|
platform_user_id=account_id,
|
||||||
|
workspace_name=workspace_name,
|
||||||
|
base_api_url=base_api_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
"""Process incoming Jira webhook message."""
|
||||||
|
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
job_context = self.parse_webhook(payload)
|
||||||
|
|
||||||
|
if not job_context:
|
||||||
|
logger.info('[Jira] Webhook does not match trigger conditions')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get workspace by user email domain
|
||||||
|
workspace = await self.integration_store.get_workspace_by_name(
|
||||||
|
job_context.workspace_name
|
||||||
|
)
|
||||||
|
if not workspace:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira] No workspace found for email domain: {job_context.user_email}'
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Your workspace is not configured with Jira integration.',
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prevent any recursive triggers from the service account
|
||||||
|
if job_context.user_email == workspace.svc_acc_email:
|
||||||
|
return
|
||||||
|
|
||||||
|
if workspace.status != 'active':
|
||||||
|
logger.warning(f'[Jira] Workspace {workspace.id} is not active')
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Jira integration is not active for your workspace.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Authenticate user
|
||||||
|
jira_user, saas_user_auth = await self.authenticate_user(
|
||||||
|
job_context.platform_user_id, workspace.id
|
||||||
|
)
|
||||||
|
if not jira_user or not saas_user_auth:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira] User authentication failed for {job_context.user_email}'
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
f'User {job_context.user_email} is not authenticated or active in the Jira integration.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get issue details
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||||
|
issue_title, issue_description = await self.get_issue_details(
|
||||||
|
job_context, workspace.jira_cloud_id, workspace.svc_acc_email, api_key
|
||||||
|
)
|
||||||
|
job_context.issue_title = issue_title
|
||||||
|
job_context.issue_description = issue_description
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira] Failed to get issue context: {str(e)}')
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create Jira view
|
||||||
|
jira_view = await JiraFactory.create_jira_view_from_payload(
|
||||||
|
job_context,
|
||||||
|
saas_user_auth,
|
||||||
|
jira_user,
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira] Failed to create jira view: {str(e)}', exc_info=True)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Failed to initialize conversation. Please try again.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await self.is_job_requested(message, jira_view):
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.start_job(jira_view)
|
||||||
|
|
||||||
|
async def is_job_requested(
|
||||||
|
self, message: Message, jira_view: JiraViewInterface
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a job is requested and handle repository selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(jira_view, JiraExistingConversationView):
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get user repositories
|
||||||
|
user_repos: list[Repository] = await self._get_repositories(
|
||||||
|
jira_view.saas_user_auth
|
||||||
|
)
|
||||||
|
|
||||||
|
target_str = f'{jira_view.job_context.issue_description}\n{jira_view.job_context.user_msg}'
|
||||||
|
|
||||||
|
# Try to infer repository from issue description
|
||||||
|
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
# Found exact repository match
|
||||||
|
jira_view.selected_repo = repos[0].full_name
|
||||||
|
logger.info(f'[Jira] Inferred repository: {repos[0].full_name}')
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# No clear match - send repository selection comment
|
||||||
|
await self._send_repo_selection_comment(jira_view)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira] Error in is_job_requested: {str(e)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start_job(self, jira_view: JiraViewInterface):
|
||||||
|
"""Start a Jira job/conversation."""
|
||||||
|
# Import here to prevent circular import
|
||||||
|
from server.conversation_callback_processor.jira_callback_processor import (
|
||||||
|
JiraCallbackProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_info: JiraUser = jira_view.jira_user
|
||||||
|
logger.info(
|
||||||
|
f'[Jira] Starting job for user {user_info.keycloak_user_id} '
|
||||||
|
f'issue {jira_view.job_context.issue_key}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create conversation
|
||||||
|
conversation_id = await jira_view.create_or_update_conversation(
|
||||||
|
self.jinja_env
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Jira] Created/Updated conversation {conversation_id} for issue {jira_view.job_context.issue_key}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register callback processor for updates
|
||||||
|
if isinstance(jira_view, JiraNewConversationView):
|
||||||
|
processor = JiraCallbackProcessor(
|
||||||
|
issue_key=jira_view.job_context.issue_key,
|
||||||
|
workspace_name=jira_view.jira_workspace.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback processor
|
||||||
|
register_callback_processor(conversation_id, processor)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Jira] Created callback processor for conversation {conversation_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send initial response
|
||||||
|
msg_info = jira_view.get_response_msg()
|
||||||
|
|
||||||
|
except MissingSettingsError as e:
|
||||||
|
logger.warning(f'[Jira] Missing settings error: {str(e)}')
|
||||||
|
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except LLMAuthenticationError as e:
|
||||||
|
logger.warning(f'[Jira] LLM authentication error: {str(e)}')
|
||||||
|
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira] Unexpected error starting job: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||||
|
|
||||||
|
# Send response comment
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(
|
||||||
|
jira_view.jira_workspace.svc_acc_api_key
|
||||||
|
)
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=msg_info),
|
||||||
|
issue_key=jira_view.job_context.issue_key,
|
||||||
|
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||||
|
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||||
|
svc_acc_api_key=api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||||
|
|
||||||
|
async def get_issue_details(
|
||||||
|
self,
|
||||||
|
job_context: JobContext,
|
||||||
|
jira_cloud_id: str,
|
||||||
|
svc_acc_email: str,
|
||||||
|
svc_acc_api_key: str,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
url = f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{job_context.issue_key}'
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url, auth=(svc_acc_email, svc_acc_api_key))
|
||||||
|
response.raise_for_status()
|
||||||
|
issue_payload = response.json()
|
||||||
|
|
||||||
|
if not issue_payload:
|
||||||
|
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||||
|
|
||||||
|
title = issue_payload.get('fields', {}).get('summary', '')
|
||||||
|
description = issue_payload.get('fields', {}).get('description', '')
|
||||||
|
|
||||||
|
if not title:
|
||||||
|
raise ValueError(
|
||||||
|
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not description:
|
||||||
|
raise ValueError(
|
||||||
|
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||||
|
)
|
||||||
|
|
||||||
|
return title, description
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self,
|
||||||
|
message: Message,
|
||||||
|
issue_key: str,
|
||||||
|
jira_cloud_id: str,
|
||||||
|
svc_acc_email: str,
|
||||||
|
svc_acc_api_key: str,
|
||||||
|
):
|
||||||
|
url = (
|
||||||
|
f'{JIRA_CLOUD_API_URL}/{jira_cloud_id}/rest/api/2/issue/{issue_key}/comment'
|
||||||
|
)
|
||||||
|
data = {'body': message.message}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url, auth=(svc_acc_email, svc_acc_api_key), json=data
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _send_error_comment(
|
||||||
|
self,
|
||||||
|
job_context: JobContext,
|
||||||
|
error_msg: str,
|
||||||
|
workspace: JiraWorkspace | None,
|
||||||
|
):
|
||||||
|
"""Send error comment to Jira issue."""
|
||||||
|
if not workspace:
|
||||||
|
logger.error('[Jira] Cannot send error comment - no workspace available')
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=error_msg),
|
||||||
|
issue_key=job_context.issue_key,
|
||||||
|
jira_cloud_id=workspace.jira_cloud_id,
|
||||||
|
svc_acc_email=workspace.svc_acc_email,
|
||||||
|
svc_acc_api_key=api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira] Failed to send error comment: {str(e)}')
|
||||||
|
|
||||||
|
async def _send_repo_selection_comment(self, jira_view: JiraViewInterface):
|
||||||
|
"""Send a comment with repository options for the user to choose."""
|
||||||
|
try:
|
||||||
|
comment_msg = (
|
||||||
|
'I need to know which repository to work with. '
|
||||||
|
'Please add it to your issue description or send a followup comment.'
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = self.token_manager.decrypt_text(
|
||||||
|
jira_view.jira_workspace.svc_acc_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=comment_msg),
|
||||||
|
issue_key=jira_view.job_context.issue_key,
|
||||||
|
jira_cloud_id=jira_view.jira_workspace.jira_cloud_id,
|
||||||
|
svc_acc_email=jira_view.jira_workspace.svc_acc_email,
|
||||||
|
svc_acc_api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Jira] Sent repository selection comment for issue {jira_view.job_context.issue_key}'
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||||
|
)
|
||||||
40
enterprise/integrations/jira/jira_types.py
Normal file
40
enterprise/integrations/jira/jira_types.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from integrations.models import JobContext
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.jira_user import JiraUser
|
||||||
|
from storage.jira_workspace import JiraWorkspace
|
||||||
|
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
|
||||||
|
class JiraViewInterface(ABC):
|
||||||
|
"""Interface for Jira views that handle different types of Jira interactions."""
|
||||||
|
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
jira_user: JiraUser
|
||||||
|
jira_workspace: JiraWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Get initial instructions for the conversation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Create or update a conversation and return the conversation ID."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Jira."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StartingConvoException(Exception):
|
||||||
|
"""Exception raised when starting a conversation fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
222
enterprise/integrations/jira/jira_view.py
Normal file
222
enterprise/integrations/jira/jira_view.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from integrations.jira.jira_types import JiraViewInterface, StartingConvoException
|
||||||
|
from integrations.models import JobContext
|
||||||
|
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.jira_conversation import JiraConversation
|
||||||
|
from storage.jira_integration_store import JiraIntegrationStore
|
||||||
|
from storage.jira_user import JiraUser
|
||||||
|
from storage.jira_workspace import JiraWorkspace
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.schema.agent import AgentState
|
||||||
|
from openhands.events.action import MessageAction
|
||||||
|
from openhands.events.serialization.event import event_to_dict
|
||||||
|
from openhands.server.services.conversation_service import (
|
||||||
|
create_new_conversation,
|
||||||
|
setup_init_conversation_settings,
|
||||||
|
)
|
||||||
|
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|
||||||
|
integration_store = JiraIntegrationStore.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JiraNewConversationView(JiraViewInterface):
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
jira_user: JiraUser
|
||||||
|
jira_workspace: JiraWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Instructions passed when conversation is first initialized"""
|
||||||
|
|
||||||
|
instructions_template = jinja_env.get_template('jira_instructions.j2')
|
||||||
|
instructions = instructions_template.render()
|
||||||
|
|
||||||
|
user_msg_template = jinja_env.get_template('jira_new_conversation.j2')
|
||||||
|
|
||||||
|
user_msg = user_msg_template.render(
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
issue_title=self.job_context.issue_title,
|
||||||
|
issue_description=self.job_context.issue_description,
|
||||||
|
user_message=self.job_context.user_msg or '',
|
||||||
|
)
|
||||||
|
|
||||||
|
return instructions, user_msg
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Create a new Jira conversation"""
|
||||||
|
|
||||||
|
if not self.selected_repo:
|
||||||
|
raise StartingConvoException('No repository selected for this conversation')
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||||
|
instructions, user_msg = self._get_instructions(jinja_env)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_loop_info = await create_new_conversation(
|
||||||
|
user_id=self.jira_user.keycloak_user_id,
|
||||||
|
git_provider_tokens=provider_tokens,
|
||||||
|
selected_repository=self.selected_repo,
|
||||||
|
selected_branch=None,
|
||||||
|
initial_user_msg=user_msg,
|
||||||
|
conversation_instructions=instructions,
|
||||||
|
image_urls=None,
|
||||||
|
replay_json=None,
|
||||||
|
conversation_trigger=ConversationTrigger.JIRA,
|
||||||
|
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_id = agent_loop_info.conversation_id
|
||||||
|
|
||||||
|
logger.info(f'[Jira] Created conversation {self.conversation_id}')
|
||||||
|
|
||||||
|
# Store Jira conversation mapping
|
||||||
|
jira_conversation = JiraConversation(
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
issue_id=self.job_context.issue_id,
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
jira_user_id=self.jira_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await integration_store.create_conversation(jira_conversation)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Jira"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JiraExistingConversationView(JiraViewInterface):
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
jira_user: JiraUser
|
||||||
|
jira_workspace: JiraWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Instructions passed when conversation is first initialized"""
|
||||||
|
|
||||||
|
user_msg_template = jinja_env.get_template('jira_existing_conversation.j2')
|
||||||
|
user_msg = user_msg_template.render(
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
user_message=self.job_context.user_msg or '',
|
||||||
|
issue_title=self.job_context.issue_title,
|
||||||
|
issue_description=self.job_context.issue_description,
|
||||||
|
)
|
||||||
|
|
||||||
|
return '', user_msg
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Update an existing Jira conversation"""
|
||||||
|
|
||||||
|
user_id = self.jira_user.keycloak_user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
|
config, user_id
|
||||||
|
)
|
||||||
|
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||||
|
if not metadata:
|
||||||
|
raise StartingConvoException('Conversation no longer exists.')
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
# Should we raise here if there are no providers?
|
||||||
|
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||||
|
|
||||||
|
conversation_init_data = await setup_init_conversation_settings(
|
||||||
|
user_id, self.conversation_id, providers_set
|
||||||
|
)
|
||||||
|
|
||||||
|
# Either join ongoing conversation, or restart the conversation
|
||||||
|
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||||
|
self.conversation_id, conversation_init_data, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
final_agent_observation = get_final_agent_observation(
|
||||||
|
agent_loop_info.event_store
|
||||||
|
)
|
||||||
|
agent_state = (
|
||||||
|
None
|
||||||
|
if len(final_agent_observation) == 0
|
||||||
|
else final_agent_observation[0].agent_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_state or agent_state == AgentState.LOADING:
|
||||||
|
raise StartingConvoException('Conversation is still starting')
|
||||||
|
|
||||||
|
_, user_msg = self._get_instructions(jinja_env)
|
||||||
|
user_message_event = MessageAction(content=user_msg)
|
||||||
|
await conversation_manager.send_event_to_conversation(
|
||||||
|
self.conversation_id, event_to_dict(user_message_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Jira"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||||
|
|
||||||
|
|
||||||
|
class JiraFactory:
|
||||||
|
"""Factory for creating Jira views based on message content"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_jira_view_from_payload(
|
||||||
|
job_context: JobContext,
|
||||||
|
saas_user_auth: UserAuth,
|
||||||
|
jira_user: JiraUser,
|
||||||
|
jira_workspace: JiraWorkspace,
|
||||||
|
) -> JiraViewInterface:
|
||||||
|
"""Create appropriate Jira view based on the message and user state"""
|
||||||
|
|
||||||
|
if not jira_user or not saas_user_auth or not jira_workspace:
|
||||||
|
raise StartingConvoException('User not authenticated with Jira integration')
|
||||||
|
|
||||||
|
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||||
|
job_context.issue_id, jira_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation:
|
||||||
|
logger.info(
|
||||||
|
f'[Jira] Found existing conversation for issue {job_context.issue_id}'
|
||||||
|
)
|
||||||
|
return JiraExistingConversationView(
|
||||||
|
job_context=job_context,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
jira_user=jira_user,
|
||||||
|
jira_workspace=jira_workspace,
|
||||||
|
selected_repo=None,
|
||||||
|
conversation_id=conversation.conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JiraNewConversationView(
|
||||||
|
job_context=job_context,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
jira_user=jira_user,
|
||||||
|
jira_workspace=jira_workspace,
|
||||||
|
selected_repo=None, # Will be set later after repo inference
|
||||||
|
conversation_id='', # Will be set when conversation is created
|
||||||
|
)
|
||||||
508
enterprise/integrations/jira_dc/jira_dc_manager.py
Normal file
508
enterprise/integrations/jira_dc/jira_dc_manager.py
Normal file
@@ -0,0 +1,508 @@
|
|||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request
|
||||||
|
from integrations.jira_dc.jira_dc_types import (
|
||||||
|
JiraDcViewInterface,
|
||||||
|
)
|
||||||
|
from integrations.jira_dc.jira_dc_view import (
|
||||||
|
JiraDcExistingConversationView,
|
||||||
|
JiraDcFactory,
|
||||||
|
JiraDcNewConversationView,
|
||||||
|
)
|
||||||
|
from integrations.manager import Manager
|
||||||
|
from integrations.models import JobContext, Message
|
||||||
|
from integrations.utils import (
|
||||||
|
HOST_URL,
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||||
|
filter_potential_repos_by_user_msg,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from server.utils.conversation_callback_utils import register_callback_processor
|
||||||
|
from storage.jira_dc_integration_store import JiraDcIntegrationStore
|
||||||
|
from storage.jira_dc_user import JiraDcUser
|
||||||
|
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.provider import ProviderHandler
|
||||||
|
from openhands.integrations.service_types import Repository
|
||||||
|
from openhands.server.shared import server_config
|
||||||
|
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
|
||||||
|
class JiraDcManager(Manager):
|
||||||
|
def __init__(self, token_manager: TokenManager):
|
||||||
|
self.token_manager = token_manager
|
||||||
|
self.integration_store = JiraDcIntegrationStore.get_instance()
|
||||||
|
self.jinja_env = Environment(
|
||||||
|
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'jira_dc')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def authenticate_user(
|
||||||
|
self, user_email: str, jira_dc_user_id: str, workspace_id: int
|
||||||
|
) -> tuple[JiraDcUser | None, UserAuth | None]:
|
||||||
|
"""Authenticate Jira DC user and get their OpenHands user auth."""
|
||||||
|
|
||||||
|
if not jira_dc_user_id or jira_dc_user_id == 'none':
|
||||||
|
# Get Keycloak user ID from email
|
||||||
|
keycloak_user_id = await self.token_manager.get_user_id_from_user_email(
|
||||||
|
user_email
|
||||||
|
)
|
||||||
|
if not keycloak_user_id:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira DC] No Keycloak user found for email: {user_email}'
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Find active Jira DC user by Keycloak user ID and organization
|
||||||
|
jira_dc_user = await self.integration_store.get_active_user_by_keycloak_id_and_workspace(
|
||||||
|
keycloak_user_id, workspace_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
jira_dc_user = await self.integration_store.get_active_user(
|
||||||
|
jira_dc_user_id, workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not jira_dc_user:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira DC] No active Jira DC user found for {user_email} in workspace {workspace_id}'
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||||
|
jira_dc_user.keycloak_user_id
|
||||||
|
)
|
||||||
|
return jira_dc_user, saas_user_auth
|
||||||
|
|
||||||
|
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||||
|
"""Get repositories that the user has access to."""
|
||||||
|
provider_tokens = await user_auth.get_provider_tokens()
|
||||||
|
if provider_tokens is None:
|
||||||
|
return []
|
||||||
|
access_token = await user_auth.get_access_token()
|
||||||
|
user_id = await user_auth.get_user_id()
|
||||||
|
client = ProviderHandler(
|
||||||
|
provider_tokens=provider_tokens,
|
||||||
|
external_auth_token=access_token,
|
||||||
|
external_auth_id=user_id,
|
||||||
|
)
|
||||||
|
repos: list[Repository] = await client.get_repositories(
|
||||||
|
'pushed', server_config.app_mode, None, None, None, None
|
||||||
|
)
|
||||||
|
return repos
|
||||||
|
|
||||||
|
async def validate_request(
|
||||||
|
self, request: Request
|
||||||
|
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||||
|
"""Verify Jira DC webhook signature."""
|
||||||
|
signature_header = request.headers.get('x-hub-signature')
|
||||||
|
signature = signature_header.split('=')[1] if signature_header else None
|
||||||
|
body = await request.body()
|
||||||
|
payload = await request.json()
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
if payload.get('webhookEvent') == 'comment_created':
|
||||||
|
selfUrl = payload.get('comment', {}).get('author', {}).get('self')
|
||||||
|
elif payload.get('webhookEvent') == 'jira:issue_updated':
|
||||||
|
selfUrl = payload.get('user', {}).get('self')
|
||||||
|
else:
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
parsedUrl = urlparse(selfUrl)
|
||||||
|
if parsedUrl.hostname:
|
||||||
|
workspace_name = parsedUrl.hostname
|
||||||
|
|
||||||
|
if not workspace_name:
|
||||||
|
logger.warning('[Jira DC] No workspace name found in webhook payload')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if not signature:
|
||||||
|
logger.warning('[Jira DC] No signature found in webhook headers')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||||
|
|
||||||
|
if not workspace:
|
||||||
|
logger.warning('[Jira DC] Could not identify workspace for webhook')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if workspace.status != 'active':
|
||||||
|
logger.warning(f'[Jira DC] Workspace {workspace.id} is not active')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||||
|
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||||
|
|
||||||
|
if hmac.compare_digest(signature, digest):
|
||||||
|
logger.info('[Jira DC] Webhook signature verified successfully')
|
||||||
|
return True, signature, payload
|
||||||
|
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||||
|
event_type = payload.get('webhookEvent')
|
||||||
|
|
||||||
|
if event_type == 'comment_created':
|
||||||
|
comment_data = payload.get('comment', {})
|
||||||
|
comment = comment_data.get('body', '')
|
||||||
|
|
||||||
|
if '@openhands' not in comment:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue_data = payload.get('issue', {})
|
||||||
|
issue_id = issue_data.get('id')
|
||||||
|
issue_key = issue_data.get('key')
|
||||||
|
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||||
|
|
||||||
|
user_data = comment_data.get('author', {})
|
||||||
|
user_email = user_data.get('emailAddress')
|
||||||
|
display_name = user_data.get('displayName')
|
||||||
|
user_key = user_data.get('key')
|
||||||
|
elif event_type == 'jira:issue_updated':
|
||||||
|
changelog = payload.get('changelog', {})
|
||||||
|
items = changelog.get('items', [])
|
||||||
|
labels = [
|
||||||
|
item.get('toString', '')
|
||||||
|
for item in items
|
||||||
|
if item.get('field') == 'labels' and 'toString' in item
|
||||||
|
]
|
||||||
|
|
||||||
|
if 'openhands' not in labels:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue_data = payload.get('issue', {})
|
||||||
|
issue_id = issue_data.get('id')
|
||||||
|
issue_key = issue_data.get('key')
|
||||||
|
base_api_url = issue_data.get('self', '').split('/rest/')[0]
|
||||||
|
|
||||||
|
user_data = payload.get('user', {})
|
||||||
|
user_email = user_data.get('emailAddress')
|
||||||
|
display_name = user_data.get('displayName')
|
||||||
|
user_key = user_data.get('key')
|
||||||
|
comment = ''
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
parsedUrl = urlparse(base_api_url)
|
||||||
|
if parsedUrl.hostname:
|
||||||
|
workspace_name = parsedUrl.hostname
|
||||||
|
|
||||||
|
if not all(
|
||||||
|
[
|
||||||
|
issue_id,
|
||||||
|
issue_key,
|
||||||
|
user_email,
|
||||||
|
display_name,
|
||||||
|
user_key,
|
||||||
|
workspace_name,
|
||||||
|
base_api_url,
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return JobContext(
|
||||||
|
issue_id=issue_id,
|
||||||
|
issue_key=issue_key,
|
||||||
|
user_msg=comment,
|
||||||
|
user_email=user_email,
|
||||||
|
display_name=display_name,
|
||||||
|
platform_user_id=user_key,
|
||||||
|
workspace_name=workspace_name,
|
||||||
|
base_api_url=base_api_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
"""Process incoming Jira DC webhook message."""
|
||||||
|
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
job_context = self.parse_webhook(payload)
|
||||||
|
|
||||||
|
if not job_context:
|
||||||
|
logger.info('[Jira DC] Webhook does not match trigger conditions')
|
||||||
|
return
|
||||||
|
|
||||||
|
workspace = await self.integration_store.get_workspace_by_name(
|
||||||
|
job_context.workspace_name
|
||||||
|
)
|
||||||
|
if not workspace:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira DC] No workspace found for email domain: {job_context.user_email}'
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Your workspace is not configured with Jira DC integration.',
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prevent any recursive triggers from the service account
|
||||||
|
if job_context.user_email == workspace.svc_acc_email:
|
||||||
|
return
|
||||||
|
|
||||||
|
if workspace.status != 'active':
|
||||||
|
logger.warning(f'[Jira DC] Workspace {workspace.id} is not active')
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Jira DC integration is not active for your workspace.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Authenticate user
|
||||||
|
jira_dc_user, saas_user_auth = await self.authenticate_user(
|
||||||
|
job_context.user_email, job_context.platform_user_id, workspace.id
|
||||||
|
)
|
||||||
|
if not jira_dc_user or not saas_user_auth:
|
||||||
|
logger.warning(
|
||||||
|
f'[Jira DC] User authentication failed for {job_context.user_email}'
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
f'User {job_context.user_email} is not authenticated or active in the Jira DC integration.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get issue details
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||||
|
issue_title, issue_description = await self.get_issue_details(
|
||||||
|
job_context, api_key
|
||||||
|
)
|
||||||
|
job_context.issue_title = issue_title
|
||||||
|
job_context.issue_description = issue_description
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira DC] Failed to get issue context: {str(e)}')
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Failed to retrieve issue details. Please check the issue key and try again.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create Jira DC view
|
||||||
|
jira_dc_view = await JiraDcFactory.create_jira_dc_view_from_payload(
|
||||||
|
job_context,
|
||||||
|
saas_user_auth,
|
||||||
|
jira_dc_user,
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira DC] Failed to create jira dc view: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context,
|
||||||
|
'Failed to initialize conversation. Please try again.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await self.is_job_requested(message, jira_dc_view):
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.start_job(jira_dc_view)
|
||||||
|
|
||||||
|
async def is_job_requested(
|
||||||
|
self, message: Message, jira_dc_view: JiraDcViewInterface
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a job is requested and handle repository selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(jira_dc_view, JiraDcExistingConversationView):
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get user repositories
|
||||||
|
user_repos: list[Repository] = await self._get_repositories(
|
||||||
|
jira_dc_view.saas_user_auth
|
||||||
|
)
|
||||||
|
|
||||||
|
target_str = f'{jira_dc_view.job_context.issue_description}\n{jira_dc_view.job_context.user_msg}'
|
||||||
|
|
||||||
|
# Try to infer repository from issue description
|
||||||
|
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
# Found exact repository match
|
||||||
|
jira_dc_view.selected_repo = repos[0].full_name
|
||||||
|
logger.info(f'[Jira DC] Inferred repository: {repos[0].full_name}')
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# No clear match - send repository selection comment
|
||||||
|
await self._send_repo_selection_comment(jira_dc_view)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira DC] Error in is_job_requested: {str(e)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start_job(self, jira_dc_view: JiraDcViewInterface):
|
||||||
|
"""Start a Jira DC job/conversation."""
|
||||||
|
# Import here to prevent circular import
|
||||||
|
from server.conversation_callback_processor.jira_dc_callback_processor import (
|
||||||
|
JiraDcCallbackProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_info: JiraDcUser = jira_dc_view.jira_dc_user
|
||||||
|
logger.info(
|
||||||
|
f'[Jira DC] Starting job for user {user_info.keycloak_user_id} '
|
||||||
|
f'issue {jira_dc_view.job_context.issue_key}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create conversation
|
||||||
|
conversation_id = await jira_dc_view.create_or_update_conversation(
|
||||||
|
self.jinja_env
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Jira DC] Created/Updated conversation {conversation_id} for issue {jira_dc_view.job_context.issue_key}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(jira_dc_view, JiraDcNewConversationView):
|
||||||
|
# Register callback processor for updates
|
||||||
|
processor = JiraDcCallbackProcessor(
|
||||||
|
issue_key=jira_dc_view.job_context.issue_key,
|
||||||
|
workspace_name=jira_dc_view.jira_dc_workspace.name,
|
||||||
|
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback processor
|
||||||
|
register_callback_processor(conversation_id, processor)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Jira DC] Created callback processor for conversation {conversation_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send initial response
|
||||||
|
msg_info = jira_dc_view.get_response_msg()
|
||||||
|
|
||||||
|
except MissingSettingsError as e:
|
||||||
|
logger.warning(f'[Jira DC] Missing settings error: {str(e)}')
|
||||||
|
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except LLMAuthenticationError as e:
|
||||||
|
logger.warning(f'[Jira DC] LLM authentication error: {str(e)}')
|
||||||
|
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira DC] Unexpected error starting job: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||||
|
|
||||||
|
# Send response comment
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(
|
||||||
|
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||||
|
)
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=msg_info),
|
||||||
|
issue_key=jira_dc_view.job_context.issue_key,
|
||||||
|
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||||
|
svc_acc_api_key=api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira] Failed to send response message: {str(e)}')
|
||||||
|
|
||||||
|
async def get_issue_details(
|
||||||
|
self, job_context: JobContext, svc_acc_api_key: str
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Get issue details from Jira DC API."""
|
||||||
|
url = f'{job_context.base_api_url}/rest/api/2/issue/{job_context.issue_key}'
|
||||||
|
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
issue_payload = response.json()
|
||||||
|
|
||||||
|
if not issue_payload:
|
||||||
|
raise ValueError(f'Issue with key {job_context.issue_key} not found.')
|
||||||
|
|
||||||
|
title = issue_payload.get('fields', {}).get('summary', '')
|
||||||
|
description = issue_payload.get('fields', {}).get('description', '')
|
||||||
|
|
||||||
|
if not title:
|
||||||
|
raise ValueError(
|
||||||
|
f'Issue with key {job_context.issue_key} does not have a title.'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not description:
|
||||||
|
raise ValueError(
|
||||||
|
f'Issue with key {job_context.issue_key} does not have a description.'
|
||||||
|
)
|
||||||
|
|
||||||
|
return title, description
|
||||||
|
|
||||||
|
async def send_message(
|
||||||
|
self, message: Message, issue_key: str, base_api_url: str, svc_acc_api_key: str
|
||||||
|
):
|
||||||
|
"""Send message/comment to Jira DC issue."""
|
||||||
|
url = f'{base_api_url}/rest/api/2/issue/{issue_key}/comment'
|
||||||
|
headers = {'Authorization': f'Bearer {svc_acc_api_key}'}
|
||||||
|
data = {'body': message.message}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(url, headers=headers, json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _send_error_comment(
|
||||||
|
self,
|
||||||
|
job_context: JobContext,
|
||||||
|
error_msg: str,
|
||||||
|
workspace: JiraDcWorkspace | None,
|
||||||
|
):
|
||||||
|
"""Send error comment to Jira DC issue."""
|
||||||
|
if not workspace:
|
||||||
|
logger.error('[Jira DC] Cannot send error comment - no workspace available')
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=error_msg),
|
||||||
|
issue_key=job_context.issue_key,
|
||||||
|
base_api_url=job_context.base_api_url,
|
||||||
|
svc_acc_api_key=api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Jira DC] Failed to send error comment: {str(e)}')
|
||||||
|
|
||||||
|
async def _send_repo_selection_comment(self, jira_dc_view: JiraDcViewInterface):
|
||||||
|
"""Send a comment with repository options for the user to choose."""
|
||||||
|
try:
|
||||||
|
comment_msg = (
|
||||||
|
'I need to know which repository to work with. '
|
||||||
|
'Please add it to your issue description or send a followup comment.'
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = self.token_manager.decrypt_text(
|
||||||
|
jira_dc_view.jira_dc_workspace.svc_acc_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=comment_msg),
|
||||||
|
issue_key=jira_dc_view.job_context.issue_key,
|
||||||
|
base_api_url=jira_dc_view.job_context.base_api_url,
|
||||||
|
svc_acc_api_key=api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Jira] Sent repository selection comment for issue {jira_dc_view.job_context.issue_key}'
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira] Failed to send repository selection comment: {str(e)}'
|
||||||
|
)
|
||||||
40
enterprise/integrations/jira_dc/jira_dc_types.py
Normal file
40
enterprise/integrations/jira_dc/jira_dc_types.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from integrations.models import JobContext
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.jira_dc_user import JiraDcUser
|
||||||
|
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||||
|
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
|
||||||
|
class JiraDcViewInterface(ABC):
|
||||||
|
"""Interface for Jira DC views that handle different types of Jira DC interactions."""
|
||||||
|
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
jira_dc_user: JiraDcUser
|
||||||
|
jira_dc_workspace: JiraDcWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Get initial instructions for the conversation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Create or update a conversation and return the conversation ID."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Jira DC."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StartingConvoException(Exception):
|
||||||
|
"""Exception raised when starting a conversation fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
223
enterprise/integrations/jira_dc/jira_dc_view.py
Normal file
223
enterprise/integrations/jira_dc/jira_dc_view.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from integrations.jira_dc.jira_dc_types import (
|
||||||
|
JiraDcViewInterface,
|
||||||
|
StartingConvoException,
|
||||||
|
)
|
||||||
|
from integrations.models import JobContext
|
||||||
|
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.jira_dc_conversation import JiraDcConversation
|
||||||
|
from storage.jira_dc_integration_store import JiraDcIntegrationStore
|
||||||
|
from storage.jira_dc_user import JiraDcUser
|
||||||
|
from storage.jira_dc_workspace import JiraDcWorkspace
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.schema.agent import AgentState
|
||||||
|
from openhands.events.action import MessageAction
|
||||||
|
from openhands.events.serialization.event import event_to_dict
|
||||||
|
from openhands.server.services.conversation_service import (
|
||||||
|
create_new_conversation,
|
||||||
|
setup_init_conversation_settings,
|
||||||
|
)
|
||||||
|
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|
||||||
|
integration_store = JiraDcIntegrationStore.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JiraDcNewConversationView(JiraDcViewInterface):
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
jira_dc_user: JiraDcUser
|
||||||
|
jira_dc_workspace: JiraDcWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Instructions passed when conversation is first initialized"""
|
||||||
|
|
||||||
|
instructions_template = jinja_env.get_template('jira_dc_instructions.j2')
|
||||||
|
instructions = instructions_template.render()
|
||||||
|
|
||||||
|
user_msg_template = jinja_env.get_template('jira_dc_new_conversation.j2')
|
||||||
|
|
||||||
|
user_msg = user_msg_template.render(
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
issue_title=self.job_context.issue_title,
|
||||||
|
issue_description=self.job_context.issue_description,
|
||||||
|
user_message=self.job_context.user_msg or '',
|
||||||
|
)
|
||||||
|
|
||||||
|
return instructions, user_msg
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Create a new Jira DC conversation"""
|
||||||
|
|
||||||
|
if not self.selected_repo:
|
||||||
|
raise StartingConvoException('No repository selected for this conversation')
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||||
|
instructions, user_msg = self._get_instructions(jinja_env)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_loop_info = await create_new_conversation(
|
||||||
|
user_id=self.jira_dc_user.keycloak_user_id,
|
||||||
|
git_provider_tokens=provider_tokens,
|
||||||
|
selected_repository=self.selected_repo,
|
||||||
|
selected_branch=None,
|
||||||
|
initial_user_msg=user_msg,
|
||||||
|
conversation_instructions=instructions,
|
||||||
|
image_urls=None,
|
||||||
|
replay_json=None,
|
||||||
|
conversation_trigger=ConversationTrigger.JIRA_DC,
|
||||||
|
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_id = agent_loop_info.conversation_id
|
||||||
|
|
||||||
|
logger.info(f'[Jira DC] Created conversation {self.conversation_id}')
|
||||||
|
|
||||||
|
# Store Jira DC conversation mapping
|
||||||
|
jira_dc_conversation = JiraDcConversation(
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
issue_id=self.job_context.issue_id,
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
jira_dc_user_id=self.jira_dc_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await integration_store.create_conversation(jira_dc_conversation)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira DC] Failed to create conversation: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Jira DC"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {self.job_context.display_name} can [track my progress here|{conversation_link}]."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JiraDcExistingConversationView(JiraDcViewInterface):
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
jira_dc_user: JiraDcUser
|
||||||
|
jira_dc_workspace: JiraDcWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Instructions passed when conversation is first initialized"""
|
||||||
|
|
||||||
|
user_msg_template = jinja_env.get_template('jira_dc_existing_conversation.j2')
|
||||||
|
user_msg = user_msg_template.render(
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
user_message=self.job_context.user_msg or '',
|
||||||
|
issue_title=self.job_context.issue_title,
|
||||||
|
issue_description=self.job_context.issue_description,
|
||||||
|
)
|
||||||
|
|
||||||
|
return '', user_msg
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Update an existing Jira conversation"""
|
||||||
|
|
||||||
|
user_id = self.jira_dc_user.keycloak_user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
|
config, user_id
|
||||||
|
)
|
||||||
|
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||||
|
if not metadata:
|
||||||
|
raise StartingConvoException('Conversation no longer exists.')
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
if provider_tokens is None:
|
||||||
|
raise ValueError('Could not load provider tokens')
|
||||||
|
providers_set = list(provider_tokens.keys())
|
||||||
|
|
||||||
|
conversation_init_data = await setup_init_conversation_settings(
|
||||||
|
user_id, self.conversation_id, providers_set
|
||||||
|
)
|
||||||
|
|
||||||
|
# Either join ongoing conversation, or restart the conversation
|
||||||
|
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||||
|
self.conversation_id, conversation_init_data, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
final_agent_observation = get_final_agent_observation(
|
||||||
|
agent_loop_info.event_store
|
||||||
|
)
|
||||||
|
agent_state = (
|
||||||
|
None
|
||||||
|
if len(final_agent_observation) == 0
|
||||||
|
else final_agent_observation[0].agent_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_state or agent_state == AgentState.LOADING:
|
||||||
|
raise StartingConvoException('Conversation is still starting')
|
||||||
|
|
||||||
|
_, user_msg = self._get_instructions(jinja_env)
|
||||||
|
user_message_event = MessageAction(content=user_msg)
|
||||||
|
await conversation_manager.send_event_to_conversation(
|
||||||
|
self.conversation_id, event_to_dict(user_message_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Jira] Failed to create conversation: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Jira"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here|{conversation_link}]."
|
||||||
|
|
||||||
|
|
||||||
|
class JiraDcFactory:
|
||||||
|
"""Factory class for creating Jira DC views based on message type."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_jira_dc_view_from_payload(
|
||||||
|
job_context: JobContext,
|
||||||
|
saas_user_auth: UserAuth,
|
||||||
|
jira_dc_user: JiraDcUser,
|
||||||
|
jira_dc_workspace: JiraDcWorkspace,
|
||||||
|
) -> JiraDcViewInterface:
|
||||||
|
"""Create appropriate Jira DC view based on the payload."""
|
||||||
|
|
||||||
|
if not jira_dc_user or not saas_user_auth or not jira_dc_workspace:
|
||||||
|
raise StartingConvoException('User not authenticated with Jira integration')
|
||||||
|
|
||||||
|
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||||
|
job_context.issue_id, jira_dc_user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if conversation:
|
||||||
|
return JiraDcExistingConversationView(
|
||||||
|
job_context=job_context,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
jira_dc_user=jira_dc_user,
|
||||||
|
jira_dc_workspace=jira_dc_workspace,
|
||||||
|
selected_repo=None,
|
||||||
|
conversation_id=conversation.conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return JiraDcNewConversationView(
|
||||||
|
job_context=job_context,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
jira_dc_user=jira_dc_user,
|
||||||
|
jira_dc_workspace=jira_dc_workspace,
|
||||||
|
selected_repo=None, # Will be set later after repo inference
|
||||||
|
conversation_id='', # Will be set when conversation is created
|
||||||
|
)
|
||||||
522
enterprise/integrations/linear/linear_manager.py
Normal file
522
enterprise/integrations/linear/linear_manager.py
Normal file
@@ -0,0 +1,522 @@
|
|||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request
|
||||||
|
from integrations.linear.linear_types import LinearViewInterface
|
||||||
|
from integrations.linear.linear_view import (
|
||||||
|
LinearExistingConversationView,
|
||||||
|
LinearFactory,
|
||||||
|
LinearNewConversationView,
|
||||||
|
)
|
||||||
|
from integrations.manager import Manager
|
||||||
|
from integrations.models import JobContext, Message
|
||||||
|
from integrations.utils import (
|
||||||
|
HOST_URL,
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||||
|
filter_potential_repos_by_user_msg,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from server.auth.saas_user_auth import get_user_auth_from_keycloak_id
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from server.utils.conversation_callback_utils import register_callback_processor
|
||||||
|
from storage.linear_integration_store import LinearIntegrationStore
|
||||||
|
from storage.linear_user import LinearUser
|
||||||
|
from storage.linear_workspace import LinearWorkspace
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.provider import ProviderHandler
|
||||||
|
from openhands.integrations.service_types import Repository
|
||||||
|
from openhands.server.shared import server_config
|
||||||
|
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
|
||||||
|
class LinearManager(Manager):
|
||||||
|
def __init__(self, token_manager: TokenManager):
|
||||||
|
self.token_manager = token_manager
|
||||||
|
self.integration_store = LinearIntegrationStore.get_instance()
|
||||||
|
self.api_url = 'https://api.linear.app/graphql'
|
||||||
|
self.jinja_env = Environment(
|
||||||
|
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'linear')
|
||||||
|
)
|
||||||
|
|
||||||
|
async def authenticate_user(
|
||||||
|
self, linear_user_id: str, workspace_id: int
|
||||||
|
) -> tuple[LinearUser | None, UserAuth | None]:
|
||||||
|
"""Authenticate Linear user and get their OpenHands user auth."""
|
||||||
|
|
||||||
|
# Find active Linear user by Linear user ID and workspace ID
|
||||||
|
linear_user = await self.integration_store.get_active_user(
|
||||||
|
linear_user_id, workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not linear_user:
|
||||||
|
logger.warning(
|
||||||
|
f'[Linear] No active Linear user found for {linear_user_id} in workspace {workspace_id}'
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
saas_user_auth = await get_user_auth_from_keycloak_id(
|
||||||
|
linear_user.keycloak_user_id
|
||||||
|
)
|
||||||
|
return linear_user, saas_user_auth
|
||||||
|
|
||||||
|
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||||
|
"""Get repositories that the user has access to."""
|
||||||
|
provider_tokens = await user_auth.get_provider_tokens()
|
||||||
|
if provider_tokens is None:
|
||||||
|
return []
|
||||||
|
access_token = await user_auth.get_access_token()
|
||||||
|
user_id = await user_auth.get_user_id()
|
||||||
|
client = ProviderHandler(
|
||||||
|
provider_tokens=provider_tokens,
|
||||||
|
external_auth_token=access_token,
|
||||||
|
external_auth_id=user_id,
|
||||||
|
)
|
||||||
|
repos: list[Repository] = await client.get_repositories(
|
||||||
|
'pushed', server_config.app_mode, None, None, None, None
|
||||||
|
)
|
||||||
|
return repos
|
||||||
|
|
||||||
|
async def validate_request(
|
||||||
|
self, request: Request
|
||||||
|
) -> Tuple[bool, Optional[str], Optional[Dict]]:
|
||||||
|
"""Verify Linear webhook signature."""
|
||||||
|
signature = request.headers.get('linear-signature')
|
||||||
|
body = await request.body()
|
||||||
|
payload = await request.json()
|
||||||
|
actor_url = payload.get('actor', {}).get('url', '')
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
# Extract workspace name from actor URL
|
||||||
|
# Format: https://linear.app/{workspace}/profiles/{user}
|
||||||
|
if actor_url.startswith('https://linear.app/'):
|
||||||
|
url_parts = actor_url.split('/')
|
||||||
|
if len(url_parts) >= 4:
|
||||||
|
workspace_name = url_parts[3] # Extract workspace name
|
||||||
|
else:
|
||||||
|
logger.warning(f'[Linear] Invalid actor URL format: {actor_url}')
|
||||||
|
return False, None, None
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f'[Linear] Actor URL does not match expected format: {actor_url}'
|
||||||
|
)
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if not workspace_name:
|
||||||
|
logger.warning('[Linear] No workspace name found in webhook payload')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if not signature:
|
||||||
|
logger.warning('[Linear] No signature found in webhook headers')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
workspace = await self.integration_store.get_workspace_by_name(workspace_name)
|
||||||
|
|
||||||
|
if not workspace:
|
||||||
|
logger.warning('[Linear] Could not identify workspace for webhook')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
if workspace.status != 'active':
|
||||||
|
logger.warning(f'[Linear] Workspace {workspace.id} is not active')
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
webhook_secret = self.token_manager.decrypt_text(workspace.webhook_secret)
|
||||||
|
digest = hmac.new(webhook_secret.encode(), body, hashlib.sha256).hexdigest()
|
||||||
|
|
||||||
|
if hmac.compare_digest(signature, digest):
|
||||||
|
logger.info('[Linear] Webhook signature verified successfully')
|
||||||
|
return True, signature, payload
|
||||||
|
|
||||||
|
return False, None, None
|
||||||
|
|
||||||
|
def parse_webhook(self, payload: Dict) -> JobContext | None:
|
||||||
|
action = payload.get('action')
|
||||||
|
type = payload.get('type')
|
||||||
|
|
||||||
|
if action == 'create' and type == 'Comment':
|
||||||
|
data = payload.get('data', {})
|
||||||
|
comment = data.get('body', '')
|
||||||
|
|
||||||
|
if '@openhands' not in comment:
|
||||||
|
return None
|
||||||
|
|
||||||
|
issue_data = data.get('issue', {})
|
||||||
|
issue_id = issue_data.get('id', '')
|
||||||
|
issue_key = issue_data.get('identifier', '')
|
||||||
|
elif action == 'update' and type == 'Issue':
|
||||||
|
data = payload.get('data', {})
|
||||||
|
labels = data.get('labels', [])
|
||||||
|
|
||||||
|
has_openhands_label = False
|
||||||
|
label_id = ''
|
||||||
|
for label in labels:
|
||||||
|
if label.get('name') == 'openhands':
|
||||||
|
label_id = label.get('id', '')
|
||||||
|
has_openhands_label = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_openhands_label and not label_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
labelIdChanges = data.get('updatedFrom', {}).get('labelIds', [])
|
||||||
|
|
||||||
|
if labelIdChanges and label_id in labelIdChanges:
|
||||||
|
return None # Label was added previously, ignore this webhook
|
||||||
|
|
||||||
|
issue_id = data.get('id', '')
|
||||||
|
issue_key = data.get('identifier', '')
|
||||||
|
comment = ''
|
||||||
|
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
actor = payload.get('actor', {})
|
||||||
|
display_name = actor.get('name', '')
|
||||||
|
user_email = actor.get('email', '')
|
||||||
|
actor_url = actor.get('url', '')
|
||||||
|
actor_id = actor.get('id', '')
|
||||||
|
workspace_name = ''
|
||||||
|
|
||||||
|
if actor_url.startswith('https://linear.app/'):
|
||||||
|
url_parts = actor_url.split('/')
|
||||||
|
if len(url_parts) >= 4:
|
||||||
|
workspace_name = url_parts[3] # Extract workspace name
|
||||||
|
else:
|
||||||
|
logger.warning(f'[Linear] Invalid actor URL format: {actor_url}')
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f'[Linear] Actor URL does not match expected format: {actor_url}'
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not all(
|
||||||
|
[issue_id, issue_key, display_name, user_email, actor_id, workspace_name]
|
||||||
|
):
|
||||||
|
logger.warning('[Linear] Missing required fields in webhook payload')
|
||||||
|
return None
|
||||||
|
|
||||||
|
return JobContext(
|
||||||
|
issue_id=issue_id,
|
||||||
|
issue_key=issue_key,
|
||||||
|
user_msg=comment,
|
||||||
|
user_email=user_email,
|
||||||
|
platform_user_id=actor_id,
|
||||||
|
workspace_name=workspace_name,
|
||||||
|
display_name=display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
"""Process incoming Linear webhook message."""
|
||||||
|
payload = message.message.get('payload', {})
|
||||||
|
job_context = self.parse_webhook(payload)
|
||||||
|
|
||||||
|
if not job_context:
|
||||||
|
logger.info('[Linear] Webhook does not match trigger conditions')
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get workspace by user email domain
|
||||||
|
workspace = await self.integration_store.get_workspace_by_name(
|
||||||
|
job_context.workspace_name
|
||||||
|
)
|
||||||
|
if not workspace:
|
||||||
|
logger.warning(
|
||||||
|
f'[Linear] No workspace found for email domain: {job_context.workspace_name}'
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context.issue_id,
|
||||||
|
'Your workspace is not configured with Linear integration.',
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Prevent any recursive triggers from the service account
|
||||||
|
if job_context.user_email == workspace.svc_acc_email:
|
||||||
|
return
|
||||||
|
|
||||||
|
if workspace.status != 'active':
|
||||||
|
logger.warning(f'[Linear] Workspace {workspace.id} is not active')
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context.issue_id,
|
||||||
|
'Linear integration is not active for your workspace.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Authenticate user
|
||||||
|
linear_user, saas_user_auth = await self.authenticate_user(
|
||||||
|
job_context.platform_user_id, workspace.id
|
||||||
|
)
|
||||||
|
if not linear_user or not saas_user_auth:
|
||||||
|
logger.warning(
|
||||||
|
f'[Linear] User authentication failed for {job_context.user_email}'
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context.issue_id,
|
||||||
|
f'User {job_context.user_email} is not authenticated or active in the Linear integration.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Get issue details
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||||
|
issue_title, issue_description = await self.get_issue_details(
|
||||||
|
job_context.issue_id, api_key
|
||||||
|
)
|
||||||
|
job_context.issue_title = issue_title
|
||||||
|
job_context.issue_description = issue_description
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Linear] Failed to get issue context: {str(e)}')
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context.issue_id,
|
||||||
|
'Failed to retrieve issue details. Please check the issue ID and try again.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create Linear view
|
||||||
|
linear_view = await LinearFactory.create_linear_view_from_payload(
|
||||||
|
job_context,
|
||||||
|
saas_user_auth,
|
||||||
|
linear_user,
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Linear] Failed to create linear view: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
await self._send_error_comment(
|
||||||
|
job_context.issue_id,
|
||||||
|
'Failed to initialize conversation. Please try again.',
|
||||||
|
workspace,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await self.is_job_requested(message, linear_view):
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.start_job(linear_view)
|
||||||
|
|
||||||
|
async def is_job_requested(
|
||||||
|
self, message: Message, linear_view: LinearViewInterface
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a job is requested and handle repository selection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(linear_view, LinearExistingConversationView):
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get user repositories
|
||||||
|
user_repos: list[Repository] = await self._get_repositories(
|
||||||
|
linear_view.saas_user_auth
|
||||||
|
)
|
||||||
|
|
||||||
|
target_str = f'{linear_view.job_context.issue_description}\n{linear_view.job_context.user_msg}'
|
||||||
|
|
||||||
|
# Try to infer repository from issue description
|
||||||
|
match, repos = filter_potential_repos_by_user_msg(target_str, user_repos)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
# Found exact repository match
|
||||||
|
linear_view.selected_repo = repos[0].full_name
|
||||||
|
logger.info(f'[Linear] Inferred repository: {repos[0].full_name}')
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# No clear match - send repository selection comment
|
||||||
|
await self._send_repo_selection_comment(linear_view)
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Linear] Error in is_job_requested: {str(e)}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def start_job(self, linear_view: LinearViewInterface):
|
||||||
|
"""Start a Linear job/conversation."""
|
||||||
|
# Import here to prevent circular import
|
||||||
|
from server.conversation_callback_processor.linear_callback_processor import (
|
||||||
|
LinearCallbackProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_info: LinearUser = linear_view.linear_user
|
||||||
|
logger.info(
|
||||||
|
f'[Linear] Starting job for user {user_info.keycloak_user_id} '
|
||||||
|
f'issue {linear_view.job_context.issue_key}',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create conversation
|
||||||
|
conversation_id = await linear_view.create_or_update_conversation(
|
||||||
|
self.jinja_env
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Linear] Created/Updated conversation {conversation_id} for issue {linear_view.job_context.issue_key}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(linear_view, LinearNewConversationView):
|
||||||
|
# Register callback processor for updates
|
||||||
|
processor = LinearCallbackProcessor(
|
||||||
|
issue_id=linear_view.job_context.issue_id,
|
||||||
|
issue_key=linear_view.job_context.issue_key,
|
||||||
|
workspace_name=linear_view.linear_workspace.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback processor
|
||||||
|
register_callback_processor(conversation_id, processor)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Linear] Created callback processor for conversation {conversation_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Send initial response
|
||||||
|
msg_info = linear_view.get_response_msg()
|
||||||
|
|
||||||
|
except MissingSettingsError as e:
|
||||||
|
logger.warning(f'[Linear] Missing settings error: {str(e)}')
|
||||||
|
msg_info = f'Please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except LLMAuthenticationError as e:
|
||||||
|
logger.warning(f'[Linear] LLM authentication error: {str(e)}')
|
||||||
|
msg_info = f'Please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Linear] Unexpected error starting job: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
msg_info = 'Sorry, there was an unexpected error starting the job. Please try again.'
|
||||||
|
|
||||||
|
# Send response comment
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(
|
||||||
|
linear_view.linear_workspace.svc_acc_api_key
|
||||||
|
)
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=msg_info),
|
||||||
|
linear_view.job_context.issue_id,
|
||||||
|
api_key,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Linear] Failed to send response message: {str(e)}')
|
||||||
|
|
||||||
|
async def _query_api(self, query: str, variables: Dict, api_key: str) -> Dict:
|
||||||
|
"""Query Linear GraphQL API."""
|
||||||
|
headers = {'Authorization': api_key}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.api_url,
|
||||||
|
headers=headers,
|
||||||
|
json={'query': query, 'variables': variables},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_issue_details(self, issue_id: str, api_key: str) -> Tuple[str, str]:
|
||||||
|
"""Get issue details from Linear API."""
|
||||||
|
query = """
|
||||||
|
query Issue($issueId: String!) {
|
||||||
|
issue(id: $issueId) {
|
||||||
|
id
|
||||||
|
identifier
|
||||||
|
title
|
||||||
|
description
|
||||||
|
syncedWith {
|
||||||
|
metadata {
|
||||||
|
... on ExternalEntityInfoGithubMetadata {
|
||||||
|
owner
|
||||||
|
repo
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
issue_payload = await self._query_api(query, {'issueId': issue_id}, api_key)
|
||||||
|
|
||||||
|
if not issue_payload:
|
||||||
|
raise ValueError(f'Issue with ID {issue_id} not found.')
|
||||||
|
|
||||||
|
issue_data = issue_payload.get('data', {}).get('issue', {})
|
||||||
|
title = issue_data.get('title', '')
|
||||||
|
description = issue_data.get('description', '')
|
||||||
|
synced_with = issue_data.get('syncedWith', [])
|
||||||
|
owner = ''
|
||||||
|
repo = ''
|
||||||
|
if synced_with:
|
||||||
|
owner = synced_with[0].get('metadata', {}).get('owner', '')
|
||||||
|
repo = synced_with[0].get('metadata', {}).get('repo', '')
|
||||||
|
|
||||||
|
if not title:
|
||||||
|
raise ValueError(f'Issue with ID {issue_id} does not have a title.')
|
||||||
|
|
||||||
|
if not description:
|
||||||
|
raise ValueError(f'Issue with ID {issue_id} does not have a description.')
|
||||||
|
|
||||||
|
if owner and repo:
|
||||||
|
description += f'\n\nGit Repo: {owner}/{repo}'
|
||||||
|
|
||||||
|
return title, description
|
||||||
|
|
||||||
|
async def send_message(self, message: Message, issue_id: str, api_key: str):
|
||||||
|
"""Send message/comment to Linear issue."""
|
||||||
|
query = """
|
||||||
|
mutation CommentCreate($input: CommentCreateInput!) {
|
||||||
|
commentCreate(input: $input) {
|
||||||
|
success
|
||||||
|
comment {
|
||||||
|
id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
variables = {'input': {'issueId': issue_id, 'body': message.message}}
|
||||||
|
return await self._query_api(query, variables, api_key)
|
||||||
|
|
||||||
|
async def _send_error_comment(
|
||||||
|
self, issue_id: str, error_msg: str, workspace: LinearWorkspace | None
|
||||||
|
):
|
||||||
|
"""Send error comment to Linear issue."""
|
||||||
|
if not workspace:
|
||||||
|
logger.error('[Linear] Cannot send error comment - no workspace available')
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_key = self.token_manager.decrypt_text(workspace.svc_acc_api_key)
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=error_msg), issue_id, api_key
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f'[Linear] Failed to send error comment: {str(e)}')
|
||||||
|
|
||||||
|
async def _send_repo_selection_comment(self, linear_view: LinearViewInterface):
|
||||||
|
"""Send a comment with repository options for the user to choose."""
|
||||||
|
try:
|
||||||
|
comment_msg = (
|
||||||
|
'I need to know which repository to work with. '
|
||||||
|
'Please add it to your issue description or send a followup comment.'
|
||||||
|
)
|
||||||
|
|
||||||
|
api_key = self.token_manager.decrypt_text(
|
||||||
|
linear_view.linear_workspace.svc_acc_api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg=comment_msg),
|
||||||
|
linear_view.job_context.issue_id,
|
||||||
|
api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Linear] Sent repository selection comment for issue {linear_view.job_context.issue_key}'
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Linear] Failed to send repository selection comment: {str(e)}'
|
||||||
|
)
|
||||||
40
enterprise/integrations/linear/linear_types.py
Normal file
40
enterprise/integrations/linear/linear_types.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from integrations.models import JobContext
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.linear_user import LinearUser
|
||||||
|
from storage.linear_workspace import LinearWorkspace
|
||||||
|
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
|
||||||
|
class LinearViewInterface(ABC):
|
||||||
|
"""Interface for Linear views that handle different types of Linear interactions."""
|
||||||
|
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
linear_user: LinearUser
|
||||||
|
linear_workspace: LinearWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Get initial instructions for the conversation."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Create or update a conversation and return the conversation ID."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Linear."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StartingConvoException(Exception):
|
||||||
|
"""Exception raised when starting a conversation fails."""
|
||||||
|
|
||||||
|
pass
|
||||||
224
enterprise/integrations/linear/linear_view.py
Normal file
224
enterprise/integrations/linear/linear_view.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from integrations.linear.linear_types import LinearViewInterface, StartingConvoException
|
||||||
|
from integrations.models import JobContext
|
||||||
|
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.linear_conversation import LinearConversation
|
||||||
|
from storage.linear_integration_store import LinearIntegrationStore
|
||||||
|
from storage.linear_user import LinearUser
|
||||||
|
from storage.linear_workspace import LinearWorkspace
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.schema.agent import AgentState
|
||||||
|
from openhands.events.action import MessageAction
|
||||||
|
from openhands.events.serialization.event import event_to_dict
|
||||||
|
from openhands.server.services.conversation_service import (
|
||||||
|
create_new_conversation,
|
||||||
|
setup_init_conversation_settings,
|
||||||
|
)
|
||||||
|
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
|
||||||
|
integration_store = LinearIntegrationStore.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LinearNewConversationView(LinearViewInterface):
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
linear_user: LinearUser
|
||||||
|
linear_workspace: LinearWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Instructions passed when conversation is first initialized"""
|
||||||
|
|
||||||
|
instructions_template = jinja_env.get_template('linear_instructions.j2')
|
||||||
|
instructions = instructions_template.render()
|
||||||
|
|
||||||
|
user_msg_template = jinja_env.get_template('linear_new_conversation.j2')
|
||||||
|
|
||||||
|
user_msg = user_msg_template.render(
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
issue_title=self.job_context.issue_title,
|
||||||
|
issue_description=self.job_context.issue_description,
|
||||||
|
user_message=self.job_context.user_msg or '',
|
||||||
|
)
|
||||||
|
|
||||||
|
return instructions, user_msg
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Create a new Linear conversation"""
|
||||||
|
|
||||||
|
if not self.selected_repo:
|
||||||
|
raise StartingConvoException('No repository selected for this conversation')
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||||
|
instructions, user_msg = self._get_instructions(jinja_env)
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent_loop_info = await create_new_conversation(
|
||||||
|
user_id=self.linear_user.keycloak_user_id,
|
||||||
|
git_provider_tokens=provider_tokens,
|
||||||
|
selected_repository=self.selected_repo,
|
||||||
|
selected_branch=None,
|
||||||
|
initial_user_msg=user_msg,
|
||||||
|
conversation_instructions=instructions,
|
||||||
|
image_urls=None,
|
||||||
|
replay_json=None,
|
||||||
|
conversation_trigger=ConversationTrigger.LINEAR,
|
||||||
|
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_id = agent_loop_info.conversation_id
|
||||||
|
|
||||||
|
logger.info(f'[Linear] Created conversation {self.conversation_id}')
|
||||||
|
|
||||||
|
# Store Linear conversation mapping
|
||||||
|
linear_conversation = LinearConversation(
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
issue_id=self.job_context.issue_id,
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
linear_user_id=self.linear_user.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await integration_store.create_conversation(linear_conversation)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Linear] Failed to create conversation: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Linear"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {self.job_context.display_name} can [track my progress here]({conversation_link})."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LinearExistingConversationView(LinearViewInterface):
|
||||||
|
job_context: JobContext
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
linear_user: LinearUser
|
||||||
|
linear_workspace: LinearWorkspace
|
||||||
|
selected_repo: str | None
|
||||||
|
conversation_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"""Instructions passed when conversation is first initialized"""
|
||||||
|
|
||||||
|
user_msg_template = jinja_env.get_template('linear_existing_conversation.j2')
|
||||||
|
user_msg = user_msg_template.render(
|
||||||
|
issue_key=self.job_context.issue_key,
|
||||||
|
user_message=self.job_context.user_msg or '',
|
||||||
|
issue_title=self.job_context.issue_title,
|
||||||
|
issue_description=self.job_context.issue_description,
|
||||||
|
)
|
||||||
|
|
||||||
|
return '', user_msg
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment) -> str:
|
||||||
|
"""Update an existing Linear conversation"""
|
||||||
|
|
||||||
|
user_id = self.linear_user.keycloak_user_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
conversation_store = await ConversationStoreImpl.get_instance(
|
||||||
|
config, user_id
|
||||||
|
)
|
||||||
|
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||||
|
if not metadata:
|
||||||
|
raise StartingConvoException('Conversation no longer exists.')
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
if provider_tokens is None:
|
||||||
|
raise ValueError('Could not load provider tokens')
|
||||||
|
providers_set = list(provider_tokens.keys())
|
||||||
|
|
||||||
|
conversation_init_data = await setup_init_conversation_settings(
|
||||||
|
user_id, self.conversation_id, providers_set
|
||||||
|
)
|
||||||
|
|
||||||
|
# Either join ongoing conversation, or restart the conversation
|
||||||
|
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||||
|
self.conversation_id, conversation_init_data, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
final_agent_observation = get_final_agent_observation(
|
||||||
|
agent_loop_info.event_store
|
||||||
|
)
|
||||||
|
agent_state = (
|
||||||
|
None
|
||||||
|
if len(final_agent_observation) == 0
|
||||||
|
else final_agent_observation[0].agent_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_state or agent_state == AgentState.LOADING:
|
||||||
|
raise StartingConvoException('Conversation is still starting')
|
||||||
|
|
||||||
|
_, user_msg = self._get_instructions(jinja_env)
|
||||||
|
user_message_event = MessageAction(content=user_msg)
|
||||||
|
await conversation_manager.send_event_to_conversation(
|
||||||
|
self.conversation_id, event_to_dict(user_message_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Linear] Failed to create conversation: {str(e)}', exc_info=True
|
||||||
|
)
|
||||||
|
raise StartingConvoException(f'Failed to create conversation: {str(e)}')
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
"""Get the response message to send back to Linear"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {self.job_context.display_name} can [continue tracking my progress here]({conversation_link})."
|
||||||
|
|
||||||
|
|
||||||
|
class LinearFactory:
|
||||||
|
"""Factory for creating Linear views based on message content"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def create_linear_view_from_payload(
|
||||||
|
job_context: JobContext,
|
||||||
|
saas_user_auth: UserAuth,
|
||||||
|
linear_user: LinearUser,
|
||||||
|
linear_workspace: LinearWorkspace,
|
||||||
|
) -> LinearViewInterface:
|
||||||
|
"""Create appropriate Linear view based on the message and user state"""
|
||||||
|
|
||||||
|
if not linear_user or not saas_user_auth or not linear_workspace:
|
||||||
|
raise StartingConvoException(
|
||||||
|
'User not authenticated with Linear integration'
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = await integration_store.get_user_conversations_by_issue_id(
|
||||||
|
job_context.issue_id, linear_user.id
|
||||||
|
)
|
||||||
|
if conversation:
|
||||||
|
logger.info(
|
||||||
|
f'[Linear] Found existing conversation for issue {job_context.issue_id}'
|
||||||
|
)
|
||||||
|
return LinearExistingConversationView(
|
||||||
|
job_context=job_context,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
linear_user=linear_user,
|
||||||
|
linear_workspace=linear_workspace,
|
||||||
|
selected_repo=None,
|
||||||
|
conversation_id=conversation.conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LinearNewConversationView(
|
||||||
|
job_context=job_context,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
linear_user=linear_user,
|
||||||
|
linear_workspace=linear_workspace,
|
||||||
|
selected_repo=None, # Will be set later after repo inference
|
||||||
|
conversation_id='', # Will be set when conversation is created
|
||||||
|
)
|
||||||
30
enterprise/integrations/manager.py
Normal file
30
enterprise/integrations/manager.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from integrations.models import Message, SourceType
|
||||||
|
|
||||||
|
|
||||||
|
class Manager(ABC):
|
||||||
|
manager_type: SourceType
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
"Receive message from integration"
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def send_message(self, message: Message):
|
||||||
|
"Send message to integration from Openhands server"
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def is_job_requested(self, message: Message) -> bool:
|
||||||
|
"Confirm that a job is being requested"
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start_job(self):
|
||||||
|
"Kick off a job with openhands agent"
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def create_outgoing_message(self, msg: str | dict, ephemeral: bool = False):
|
||||||
|
return Message(source=SourceType.OPENHANDS, message=msg, ephemeral=ephemeral)
|
||||||
52
enterprise/integrations/models.py
Normal file
52
enterprise/integrations/models.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from openhands.core.schema import AgentState
|
||||||
|
|
||||||
|
|
||||||
|
class SourceType(str, Enum):
|
||||||
|
GITHUB = 'github'
|
||||||
|
GITLAB = 'gitlab'
|
||||||
|
OPENHANDS = 'openhands'
|
||||||
|
SLACK = 'slack'
|
||||||
|
JIRA = 'jira'
|
||||||
|
JIRA_DC = 'jira_dc'
|
||||||
|
LINEAR = 'linear'
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
source: SourceType
|
||||||
|
message: str | dict
|
||||||
|
ephemeral: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class JobContext(BaseModel):
|
||||||
|
issue_id: str
|
||||||
|
issue_key: str
|
||||||
|
user_msg: str
|
||||||
|
user_email: str
|
||||||
|
display_name: str
|
||||||
|
platform_user_id: str = ''
|
||||||
|
workspace_name: str
|
||||||
|
base_api_url: str = ''
|
||||||
|
issue_title: str = ''
|
||||||
|
issue_description: str = ''
|
||||||
|
|
||||||
|
|
||||||
|
class JobResult:
|
||||||
|
result: str
|
||||||
|
explanation: str
|
||||||
|
|
||||||
|
|
||||||
|
class GithubResolverJob:
|
||||||
|
type: SourceType
|
||||||
|
status: AgentState
|
||||||
|
result: JobResult
|
||||||
|
owner: str
|
||||||
|
repo: str
|
||||||
|
installation_token: str
|
||||||
|
issue_number: int
|
||||||
|
runtime_id: int
|
||||||
|
created_at: int
|
||||||
|
completed_at: int
|
||||||
363
enterprise/integrations/slack/slack_manager.py
Normal file
363
enterprise/integrations/slack/slack_manager.py
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
import jwt
|
||||||
|
from integrations.manager import Manager
|
||||||
|
from integrations.models import Message, SourceType
|
||||||
|
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||||
|
from integrations.slack.slack_view import (
|
||||||
|
SlackFactory,
|
||||||
|
SlackNewConversationFromRepoFormView,
|
||||||
|
SlackNewConversationView,
|
||||||
|
SlackUnkownUserView,
|
||||||
|
SlackUpdateExistingConversationView,
|
||||||
|
)
|
||||||
|
from integrations.utils import (
|
||||||
|
HOST_URL,
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR,
|
||||||
|
)
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from pydantic import SecretStr
|
||||||
|
from server.auth.saas_user_auth import SaasUserAuth
|
||||||
|
from server.constants import SLACK_CLIENT_ID
|
||||||
|
from server.utils.conversation_callback_utils import register_callback_processor
|
||||||
|
from slack_sdk.oauth import AuthorizeUrlGenerator
|
||||||
|
from slack_sdk.web.async_client import AsyncWebClient
|
||||||
|
from storage.database import session_maker
|
||||||
|
from storage.slack_user import SlackUser
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.integrations.provider import ProviderHandler
|
||||||
|
from openhands.integrations.service_types import Repository
|
||||||
|
from openhands.server.shared import config, server_config
|
||||||
|
from openhands.server.types import LLMAuthenticationError, MissingSettingsError
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
authorize_url_generator = AuthorizeUrlGenerator(
|
||||||
|
client_id=SLACK_CLIENT_ID,
|
||||||
|
scopes=['app_mentions:read', 'chat:write'],
|
||||||
|
user_scopes=['search:read'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SlackManager(Manager):
|
||||||
|
def __init__(self, token_manager):
|
||||||
|
self.token_manager = token_manager
|
||||||
|
self.login_link = (
|
||||||
|
'User has not yet authenticated: [Click here to Login to OpenHands]({}).'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.jinja_env = Environment(
|
||||||
|
loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR + 'slack')
|
||||||
|
)
|
||||||
|
|
||||||
|
def _confirm_incoming_source_type(self, message: Message):
|
||||||
|
if message.source != SourceType.SLACK:
|
||||||
|
raise ValueError(f'Unexpected message source {message.source}')
|
||||||
|
|
||||||
|
async def _get_user_auth(self, keycloak_user_id: str) -> UserAuth:
|
||||||
|
offline_token = await self.token_manager.load_offline_token(keycloak_user_id)
|
||||||
|
if offline_token is None:
|
||||||
|
logger.info('no_offline_token_found')
|
||||||
|
|
||||||
|
user_auth = SaasUserAuth(
|
||||||
|
user_id=keycloak_user_id,
|
||||||
|
refresh_token=SecretStr(offline_token),
|
||||||
|
)
|
||||||
|
return user_auth
|
||||||
|
|
||||||
|
async def authenticate_user(
|
||||||
|
self, slack_user_id: str
|
||||||
|
) -> tuple[SlackUser | None, UserAuth | None]:
|
||||||
|
# We get the user and correlate them back to a user in OpenHands - if we can
|
||||||
|
slack_user = None
|
||||||
|
with session_maker() as session:
|
||||||
|
slack_user = (
|
||||||
|
session.query(SlackUser)
|
||||||
|
.filter(SlackUser.slack_user_id == slack_user_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# slack_view.slack_to_openhands_user = slack_user # attach user auth info to view
|
||||||
|
|
||||||
|
saas_user_auth = None
|
||||||
|
if slack_user:
|
||||||
|
saas_user_auth = await self._get_user_auth(slack_user.keycloak_user_id)
|
||||||
|
# slack_view.saas_user_auth = await self._get_user_auth(slack_view.slack_to_openhands_user.keycloak_user_id)
|
||||||
|
|
||||||
|
return slack_user, saas_user_auth
|
||||||
|
|
||||||
|
def _infer_repo_from_message(self, user_msg: str) -> str | None:
|
||||||
|
# Regular expression to match patterns like "All-Hands-AI/OpenHands" or "deploy repo"
|
||||||
|
pattern = r'([a-zA-Z0-9_-]+/[a-zA-Z0-9_-]+)|([a-zA-Z0-9_-]+)(?=\s+repo)'
|
||||||
|
match = re.search(pattern, user_msg)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
repo = match.group(1) if match.group(1) else match.group(2)
|
||||||
|
return repo
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_repositories(self, user_auth: UserAuth) -> list[Repository]:
|
||||||
|
provider_tokens = await user_auth.get_provider_tokens()
|
||||||
|
if provider_tokens is None:
|
||||||
|
return []
|
||||||
|
access_token = await user_auth.get_access_token()
|
||||||
|
user_id = await user_auth.get_user_id()
|
||||||
|
client = ProviderHandler(
|
||||||
|
provider_tokens=provider_tokens,
|
||||||
|
external_auth_token=access_token,
|
||||||
|
external_auth_id=user_id,
|
||||||
|
)
|
||||||
|
repos: list[Repository] = await client.get_repositories(
|
||||||
|
'pushed', server_config.app_mode, None, None, None, None
|
||||||
|
)
|
||||||
|
return repos
|
||||||
|
|
||||||
|
def _generate_repo_selection_form(
|
||||||
|
self, repo_list: list[Repository], message_ts: str, thread_ts: str | None
|
||||||
|
):
|
||||||
|
options = [
|
||||||
|
{
|
||||||
|
'text': {'type': 'plain_text', 'text': 'No Repository'},
|
||||||
|
'value': '-',
|
||||||
|
}
|
||||||
|
]
|
||||||
|
options.extend(
|
||||||
|
{
|
||||||
|
'text': {
|
||||||
|
'type': 'plain_text',
|
||||||
|
'text': repo.full_name,
|
||||||
|
},
|
||||||
|
'value': repo.full_name,
|
||||||
|
}
|
||||||
|
for repo in repo_list
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
'type': 'header',
|
||||||
|
'text': {
|
||||||
|
'type': 'plain_text',
|
||||||
|
'text': 'Choose a repository',
|
||||||
|
'emoji': True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'type': 'actions',
|
||||||
|
'elements': [
|
||||||
|
{
|
||||||
|
'type': 'static_select',
|
||||||
|
'action_id': f'repository_select:{message_ts}:{thread_ts}',
|
||||||
|
'options': options,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def filter_potential_repos_by_user_msg(
|
||||||
|
self, user_msg: str, user_repos: list[Repository]
|
||||||
|
) -> tuple[bool, list[Repository]]:
|
||||||
|
inferred_repo = self._infer_repo_from_message(user_msg)
|
||||||
|
if not inferred_repo:
|
||||||
|
return False, user_repos[0:99]
|
||||||
|
|
||||||
|
final_repos = []
|
||||||
|
for repo in user_repos:
|
||||||
|
if inferred_repo.lower() in repo.full_name.lower():
|
||||||
|
final_repos.append(repo)
|
||||||
|
|
||||||
|
# no repos matched, return original list
|
||||||
|
if len(final_repos) == 0:
|
||||||
|
return False, user_repos[0:99]
|
||||||
|
|
||||||
|
# Found exact match
|
||||||
|
elif len(final_repos) == 1:
|
||||||
|
return True, final_repos
|
||||||
|
|
||||||
|
# Found partial matches
|
||||||
|
return False, final_repos[0:99]
|
||||||
|
|
||||||
|
async def receive_message(self, message: Message):
|
||||||
|
self._confirm_incoming_source_type(message)
|
||||||
|
|
||||||
|
slack_user, saas_user_auth = await self.authenticate_user(
|
||||||
|
slack_user_id=message.message['slack_user_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
slack_view = SlackFactory.create_slack_view_from_payload(
|
||||||
|
message, slack_user, saas_user_auth
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f'[Slack]: Failed to create slack view: {e}',
|
||||||
|
exc_info=True,
|
||||||
|
stack_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(slack_view, SlackUnkownUserView):
|
||||||
|
jwt_secret = config.jwt_secret
|
||||||
|
if not jwt_secret:
|
||||||
|
raise ValueError('Must configure jwt_secret')
|
||||||
|
state = jwt.encode(
|
||||||
|
message.message, jwt_secret.get_secret_value(), algorithm='HS256'
|
||||||
|
)
|
||||||
|
link = authorize_url_generator.generate(state)
|
||||||
|
msg = self.login_link.format(link)
|
||||||
|
|
||||||
|
logger.info('slack_not_yet_authenticated')
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(msg, ephemeral=True), slack_view
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not await self.is_job_requested(message, slack_view):
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.start_job(slack_view)
|
||||||
|
|
||||||
|
async def send_message(self, message: Message, slack_view: SlackViewInterface):
|
||||||
|
client = AsyncWebClient(token=slack_view.bot_access_token)
|
||||||
|
if message.ephemeral and isinstance(message.message, str):
|
||||||
|
await client.chat_postEphemeral(
|
||||||
|
channel=slack_view.channel_id,
|
||||||
|
markdown_text=message.message,
|
||||||
|
user=slack_view.slack_user_id,
|
||||||
|
thread_ts=slack_view.thread_ts,
|
||||||
|
)
|
||||||
|
elif message.ephemeral and isinstance(message.message, dict):
|
||||||
|
await client.chat_postEphemeral(
|
||||||
|
channel=slack_view.channel_id,
|
||||||
|
user=slack_view.slack_user_id,
|
||||||
|
thread_ts=slack_view.thread_ts,
|
||||||
|
text=message.message['text'],
|
||||||
|
blocks=message.message['blocks'],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await client.chat_postMessage(
|
||||||
|
channel=slack_view.channel_id,
|
||||||
|
markdown_text=message.message,
|
||||||
|
thread_ts=slack_view.message_ts,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def is_job_requested(
|
||||||
|
self, message: Message, slack_view: SlackViewInterface
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
A job is always request we only receive webhooks for events associated with the slack bot
|
||||||
|
This method really just checks
|
||||||
|
1. Is the user is authenticated
|
||||||
|
2. Do we have the necessary information to start a job (either by inferring the selected repo, otherwise asking the user)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Infer repo from user message is not needed; user selected repo from the form or is updating existing convo
|
||||||
|
if isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||||
|
return True
|
||||||
|
elif isinstance(slack_view, SlackNewConversationFromRepoFormView):
|
||||||
|
return True
|
||||||
|
elif isinstance(slack_view, SlackNewConversationView):
|
||||||
|
user = slack_view.slack_to_openhands_user
|
||||||
|
user_repos: list[Repository] = await self._get_repositories(
|
||||||
|
slack_view.saas_user_auth
|
||||||
|
)
|
||||||
|
match, repos = self.filter_potential_repos_by_user_msg(
|
||||||
|
slack_view.user_msg, user_repos
|
||||||
|
)
|
||||||
|
|
||||||
|
# User mentioned a matching repo is their message, start job without repo selection form
|
||||||
|
if match:
|
||||||
|
slack_view.selected_repo = repos[0].full_name
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'render_repository_selector',
|
||||||
|
extra={
|
||||||
|
'slack_user_id': user,
|
||||||
|
'keycloak_user_id': user.keycloak_user_id,
|
||||||
|
'message_ts': slack_view.message_ts,
|
||||||
|
'thread_ts': slack_view.thread_ts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
repo_selection_msg = {
|
||||||
|
'text': 'Choose a Repository:',
|
||||||
|
'blocks': self._generate_repo_selection_form(
|
||||||
|
repos, slack_view.message_ts, slack_view.thread_ts
|
||||||
|
),
|
||||||
|
}
|
||||||
|
await self.send_message(
|
||||||
|
self.create_outgoing_message(repo_selection_msg, ephemeral=True),
|
||||||
|
slack_view,
|
||||||
|
)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def start_job(self, slack_view: SlackViewInterface):
|
||||||
|
# Importing here prevents circular import
|
||||||
|
from server.conversation_callback_processor.slack_callback_processor import (
|
||||||
|
SlackCallbackProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg_info = None
|
||||||
|
user_info: SlackUser = slack_view.slack_to_openhands_user
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f'[Slack] Starting job for user {user_info.slack_display_name} (id={user_info.slack_user_id})',
|
||||||
|
extra={'keyloak_user_id': user_info.keycloak_user_id},
|
||||||
|
)
|
||||||
|
conversation_id = await slack_view.create_or_update_conversation(
|
||||||
|
self.jinja_env
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Slack] Created conversation {conversation_id} for user {user_info.slack_display_name}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(slack_view, SlackUpdateExistingConversationView):
|
||||||
|
# We don't re-subscribe for follow up messages from slack.
|
||||||
|
# Summaries are generated for every messages anyways, we only need to do
|
||||||
|
# this subscription once for the event which kicked off the job.
|
||||||
|
processor = SlackCallbackProcessor(
|
||||||
|
slack_user_id=slack_view.slack_user_id,
|
||||||
|
channel_id=slack_view.channel_id,
|
||||||
|
message_ts=slack_view.message_ts,
|
||||||
|
thread_ts=slack_view.thread_ts,
|
||||||
|
team_id=slack_view.team_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the callback processor
|
||||||
|
register_callback_processor(conversation_id, processor)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f'[Slack] Created callback processor for conversation {conversation_id}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = slack_view.get_response_msg()
|
||||||
|
|
||||||
|
except MissingSettingsError as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[Slack] Missing settings error for user {user_info.slack_display_name}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = f'{user_info.slack_display_name} please re-login into [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except LLMAuthenticationError as e:
|
||||||
|
logger.warning(
|
||||||
|
f'[Slack] LLM authentication error for user {user_info.slack_display_name}: {str(e)}'
|
||||||
|
)
|
||||||
|
|
||||||
|
msg_info = f'@{user_info.slack_display_name} please set a valid LLM API key in [OpenHands Cloud]({HOST_URL}) before starting a job.'
|
||||||
|
|
||||||
|
except StartingConvoException as e:
|
||||||
|
msg_info = str(e)
|
||||||
|
|
||||||
|
await self.send_message(self.create_outgoing_message(msg_info), slack_view)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception('[Slack]: Error starting job')
|
||||||
|
msg = 'Uh oh! There was an unexpected error starting the job :('
|
||||||
|
await self.send_message(self.create_outgoing_message(msg), slack_view)
|
||||||
48
enterprise/integrations/slack/slack_types.py
Normal file
48
enterprise/integrations/slack/slack_types.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from integrations.types import SummaryExtractionTracker
|
||||||
|
from jinja2 import Environment
|
||||||
|
from storage.slack_user import SlackUser
|
||||||
|
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
|
||||||
|
|
||||||
|
class SlackViewInterface(SummaryExtractionTracker, ABC):
|
||||||
|
bot_access_token: str
|
||||||
|
user_msg: str | None
|
||||||
|
slack_user_id: str
|
||||||
|
slack_to_openhands_user: SlackUser | None
|
||||||
|
saas_user_auth: UserAuth | None
|
||||||
|
channel_id: str
|
||||||
|
message_ts: str
|
||||||
|
thread_ts: str | None
|
||||||
|
selected_repo: str | None
|
||||||
|
should_extract: bool
|
||||||
|
send_summary_instruction: bool
|
||||||
|
conversation_id: str
|
||||||
|
team_id: str
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"Instructions passed when conversation is first initialized"
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||||
|
"Create a new conversation"
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_callback_id(self) -> str:
|
||||||
|
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StartingConvoException(Exception):
|
||||||
|
"""
|
||||||
|
Raised when trying to send message to a conversation that's is still starting up
|
||||||
|
"""
|
||||||
435
enterprise/integrations/slack/slack_view.py
Normal file
435
enterprise/integrations/slack/slack_view.py
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from integrations.models import Message
|
||||||
|
from integrations.slack.slack_types import SlackViewInterface, StartingConvoException
|
||||||
|
from integrations.utils import CONVERSATION_URL, get_final_agent_observation
|
||||||
|
from jinja2 import Environment
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
from storage.slack_conversation import SlackConversation
|
||||||
|
from storage.slack_conversation_store import SlackConversationStore
|
||||||
|
from storage.slack_team_store import SlackTeamStore
|
||||||
|
from storage.slack_user import SlackUser
|
||||||
|
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.schema.agent import AgentState
|
||||||
|
from openhands.events.action import MessageAction
|
||||||
|
from openhands.events.serialization.event import event_to_dict
|
||||||
|
from openhands.server.services.conversation_service import (
|
||||||
|
create_new_conversation,
|
||||||
|
setup_init_conversation_settings,
|
||||||
|
)
|
||||||
|
from openhands.server.shared import ConversationStoreImpl, config, conversation_manager
|
||||||
|
from openhands.server.user_auth.user_auth import UserAuth
|
||||||
|
from openhands.storage.data_models.conversation_metadata import ConversationTrigger
|
||||||
|
from openhands.utils.async_utils import GENERAL_TIMEOUT, call_async_from_sync
|
||||||
|
|
||||||
|
# =================================================
|
||||||
|
# SECTION: Github view types
|
||||||
|
# =================================================
|
||||||
|
|
||||||
|
|
||||||
|
CONTEXT_LIMIT = 21
|
||||||
|
slack_conversation_store = SlackConversationStore.get_instance()
|
||||||
|
slack_team_store = SlackTeamStore.get_instance()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SlackUnkownUserView(SlackViewInterface):
|
||||||
|
bot_access_token: str
|
||||||
|
user_msg: str | None
|
||||||
|
slack_user_id: str
|
||||||
|
slack_to_openhands_user: SlackUser | None
|
||||||
|
saas_user_auth: UserAuth | None
|
||||||
|
channel_id: str
|
||||||
|
message_ts: str
|
||||||
|
thread_ts: str | None
|
||||||
|
selected_repo: str | None
|
||||||
|
should_extract: bool
|
||||||
|
send_summary_instruction: bool
|
||||||
|
conversation_id: str
|
||||||
|
team_id: str
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja_env: Environment):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_callback_id(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SlackNewConversationView(SlackViewInterface):
|
||||||
|
bot_access_token: str
|
||||||
|
user_msg: str | None
|
||||||
|
slack_user_id: str
|
||||||
|
slack_to_openhands_user: SlackUser
|
||||||
|
saas_user_auth: UserAuth
|
||||||
|
channel_id: str
|
||||||
|
message_ts: str
|
||||||
|
thread_ts: str | None
|
||||||
|
selected_repo: str | None
|
||||||
|
should_extract: bool
|
||||||
|
send_summary_instruction: bool
|
||||||
|
conversation_id: str
|
||||||
|
team_id: str
|
||||||
|
|
||||||
|
def _get_initial_prompt(self, text: str, blocks: list[dict]):
|
||||||
|
bot_id = self._get_bot_id(blocks)
|
||||||
|
text = text.replace(f'<@{bot_id}>', '').strip()
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _get_bot_id(self, blocks: list[dict]) -> str:
|
||||||
|
for block in blocks:
|
||||||
|
type_ = block['type']
|
||||||
|
if type_ in ('rich_text', 'rich_text_section'):
|
||||||
|
bot_id = self._get_bot_id(block['elements'])
|
||||||
|
if bot_id:
|
||||||
|
return bot_id
|
||||||
|
if type_ == 'user':
|
||||||
|
return block['user_id']
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"Instructions passed when conversation is first initialized"
|
||||||
|
|
||||||
|
user_info: SlackUser = self.slack_to_openhands_user
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
if self.thread_ts:
|
||||||
|
client = WebClient(token=self.bot_access_token)
|
||||||
|
result = client.conversations_replies(
|
||||||
|
channel=self.channel_id,
|
||||||
|
ts=self.thread_ts,
|
||||||
|
inclusive=True,
|
||||||
|
latest=self.message_ts,
|
||||||
|
limit=CONTEXT_LIMIT, # We can be smarter about getting more context/condensing it even in the future
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = result['messages']
|
||||||
|
|
||||||
|
else:
|
||||||
|
client = WebClient(token=self.bot_access_token)
|
||||||
|
result = client.conversations_history(
|
||||||
|
channel=self.channel_id,
|
||||||
|
inclusive=True,
|
||||||
|
latest=self.message_ts,
|
||||||
|
limit=CONTEXT_LIMIT,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = result['messages']
|
||||||
|
messages.reverse()
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
raise ValueError('Failed to fetch information from slack API')
|
||||||
|
|
||||||
|
logger.info('got_messages_from_slack', extra={'messages': messages})
|
||||||
|
|
||||||
|
trigger_msg = messages[-1]
|
||||||
|
user_message = self._get_initial_prompt(
|
||||||
|
trigger_msg['text'], trigger_msg['blocks']
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_instructions = ''
|
||||||
|
|
||||||
|
if len(messages) > 1:
|
||||||
|
messages.pop()
|
||||||
|
text_messages = [m['text'] for m in messages if m.get('text')]
|
||||||
|
conversation_instructions_template = jinja_env.get_template(
|
||||||
|
'user_message_conversation_instructions.j2'
|
||||||
|
)
|
||||||
|
conversation_instructions = conversation_instructions_template.render(
|
||||||
|
messages=text_messages,
|
||||||
|
username=user_info.slack_display_name,
|
||||||
|
conversation_url=CONVERSATION_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_message, conversation_instructions
|
||||||
|
|
||||||
|
def _verify_necessary_values_are_set(self):
|
||||||
|
if not self.selected_repo:
|
||||||
|
raise ValueError(
|
||||||
|
'Attempting to start conversation without confirming selected repo from user'
|
||||||
|
)
|
||||||
|
|
||||||
|
async def save_slack_convo(self):
|
||||||
|
if self.slack_to_openhands_user:
|
||||||
|
user_info: SlackUser = self.slack_to_openhands_user
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'Create slack conversation',
|
||||||
|
extra={
|
||||||
|
'channel_id': self.channel_id,
|
||||||
|
'conversation_id': self.conversation_id,
|
||||||
|
'keycloak_user_id': user_info.keycloak_user_id,
|
||||||
|
'parent_id': self.thread_ts or self.message_ts,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
slack_conversation = SlackConversation(
|
||||||
|
conversation_id=self.conversation_id,
|
||||||
|
channel_id=self.channel_id,
|
||||||
|
keycloak_user_id=user_info.keycloak_user_id,
|
||||||
|
parent_id=self.thread_ts
|
||||||
|
or self.message_ts, # conversations can start in a thread reply as well; we should always references the parent's (root level msg's) message ID
|
||||||
|
)
|
||||||
|
await slack_conversation_store.create_slack_conversation(slack_conversation)
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||||
|
"""
|
||||||
|
Only creates a new conversation
|
||||||
|
"""
|
||||||
|
self._verify_necessary_values_are_set()
|
||||||
|
|
||||||
|
provider_tokens = await self.saas_user_auth.get_provider_tokens()
|
||||||
|
user_secrets = await self.saas_user_auth.get_user_secrets()
|
||||||
|
user_instructions, conversation_instructions = self._get_instructions(jinja)
|
||||||
|
|
||||||
|
agent_loop_info = await create_new_conversation(
|
||||||
|
user_id=self.slack_to_openhands_user.keycloak_user_id,
|
||||||
|
git_provider_tokens=provider_tokens,
|
||||||
|
selected_repository=self.selected_repo,
|
||||||
|
selected_branch=None,
|
||||||
|
initial_user_msg=user_instructions,
|
||||||
|
conversation_instructions=conversation_instructions
|
||||||
|
if conversation_instructions
|
||||||
|
else None,
|
||||||
|
image_urls=None,
|
||||||
|
replay_json=None,
|
||||||
|
conversation_trigger=ConversationTrigger.SLACK,
|
||||||
|
custom_secrets=user_secrets.custom_secrets if user_secrets else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conversation_id = agent_loop_info.conversation_id
|
||||||
|
await self.save_slack_convo()
|
||||||
|
return self.conversation_id
|
||||||
|
|
||||||
|
def get_callback_id(self) -> str:
|
||||||
|
return f'slack_{self.channel_id}_{self.message_ts}'
|
||||||
|
|
||||||
|
def get_response_msg(self) -> str:
|
||||||
|
user_info: SlackUser = self.slack_to_openhands_user
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {user_info.slack_display_name} can [track my progress here]({conversation_link})."
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SlackNewConversationFromRepoFormView(SlackNewConversationView):
|
||||||
|
def _verify_necessary_values_are_set(self):
|
||||||
|
# Exclude selected repo check from parent
|
||||||
|
# User can start conversations without a repo when specified via the repo selection form
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SlackUpdateExistingConversationView(SlackNewConversationView):
|
||||||
|
slack_conversation: SlackConversation
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
client = WebClient(token=self.bot_access_token)
|
||||||
|
result = client.conversations_replies(
|
||||||
|
channel=self.channel_id,
|
||||||
|
ts=self.message_ts,
|
||||||
|
inclusive=True,
|
||||||
|
latest=self.message_ts,
|
||||||
|
limit=1, # Get exact user message, in future we can be smarter with collecting additional context
|
||||||
|
)
|
||||||
|
|
||||||
|
user_message = result['messages'][0]
|
||||||
|
user_message = self._get_initial_prompt(
|
||||||
|
user_message['text'], user_message['blocks']
|
||||||
|
)
|
||||||
|
|
||||||
|
return user_message, ''
|
||||||
|
|
||||||
|
async def create_or_update_conversation(self, jinja: Environment) -> str:
|
||||||
|
"""
|
||||||
|
Send new user message to converation
|
||||||
|
"""
|
||||||
|
user_info: SlackUser = self.slack_to_openhands_user
|
||||||
|
saas_user_auth: UserAuth = self.saas_user_auth
|
||||||
|
user_id = user_info.keycloak_user_id
|
||||||
|
|
||||||
|
# Org management in the future will get rid of this
|
||||||
|
# For now, only user that created the conversation can send follow up messages to it
|
||||||
|
if user_id != self.slack_conversation.keycloak_user_id:
|
||||||
|
raise StartingConvoException(
|
||||||
|
f'{user_info.slack_display_name} is not authorized to send messages to this conversation.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if conversation has been deleted
|
||||||
|
# Update logic when soft delete is implemented
|
||||||
|
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
|
||||||
|
metadata = await conversation_store.get_metadata(self.conversation_id)
|
||||||
|
if not metadata:
|
||||||
|
raise StartingConvoException('Conversation no longer exists.')
|
||||||
|
|
||||||
|
provider_tokens = await saas_user_auth.get_provider_tokens()
|
||||||
|
|
||||||
|
# Should we raise here if there are no provider tokens?
|
||||||
|
providers_set = list(provider_tokens.keys()) if provider_tokens else []
|
||||||
|
|
||||||
|
conversation_init_data = await setup_init_conversation_settings(
|
||||||
|
user_id, self.conversation_id, providers_set
|
||||||
|
)
|
||||||
|
|
||||||
|
# Either join ongoing conversation, or restart the conversation
|
||||||
|
agent_loop_info = await conversation_manager.maybe_start_agent_loop(
|
||||||
|
self.conversation_id, conversation_init_data, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
final_agent_observation = get_final_agent_observation(
|
||||||
|
agent_loop_info.event_store
|
||||||
|
)
|
||||||
|
agent_state = (
|
||||||
|
None
|
||||||
|
if len(final_agent_observation) == 0
|
||||||
|
else final_agent_observation[0].agent_state
|
||||||
|
)
|
||||||
|
|
||||||
|
if not agent_state or agent_state == AgentState.LOADING:
|
||||||
|
raise StartingConvoException('Conversation is still starting')
|
||||||
|
|
||||||
|
user_msg, _ = self._get_instructions(jinja)
|
||||||
|
user_msg_action = MessageAction(content=user_msg)
|
||||||
|
await conversation_manager.send_event_to_conversation(
|
||||||
|
self.conversation_id, event_to_dict(user_msg_action)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.conversation_id
|
||||||
|
|
||||||
|
def get_response_msg(self):
|
||||||
|
user_info: SlackUser = self.slack_to_openhands_user
|
||||||
|
conversation_link = CONVERSATION_URL.format(self.conversation_id)
|
||||||
|
return f"I'm on it! {user_info.slack_display_name} can [continue tracking my progress here]({conversation_link})."
|
||||||
|
|
||||||
|
|
||||||
|
class SlackFactory:
|
||||||
|
@staticmethod
|
||||||
|
def did_user_select_repo_from_form(message: Message):
|
||||||
|
payload = message.message
|
||||||
|
return 'selected_repo' in payload
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def determine_if_updating_existing_conversation(
|
||||||
|
message: Message,
|
||||||
|
) -> SlackConversation | None:
|
||||||
|
payload = message.message
|
||||||
|
channel_id = payload.get('channel_id')
|
||||||
|
thread_ts = payload.get('thread_ts')
|
||||||
|
|
||||||
|
# Follow up conversations must be contained in-thread
|
||||||
|
if not thread_ts:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# thread_ts in slack payloads in the parent's (root level msg's) message ID
|
||||||
|
return await slack_conversation_store.get_slack_conversation(
|
||||||
|
channel_id, thread_ts
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_slack_view_from_payload(
|
||||||
|
message: Message, slack_user: SlackUser | None, saas_user_auth: UserAuth | None
|
||||||
|
):
|
||||||
|
payload = message.message
|
||||||
|
slack_user_id = payload['slack_user_id']
|
||||||
|
channel_id = payload.get('channel_id')
|
||||||
|
message_ts = payload.get('message_ts')
|
||||||
|
thread_ts = payload.get('thread_ts')
|
||||||
|
team_id = payload['team_id']
|
||||||
|
user_msg = payload.get('user_msg')
|
||||||
|
|
||||||
|
bot_access_token = slack_team_store.get_team_bot_token(team_id)
|
||||||
|
if not bot_access_token:
|
||||||
|
logger.error(
|
||||||
|
'Did not find slack team',
|
||||||
|
extra={
|
||||||
|
'slack_user_id': slack_user_id,
|
||||||
|
'channel_id': channel_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
raise Exception('Did not slack team')
|
||||||
|
|
||||||
|
# Determine if this is a known slack user by openhands
|
||||||
|
if not slack_user or not saas_user_auth or not channel_id:
|
||||||
|
return SlackUnkownUserView(
|
||||||
|
bot_access_token=bot_access_token,
|
||||||
|
user_msg=user_msg,
|
||||||
|
slack_user_id=slack_user_id,
|
||||||
|
slack_to_openhands_user=slack_user,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
channel_id=channel_id,
|
||||||
|
message_ts=message_ts,
|
||||||
|
thread_ts=thread_ts,
|
||||||
|
selected_repo=None,
|
||||||
|
should_extract=False,
|
||||||
|
send_summary_instruction=False,
|
||||||
|
conversation_id='',
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation: SlackConversation | None = call_async_from_sync(
|
||||||
|
SlackFactory.determine_if_updating_existing_conversation,
|
||||||
|
GENERAL_TIMEOUT,
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
if conversation:
|
||||||
|
logger.info(
|
||||||
|
'Found existing slack conversation',
|
||||||
|
extra={
|
||||||
|
'conversation_id': conversation.conversation_id,
|
||||||
|
'parent_id': conversation.parent_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return SlackUpdateExistingConversationView(
|
||||||
|
bot_access_token=bot_access_token,
|
||||||
|
user_msg=user_msg,
|
||||||
|
slack_user_id=slack_user_id,
|
||||||
|
slack_to_openhands_user=slack_user,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
channel_id=channel_id,
|
||||||
|
message_ts=message_ts,
|
||||||
|
thread_ts=thread_ts,
|
||||||
|
selected_repo=None,
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
conversation_id=conversation.conversation_id,
|
||||||
|
slack_conversation=conversation,
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif SlackFactory.did_user_select_repo_from_form(message):
|
||||||
|
return SlackNewConversationFromRepoFormView(
|
||||||
|
bot_access_token=bot_access_token,
|
||||||
|
user_msg=user_msg,
|
||||||
|
slack_user_id=slack_user_id,
|
||||||
|
slack_to_openhands_user=slack_user,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
channel_id=channel_id,
|
||||||
|
message_ts=message_ts,
|
||||||
|
thread_ts=thread_ts,
|
||||||
|
selected_repo=payload['selected_repo'],
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
conversation_id='',
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return SlackNewConversationView(
|
||||||
|
bot_access_token=bot_access_token,
|
||||||
|
user_msg=user_msg,
|
||||||
|
slack_user_id=slack_user_id,
|
||||||
|
slack_to_openhands_user=slack_user,
|
||||||
|
saas_user_auth=saas_user_auth,
|
||||||
|
channel_id=channel_id,
|
||||||
|
message_ts=message_ts,
|
||||||
|
thread_ts=thread_ts,
|
||||||
|
selected_repo=None,
|
||||||
|
should_extract=True,
|
||||||
|
send_summary_instruction=True,
|
||||||
|
conversation_id='',
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
0
enterprise/integrations/solvability/__init__.py
Normal file
0
enterprise/integrations/solvability/__init__.py
Normal file
41
enterprise/integrations/solvability/data/__init__.py
Normal file
41
enterprise/integrations/solvability/data/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
Utilities for loading and managing pre-trained classifiers.
|
||||||
|
|
||||||
|
Assumes that classifiers are stored adjacent to this file in the `solvability/data` directory, using a simple
|
||||||
|
`name + .json` pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||||
|
|
||||||
|
|
||||||
|
def load_classifier(name: str) -> SolvabilityClassifier:
|
||||||
|
"""
|
||||||
|
Load a classifier by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): The name of the classifier to load.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SolvabilityClassifier: The loaded classifier instance.
|
||||||
|
"""
|
||||||
|
data_dir = Path(__file__).parent
|
||||||
|
classifier_path = data_dir / f'{name}.json'
|
||||||
|
|
||||||
|
if not classifier_path.exists():
|
||||||
|
raise FileNotFoundError(f"Classifier '{name}' not found at {classifier_path}")
|
||||||
|
|
||||||
|
with classifier_path.open('r') as f:
|
||||||
|
return SolvabilityClassifier.model_validate_json(f.read())
|
||||||
|
|
||||||
|
|
||||||
|
def available_classifiers() -> list[str]:
|
||||||
|
"""
|
||||||
|
List all available classifiers in the data directory.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of classifier names (without the .json extension).
|
||||||
|
"""
|
||||||
|
data_dir = Path(__file__).parent
|
||||||
|
return [f.stem for f in data_dir.glob('*.json') if f.is_file()]
|
||||||
File diff suppressed because one or more lines are too long
38
enterprise/integrations/solvability/models/__init__.py
Normal file
38
enterprise/integrations/solvability/models/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
Solvability Models Package
|
||||||
|
|
||||||
|
This package contains the core machine learning models and components for predicting
|
||||||
|
the solvability of GitHub issues and similar technical problems.
|
||||||
|
|
||||||
|
The solvability prediction system works by:
|
||||||
|
1. Using a Featurizer to extract semantic features from issue descriptions via LLM calls
|
||||||
|
2. Training a RandomForestClassifier on these features to predict solvability
|
||||||
|
3. Generating detailed reports with feature importance analysis
|
||||||
|
|
||||||
|
Key Components:
|
||||||
|
- Feature: Defines individual features that can be extracted from issues
|
||||||
|
- Featurizer: Orchestrates LLM-based feature extraction with sampling and batching
|
||||||
|
- SolvabilityClassifier: Main ML pipeline combining featurization and classification
|
||||||
|
- SolvabilityReport: Comprehensive output with predictions, feature analysis, and metadata
|
||||||
|
- ImportanceStrategy: Configurable methods for calculating feature importance (SHAP, permutation, impurity)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from integrations.solvability.models.classifier import SolvabilityClassifier
|
||||||
|
from integrations.solvability.models.featurizer import (
|
||||||
|
EmbeddingDimension,
|
||||||
|
Feature,
|
||||||
|
FeatureEmbedding,
|
||||||
|
Featurizer,
|
||||||
|
)
|
||||||
|
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||||
|
from integrations.solvability.models.report import SolvabilityReport
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Feature',
|
||||||
|
'EmbeddingDimension',
|
||||||
|
'FeatureEmbedding',
|
||||||
|
'Featurizer',
|
||||||
|
'ImportanceStrategy',
|
||||||
|
'SolvabilityClassifier',
|
||||||
|
'SolvabilityReport',
|
||||||
|
]
|
||||||
433
enterprise/integrations/solvability/models/classifier.py
Normal file
433
enterprise/integrations/solvability/models/classifier.py
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import pickle
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import shap
|
||||||
|
from integrations.solvability.models.featurizer import Feature, Featurizer
|
||||||
|
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||||
|
from integrations.solvability.models.report import SolvabilityReport
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
PrivateAttr,
|
||||||
|
field_serializer,
|
||||||
|
field_validator,
|
||||||
|
model_validator,
|
||||||
|
)
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.exceptions import NotFittedError
|
||||||
|
from sklearn.inspection import permutation_importance
|
||||||
|
from sklearn.utils.validation import check_is_fitted
|
||||||
|
|
||||||
|
from openhands.core.config import LLMConfig
|
||||||
|
|
||||||
|
|
||||||
|
class SolvabilityClassifier(BaseModel):
|
||||||
|
"""
|
||||||
|
Machine learning pipeline for predicting the solvability of GitHub issues and similar problems.
|
||||||
|
|
||||||
|
This classifier combines LLM-based feature extraction with traditional ML classification:
|
||||||
|
1. Uses a Featurizer to extract semantic boolean features from issue descriptions via LLM calls
|
||||||
|
2. Trains a RandomForestClassifier on these features to predict solvability scores
|
||||||
|
3. Provides feature importance analysis using configurable strategies (SHAP, permutation, impurity)
|
||||||
|
4. Generates comprehensive reports with predictions, feature analysis, and cost metrics
|
||||||
|
|
||||||
|
The classifier supports both training on labeled data and inference on new issues, with built-in
|
||||||
|
support for batch processing and concurrent feature extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
identifier: str
|
||||||
|
"""
|
||||||
|
The identifier for the classifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
featurizer: Featurizer
|
||||||
|
"""
|
||||||
|
The featurizer to use for transforming the input data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
classifier: RandomForestClassifier
|
||||||
|
"""
|
||||||
|
The RandomForestClassifier used for predicting solvability from extracted features.
|
||||||
|
|
||||||
|
This ensemble model provides robust predictions and built-in feature importance metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
importance_strategy: ImportanceStrategy = ImportanceStrategy.IMPURITY
|
||||||
|
"""
|
||||||
|
Strategy to use for calculating feature importance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples: int = 10
|
||||||
|
"""
|
||||||
|
Number of samples to use for calculating feature embedding coefficients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
random_state: int | None = None
|
||||||
|
"""
|
||||||
|
Random state for reproducibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_classifier_attrs: dict[str, Any] = PrivateAttr(default_factory=dict)
|
||||||
|
"""
|
||||||
|
Private dictionary storing cached results from feature extraction and importance calculations.
|
||||||
|
|
||||||
|
Contains keys like 'features_', 'cost_', 'feature_importances_', and 'labels_' that are populated
|
||||||
|
during transform(), fit(), and predict() operations. Access these via the corresponding properties.
|
||||||
|
|
||||||
|
This field is never serialized, so cached values will not persist across model save/load cycles.
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
'arbitrary_types_allowed': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
@model_validator(mode='after')
|
||||||
|
def validate_random_state(self) -> SolvabilityClassifier:
|
||||||
|
"""
|
||||||
|
Validate the random state configuration between this object and the classifier.
|
||||||
|
"""
|
||||||
|
# If both random states are set, they definitely need to agree.
|
||||||
|
if self.random_state is not None and self.classifier.random_state is not None:
|
||||||
|
if self.random_state != self.classifier.random_state:
|
||||||
|
raise ValueError(
|
||||||
|
'The random state of the classifier and the top-level classifier must agree.'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise, we'll always set the classifier's random state to the top-level one.
|
||||||
|
self.classifier.random_state = self.random_state
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def features_(self) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Get the features used by the classifier for the most recent inputs.
|
||||||
|
"""
|
||||||
|
if 'features_' not in self._classifier_attrs:
|
||||||
|
raise ValueError(
|
||||||
|
'SolvabilityClassifier.transform() has not yet been called.'
|
||||||
|
)
|
||||||
|
return self._classifier_attrs['features_']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cost_(self) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Get the cost of the classifier for the most recent inputs.
|
||||||
|
"""
|
||||||
|
if 'cost_' not in self._classifier_attrs:
|
||||||
|
raise ValueError(
|
||||||
|
'SolvabilityClassifier.transform() has not yet been called.'
|
||||||
|
)
|
||||||
|
return self._classifier_attrs['cost_']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feature_importances_(self) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Get the feature importances for the most recent inputs.
|
||||||
|
"""
|
||||||
|
if 'feature_importances_' not in self._classifier_attrs:
|
||||||
|
raise ValueError(
|
||||||
|
'No SolvabilityClassifier methods that produce feature importances (.fit(), .predict_proba(), and '
|
||||||
|
'.predict()) have been called.'
|
||||||
|
)
|
||||||
|
return self._classifier_attrs['feature_importances_'] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_fitted(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the classifier is fitted.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
check_is_fitted(self.classifier)
|
||||||
|
return True
|
||||||
|
except NotFittedError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def transform(self, issues: pd.Series, llm_config: LLMConfig) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
Transform the input issues using the featurizer to extract features.
|
||||||
|
|
||||||
|
This method orchestrates the feature extraction pipeline:
|
||||||
|
1. Uses the featurizer to generate embeddings for all issues
|
||||||
|
2. Converts embeddings to a structured DataFrame
|
||||||
|
3. Separates feature columns from metadata columns
|
||||||
|
4. Stores results for later access via properties
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issues: A pandas Series containing the issue descriptions.
|
||||||
|
llm_config: LLM configuration to use for feature extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pd.DataFrame: A DataFrame containing only the feature columns (no metadata).
|
||||||
|
"""
|
||||||
|
# Generate feature embeddings for all issues using batch processing
|
||||||
|
feature_embeddings = self.featurizer.embed_batch(
|
||||||
|
issues, samples=self.samples, llm_config=llm_config
|
||||||
|
)
|
||||||
|
df = pd.DataFrame(embedding.to_row() for embedding in feature_embeddings)
|
||||||
|
|
||||||
|
# Split into feature columns (used by classifier) and cost columns (metadata)
|
||||||
|
feature_columns = [feature.identifier for feature in self.featurizer.features]
|
||||||
|
cost_columns = [col for col in df.columns if col not in feature_columns]
|
||||||
|
|
||||||
|
# Store both sets for access via properties
|
||||||
|
self._classifier_attrs['features_'] = df[feature_columns]
|
||||||
|
self._classifier_attrs['cost_'] = df[cost_columns]
|
||||||
|
|
||||||
|
return self.features_
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self, issues: pd.Series, labels: pd.Series, llm_config: LLMConfig
|
||||||
|
) -> SolvabilityClassifier:
|
||||||
|
"""
|
||||||
|
Fit the classifier to the input issues and labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issues: A pandas Series containing the issue descriptions.
|
||||||
|
|
||||||
|
labels: A pandas Series containing the labels (0 or 1) for each issue.
|
||||||
|
|
||||||
|
llm_config: LLM configuration to use for feature extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SolvabilityClassifier: The fitted classifier.
|
||||||
|
"""
|
||||||
|
features = self.transform(issues, llm_config=llm_config)
|
||||||
|
self.classifier.fit(features, labels)
|
||||||
|
|
||||||
|
# Store labels for permutation importance calculation
|
||||||
|
self._classifier_attrs['labels_'] = labels
|
||||||
|
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||||
|
features, self.classifier.predict_proba(features), labels
|
||||||
|
)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict_proba(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Predict the solvability probabilities for the input issues.
|
||||||
|
|
||||||
|
Returns class probabilities where the second column represents the probability
|
||||||
|
of the issue being solvable (positive class).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issues: A pandas Series containing the issue descriptions.
|
||||||
|
llm_config: LLM configuration to use for feature extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Array of shape (n_samples, 2) with probabilities for each class.
|
||||||
|
Column 0: probability of not solvable, Column 1: probability of solvable.
|
||||||
|
"""
|
||||||
|
features = self.transform(issues, llm_config=llm_config)
|
||||||
|
scores = self.classifier.predict_proba(features)
|
||||||
|
|
||||||
|
# Calculate feature importances based on the configured strategy
|
||||||
|
# For permutation importance, we need ground truth labels if available
|
||||||
|
labels = self._classifier_attrs.get('labels_')
|
||||||
|
if (
|
||||||
|
self.importance_strategy == ImportanceStrategy.PERMUTATION
|
||||||
|
and labels is not None
|
||||||
|
):
|
||||||
|
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||||
|
features, scores, labels
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._classifier_attrs['feature_importances_'] = self._importance(
|
||||||
|
features, scores
|
||||||
|
)
|
||||||
|
|
||||||
|
return scores # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
def predict(self, issues: pd.Series, llm_config: LLMConfig) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Predict the solvability of the input issues by returning binary labels.
|
||||||
|
|
||||||
|
Uses a 0.5 probability threshold to convert probabilities to binary predictions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issues: A pandas Series containing the issue descriptions.
|
||||||
|
llm_config: LLM configuration to use for feature extraction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Boolean array where True indicates the issue is predicted as solvable.
|
||||||
|
"""
|
||||||
|
probabilities = self.predict_proba(issues, llm_config=llm_config)
|
||||||
|
# Apply 0.5 threshold to convert probabilities to binary predictions
|
||||||
|
labels = probabilities[:, 1] >= 0.5
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def _importance(
|
||||||
|
self,
|
||||||
|
features: pd.DataFrame,
|
||||||
|
scores: np.ndarray,
|
||||||
|
labels: np.ndarray | None = None,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Calculate feature importance scores using the configured strategy.
|
||||||
|
|
||||||
|
Different strategies provide different interpretations:
|
||||||
|
- SHAP: Shapley values indicating contribution to individual predictions
|
||||||
|
- PERMUTATION: Decrease in model performance when feature is shuffled
|
||||||
|
- IMPURITY: Gini impurity decrease from splits on each feature
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: Feature matrix used for predictions.
|
||||||
|
scores: Model prediction scores (unused for some strategies).
|
||||||
|
labels: Ground truth labels (required for permutation importance).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: Feature importance scores, one per feature.
|
||||||
|
"""
|
||||||
|
match self.importance_strategy:
|
||||||
|
case ImportanceStrategy.SHAP:
|
||||||
|
# Use SHAP TreeExplainer for tree-based models
|
||||||
|
explainer = shap.TreeExplainer(self.classifier)
|
||||||
|
shap_values = explainer.shap_values(features)
|
||||||
|
# Return mean SHAP values for the positive class (solvable)
|
||||||
|
return shap_values.mean(axis=0)[:, 1] # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
case ImportanceStrategy.PERMUTATION:
|
||||||
|
# Permutation importance requires ground truth labels
|
||||||
|
if labels is None:
|
||||||
|
raise ValueError('Labels are required for permutation importance')
|
||||||
|
result = permutation_importance(
|
||||||
|
self.classifier,
|
||||||
|
features,
|
||||||
|
labels,
|
||||||
|
n_repeats=10, # Number of permutation rounds for stability
|
||||||
|
random_state=self.random_state,
|
||||||
|
)
|
||||||
|
return result.importances_mean # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
case ImportanceStrategy.IMPURITY:
|
||||||
|
# Use built-in feature importances from RandomForest
|
||||||
|
return self.classifier.feature_importances_ # type: ignore[no-any-return]
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f'Unknown importance strategy: {self.importance_strategy}'
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||||
|
"""
|
||||||
|
Add new features to the classifier's featurizer.
|
||||||
|
|
||||||
|
Note: Adding features after training requires retraining the classifier
|
||||||
|
since the feature space will have changed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: List of Feature objects to add.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SolvabilityClassifier: Self for method chaining.
|
||||||
|
"""
|
||||||
|
for feature in features:
|
||||||
|
if feature not in self.featurizer.features:
|
||||||
|
self.featurizer.features.append(feature)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def forget_features(self, features: list[Feature]) -> SolvabilityClassifier:
|
||||||
|
"""
|
||||||
|
Remove features from the classifier's featurizer.
|
||||||
|
|
||||||
|
Note: Removing features after training requires retraining the classifier
|
||||||
|
since the feature space will have changed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: List of Feature objects to remove.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SolvabilityClassifier: Self for method chaining.
|
||||||
|
"""
|
||||||
|
for feature in features:
|
||||||
|
try:
|
||||||
|
self.featurizer.features.remove(feature)
|
||||||
|
except ValueError:
|
||||||
|
# Feature not in list, continue with others
|
||||||
|
continue
|
||||||
|
return self
|
||||||
|
|
||||||
|
@field_serializer('classifier')
|
||||||
|
@staticmethod
|
||||||
|
def _rfc_to_json(rfc: RandomForestClassifier) -> str:
|
||||||
|
"""
|
||||||
|
Convert a RandomForestClassifier to a JSON-compatible value (a string).
|
||||||
|
"""
|
||||||
|
return base64.b64encode(pickle.dumps(rfc)).decode('utf-8')
|
||||||
|
|
||||||
|
@field_validator('classifier', mode='before')
|
||||||
|
@staticmethod
|
||||||
|
def _json_to_rfc(value: str | RandomForestClassifier) -> RandomForestClassifier:
|
||||||
|
"""
|
||||||
|
Convert a JSON-compatible value (a string) back to a RandomForestClassifier.
|
||||||
|
"""
|
||||||
|
if isinstance(value, RandomForestClassifier):
|
||||||
|
return value
|
||||||
|
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
model = pickle.loads(base64.b64decode(value))
|
||||||
|
if isinstance(model, RandomForestClassifier):
|
||||||
|
return model
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f'Failed to decode the classifier: {e}')
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
'The classifier must be a RandomForestClassifier or a JSON-compatible dictionary.'
|
||||||
|
)
|
||||||
|
|
||||||
|
def solvability_report(
|
||||||
|
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||||
|
) -> SolvabilityReport:
|
||||||
|
"""
|
||||||
|
Generate a solvability report for the given issue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issue: The issue description for which to generate the report.
|
||||||
|
llm_config: Optional LLM configuration to use for feature extraction.
|
||||||
|
kwargs: Additional metadata to include in the report.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SolvabilityReport: The generated solvability report.
|
||||||
|
"""
|
||||||
|
if not self.is_fitted:
|
||||||
|
raise ValueError(
|
||||||
|
'The classifier must be fitted before generating a report.'
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = self.predict_proba(pd.Series([issue]), llm_config=llm_config)
|
||||||
|
|
||||||
|
return SolvabilityReport(
|
||||||
|
identifier=self.identifier,
|
||||||
|
issue=issue,
|
||||||
|
score=scores[0, 1],
|
||||||
|
features=self.features_.iloc[0].to_dict(),
|
||||||
|
samples=self.samples,
|
||||||
|
importance_strategy=self.importance_strategy,
|
||||||
|
# Unlike the features, the importances are just a series with no link
|
||||||
|
# to the actual feature names. For that we have to recombine with the
|
||||||
|
# feature identifiers.
|
||||||
|
feature_importances=dict(
|
||||||
|
zip(
|
||||||
|
self.featurizer.feature_identifiers(),
|
||||||
|
self.feature_importances_.tolist(),
|
||||||
|
)
|
||||||
|
),
|
||||||
|
random_state=self.random_state,
|
||||||
|
metadata=dict(kwargs) if kwargs else None,
|
||||||
|
# Both cost and response_latency are columns in the cost_ DataFrame,
|
||||||
|
# so we can get both by just unpacking the first row.
|
||||||
|
**self.cost_.iloc[0].to_dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, issue: str, llm_config: LLMConfig, **kwargs: Any
|
||||||
|
) -> SolvabilityReport:
|
||||||
|
"""
|
||||||
|
Generate a solvability report for the given issue.
|
||||||
|
"""
|
||||||
|
return self.solvability_report(issue, llm_config=llm_config, **kwargs)
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class DifficultyLevel(Enum):
|
||||||
|
"""Enum representing the difficulty level based on solvability score."""
|
||||||
|
|
||||||
|
EASY = ('EASY', 0.7, '🟢')
|
||||||
|
MEDIUM = ('MEDIUM', 0.4, '🟡')
|
||||||
|
HARD = ('HARD', 0.0, '🔴')
|
||||||
|
|
||||||
|
def __init__(self, label: str, threshold: float, emoji: str):
|
||||||
|
self.label = label
|
||||||
|
self.threshold = threshold
|
||||||
|
self.emoji = emoji
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_score(cls, score: float) -> DifficultyLevel:
|
||||||
|
"""Get difficulty level from a solvability score.
|
||||||
|
|
||||||
|
Returns the difficulty level with the highest threshold that is less than or equal to the given score.
|
||||||
|
"""
|
||||||
|
# Sort enum values by threshold in descending order
|
||||||
|
sorted_levels = sorted(cls, key=lambda x: x.threshold, reverse=True)
|
||||||
|
|
||||||
|
# Find the first level where score meets the threshold
|
||||||
|
for level in sorted_levels:
|
||||||
|
if score >= level.threshold:
|
||||||
|
return level
|
||||||
|
|
||||||
|
# This should never happen if thresholds are set correctly,
|
||||||
|
# but return the lowest threshold level as fallback
|
||||||
|
return sorted_levels[-1]
|
||||||
|
|
||||||
|
def format_display(self) -> str:
|
||||||
|
"""Format the difficulty level for display."""
|
||||||
|
return f'{self.emoji} **Solvability: {self.label}**'
|
||||||
368
enterprise/integrations/solvability/models/featurizer.py
Normal file
368
enterprise/integrations/solvability/models/featurizer.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from openhands.core.config import LLMConfig
|
||||||
|
from openhands.llm.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class Feature(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a single boolean feature that can be extracted from issue descriptions.
|
||||||
|
|
||||||
|
Features are semantic properties of issues (e.g., "has_code_example", "requires_debugging")
|
||||||
|
that are evaluated by LLMs and used as input to the solvability classifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
identifier: str
|
||||||
|
"""Unique identifier for the feature, used as column name in feature matrices."""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
"""Human-readable description of what the feature represents, used in LLM prompts."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_tool_description_field(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert this feature to a JSON schema field for LLM tool calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: JSON schema field definition for this feature.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'type': 'boolean',
|
||||||
|
'description': self.description,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingDimension(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents a single dimension (feature evaluation) within a feature embedding sample.
|
||||||
|
|
||||||
|
Each dimension corresponds to one feature being evaluated as true/false for a given issue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
feature_id: str
|
||||||
|
"""Identifier of the feature being evaluated."""
|
||||||
|
|
||||||
|
result: bool
|
||||||
|
"""Boolean result of the feature evaluation for this sample."""
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for a single embedding sample - maps feature identifiers to boolean values
|
||||||
|
EmbeddingSample = dict[str, bool]
|
||||||
|
"""
|
||||||
|
A single sample from the LLM evaluation of features for an issue.
|
||||||
|
Maps feature identifiers to their boolean evaluations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureEmbedding(BaseModel):
|
||||||
|
"""
|
||||||
|
Represents the complete feature embedding for a single issue, including multiple samples
|
||||||
|
and associated metadata about the LLM calls used to generate it.
|
||||||
|
|
||||||
|
Multiple samples are collected to account for LLM variability and provide more robust
|
||||||
|
feature estimates through averaging.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples: list[EmbeddingSample]
|
||||||
|
"""List of individual feature evaluation samples from the LLM."""
|
||||||
|
|
||||||
|
prompt_tokens: int | None = None
|
||||||
|
"""Total prompt tokens consumed across all LLM calls for this embedding."""
|
||||||
|
|
||||||
|
completion_tokens: int | None = None
|
||||||
|
"""Total completion tokens generated across all LLM calls for this embedding."""
|
||||||
|
|
||||||
|
response_latency: float | None = None
|
||||||
|
"""Total response latency (seconds) across all LLM calls for this embedding."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dimensions(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get all unique feature identifiers present across all samples.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of feature identifiers that appear in at least one sample.
|
||||||
|
"""
|
||||||
|
dims: set[str] = set()
|
||||||
|
for sample in self.samples:
|
||||||
|
dims.update(sample.keys())
|
||||||
|
return list(dims)
|
||||||
|
|
||||||
|
def coefficient(self, dimension: str) -> float | None:
|
||||||
|
"""
|
||||||
|
Calculate the average coefficient (0-1) for a specific feature dimension.
|
||||||
|
|
||||||
|
This computes the proportion of samples where the feature was evaluated as True,
|
||||||
|
providing a continuous feature value for the classifier.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dimension: Feature identifier to calculate coefficient for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float | None: Average coefficient (0.0-1.0), or None if dimension not found.
|
||||||
|
"""
|
||||||
|
# Extract boolean values for this dimension, converting to 0/1
|
||||||
|
values = [
|
||||||
|
1 if v else 0
|
||||||
|
for v in [sample.get(dimension) for sample in self.samples]
|
||||||
|
if v is not None
|
||||||
|
]
|
||||||
|
if values:
|
||||||
|
return sum(values) / len(values)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def to_row(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Convert the embedding to a flat dictionary suitable for DataFrame construction.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: Dictionary with metadata fields and feature coefficients.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'response_latency': self.response_latency,
|
||||||
|
'prompt_tokens': self.prompt_tokens,
|
||||||
|
'completion_tokens': self.completion_tokens,
|
||||||
|
**{dimension: self.coefficient(dimension) for dimension in self.dimensions},
|
||||||
|
}
|
||||||
|
|
||||||
|
def sample_entropy(self) -> dict[str, float]:
|
||||||
|
"""
|
||||||
|
Calculate the Shannon entropy of feature evaluations across samples.
|
||||||
|
|
||||||
|
Higher entropy indicates more variability in LLM responses for a feature,
|
||||||
|
which may suggest ambiguity in the feature definition or issue description.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, float]: Mapping of feature identifiers to their entropy values (0-1).
|
||||||
|
"""
|
||||||
|
from collections import Counter
|
||||||
|
from math import log2
|
||||||
|
|
||||||
|
entropy = {}
|
||||||
|
for dimension in self.dimensions:
|
||||||
|
# Count True/False occurrences for this feature across samples
|
||||||
|
counts = Counter(sample.get(dimension, False) for sample in self.samples)
|
||||||
|
total = sum(counts.values())
|
||||||
|
if total == 0:
|
||||||
|
entropy[dimension] = 0.0
|
||||||
|
continue
|
||||||
|
# Calculate Shannon entropy: -Σ(p * log2(p))
|
||||||
|
entropy_value = -sum(
|
||||||
|
(count / total) * log2(count / total)
|
||||||
|
for count in counts.values()
|
||||||
|
if count > 0
|
||||||
|
)
|
||||||
|
entropy[dimension] = entropy_value
|
||||||
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
|
class Featurizer(BaseModel):
|
||||||
|
"""
|
||||||
|
Orchestrates LLM-based feature extraction from issue descriptions.
|
||||||
|
|
||||||
|
The Featurizer uses structured LLM tool calling to evaluate boolean features
|
||||||
|
for issue descriptions. It handles prompt construction, tool schema generation,
|
||||||
|
and batch processing with concurrency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt: str
|
||||||
|
"""System prompt that provides context and instructions to the LLM."""
|
||||||
|
|
||||||
|
message_prefix: str
|
||||||
|
"""Prefix added to user messages before the issue description."""
|
||||||
|
|
||||||
|
features: list[Feature]
|
||||||
|
"""List of features to extract from each issue description."""
|
||||||
|
|
||||||
|
def system_message(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Construct the system message for LLM conversations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: System message dictionary for LLM API calls.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'role': 'system',
|
||||||
|
'content': self.system_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
def user_message(
|
||||||
|
self, issue_description: str, set_cache: bool = True
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Construct the user message containing the issue description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issue_description: The description of the issue to analyze.
|
||||||
|
set_cache: Whether to enable ephemeral caching for this message.
|
||||||
|
Should be False for single samples to avoid cache overhead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: User message dictionary for LLM API calls.
|
||||||
|
"""
|
||||||
|
message: dict[str, Any] = {
|
||||||
|
'role': 'user',
|
||||||
|
'content': f'{self.message_prefix}{issue_description}',
|
||||||
|
}
|
||||||
|
if set_cache:
|
||||||
|
message['cache_control'] = {'type': 'ephemeral'}
|
||||||
|
return message
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_choice(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the tool choice configuration for forcing LLM to use the featurizer tool.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: Tool choice configuration for LLM API calls.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'type': 'function',
|
||||||
|
'function': {'name': 'call_featurizer'},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_description(self) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate the tool schema for the featurizer function.
|
||||||
|
|
||||||
|
Creates a JSON schema that describes the featurizer tool with all configured
|
||||||
|
features as boolean parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict[str, Any]: Complete tool description for LLM API calls.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': 'call_featurizer',
|
||||||
|
'description': 'Record the features present in the issue.',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
feature.identifier: feature.to_tool_description_field
|
||||||
|
for feature in self.features
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def embed(
|
||||||
|
self,
|
||||||
|
issue_description: str,
|
||||||
|
llm_config: LLMConfig,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
samples: int = 10,
|
||||||
|
) -> FeatureEmbedding:
|
||||||
|
"""
|
||||||
|
Generate a feature embedding for a single issue description.
|
||||||
|
|
||||||
|
Makes multiple LLM calls to collect samples and reduce variance in feature evaluations.
|
||||||
|
Each call uses tool calling to extract structured boolean feature values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issue_description: The description of the issue to analyze.
|
||||||
|
llm_config: Configuration for the LLM to use.
|
||||||
|
temperature: Sampling temperature for the model. Higher values increase randomness.
|
||||||
|
samples: Number of samples to generate for averaging.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
FeatureEmbedding: Complete embedding with samples and metadata.
|
||||||
|
"""
|
||||||
|
embedding_samples: list[dict[str, Any]] = []
|
||||||
|
response_latency: float = 0.0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
|
||||||
|
# TODO: use llm registry
|
||||||
|
llm = LLM(llm_config, service_id='solvability')
|
||||||
|
|
||||||
|
# Generate multiple samples to account for LLM variability
|
||||||
|
for _ in range(samples):
|
||||||
|
start_time = time.time()
|
||||||
|
response = llm.completion(
|
||||||
|
messages=[
|
||||||
|
self.system_message(),
|
||||||
|
self.user_message(issue_description, set_cache=(samples > 1)),
|
||||||
|
],
|
||||||
|
tools=[self.tool_description],
|
||||||
|
tool_choice=self.tool_choice,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
stop_time = time.time()
|
||||||
|
|
||||||
|
# Extract timing and token usage metrics
|
||||||
|
latency = stop_time - start_time
|
||||||
|
# Parse the structured tool call response containing feature evaluations
|
||||||
|
features = response.choices[0].message.tool_calls[0].function.arguments # type: ignore[index, union-attr]
|
||||||
|
embedding = json.loads(features)
|
||||||
|
|
||||||
|
# Accumulate results and metrics
|
||||||
|
embedding_samples.append(embedding)
|
||||||
|
prompt_tokens += response.usage.prompt_tokens # type: ignore[union-attr, attr-defined]
|
||||||
|
completion_tokens += response.usage.completion_tokens # type: ignore[union-attr, attr-defined]
|
||||||
|
response_latency += latency
|
||||||
|
|
||||||
|
return FeatureEmbedding(
|
||||||
|
samples=embedding_samples,
|
||||||
|
response_latency=response_latency,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def embed_batch(
|
||||||
|
self,
|
||||||
|
issue_descriptions: list[str],
|
||||||
|
llm_config: LLMConfig,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
samples: int = 10,
|
||||||
|
) -> list[FeatureEmbedding]:
|
||||||
|
"""
|
||||||
|
Generate embeddings for a batch of issue descriptions using concurrent processing.
|
||||||
|
|
||||||
|
Processes multiple issues in parallel to improve throughput while maintaining
|
||||||
|
result ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
issue_descriptions: List of issue descriptions to analyze.
|
||||||
|
llm_config: Configuration for the LLM to use.
|
||||||
|
temperature: Sampling temperature for the model.
|
||||||
|
samples: Number of samples to generate per issue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[FeatureEmbedding]: List of embeddings in the same order as input.
|
||||||
|
"""
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
# Submit all embedding tasks concurrently
|
||||||
|
future_to_desc = {
|
||||||
|
executor.submit(
|
||||||
|
self.embed,
|
||||||
|
desc,
|
||||||
|
llm_config,
|
||||||
|
temperature=temperature,
|
||||||
|
samples=samples,
|
||||||
|
): i
|
||||||
|
for i, desc in enumerate(issue_descriptions)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Collect results in original order to maintain consistency
|
||||||
|
results: list[FeatureEmbedding] = [None] * len(issue_descriptions) # type: ignore[list-item]
|
||||||
|
for future in as_completed(future_to_desc):
|
||||||
|
index = future_to_desc[future]
|
||||||
|
results[index] = future.result()
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def feature_identifiers(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Get the identifiers of all configured features.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: List of feature identifiers in the order they were defined.
|
||||||
|
"""
|
||||||
|
return [feature.identifier for feature in self.features]
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class ImportanceStrategy(str, Enum):
|
||||||
|
"""
|
||||||
|
Strategy to use for calculating feature importances, which are used to estimate the predictive power of each feature
|
||||||
|
in training loops and explanations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SHAP = 'shap'
|
||||||
|
"""
|
||||||
|
Use SHAP (SHapley Additive exPlanations) to calculate feature importances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PERMUTATION = 'permutation'
|
||||||
|
"""
|
||||||
|
Use the permutation-based feature importances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
IMPURITY = 'impurity'
|
||||||
|
"""
|
||||||
|
Use the impurity-based feature importances from the RandomForestClassifier.
|
||||||
|
"""
|
||||||
87
enterprise/integrations/solvability/models/report.py
Normal file
87
enterprise/integrations/solvability/models/report.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from integrations.solvability.models.importance_strategy import ImportanceStrategy
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SolvabilityReport(BaseModel):
|
||||||
|
"""
|
||||||
|
Comprehensive report containing solvability predictions and analysis for a single issue.
|
||||||
|
|
||||||
|
This report includes the solvability score, extracted feature values, feature importance analysis,
|
||||||
|
cost metrics (tokens and latency), and metadata about the prediction process. It serves as the
|
||||||
|
primary output format for solvability analysis and can be used for logging, debugging, and
|
||||||
|
generating human-readable summaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
identifier: str
|
||||||
|
"""
|
||||||
|
The identifier of the solvability model used to generate the report.
|
||||||
|
"""
|
||||||
|
|
||||||
|
issue: str
|
||||||
|
"""
|
||||||
|
The issue description for which the solvability is predicted.
|
||||||
|
|
||||||
|
This field is exactly the input to the solvability model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
score: float
|
||||||
|
"""
|
||||||
|
[0, 1]-valued score indicating the likelihood of the issue being solvable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
"""
|
||||||
|
Total number of prompt tokens used in API calls made to generate the features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
completion_tokens: int
|
||||||
|
"""
|
||||||
|
Total number of completion tokens used in API calls made to generate the features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_latency: float
|
||||||
|
"""
|
||||||
|
Total response latency of API calls made to generate the features.
|
||||||
|
"""
|
||||||
|
|
||||||
|
features: dict[str, float]
|
||||||
|
"""
|
||||||
|
[0, 1]-valued scores for each feature in the model.
|
||||||
|
|
||||||
|
These are the values fed to the random forest classifier to generate the solvability score.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samples: int
|
||||||
|
"""
|
||||||
|
Number of samples used to compute the feature embedding coefficients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
importance_strategy: ImportanceStrategy
|
||||||
|
"""
|
||||||
|
Strategy used to calculate feature importances.
|
||||||
|
"""
|
||||||
|
|
||||||
|
feature_importances: dict[str, float]
|
||||||
|
"""
|
||||||
|
Importance scores for each feature in the model.
|
||||||
|
|
||||||
|
Interpretation of these scores depends on the importance strategy used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
|
"""
|
||||||
|
Datetime when the report was created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
random_state: int | None = None
|
||||||
|
"""
|
||||||
|
Classifier random state used when generating this report.
|
||||||
|
"""
|
||||||
|
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
"""
|
||||||
|
Metadata for logging and debugging purposes.
|
||||||
|
"""
|
||||||
172
enterprise/integrations/solvability/models/summary.py
Normal file
172
enterprise/integrations/solvability/models/summary.py
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from integrations.solvability.models.difficulty_level import DifficultyLevel
|
||||||
|
from integrations.solvability.models.report import SolvabilityReport
|
||||||
|
from integrations.solvability.prompts import load_prompt
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from openhands.llm import LLM
|
||||||
|
|
||||||
|
|
||||||
|
class SolvabilitySummary(BaseModel):
|
||||||
|
"""Summary of the solvability analysis in human-readable format."""
|
||||||
|
|
||||||
|
score: float
|
||||||
|
"""
|
||||||
|
Solvability score indicating the likelihood of the issue being solvable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
summary: str
|
||||||
|
"""
|
||||||
|
The executive summary content generated by the LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
actionable_feedback: str
|
||||||
|
"""
|
||||||
|
Actionable feedback content generated by the LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
|
positive_feedback: str
|
||||||
|
"""
|
||||||
|
Positive feedback content generated by the LLM, highlighting what is good about the issue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
"""
|
||||||
|
Number of prompt tokens used in the API call to generate the summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
completion_tokens: int
|
||||||
|
"""
|
||||||
|
Number of completion tokens used in the API call to generate the summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
response_latency: float
|
||||||
|
"""
|
||||||
|
Response latency of the API call to generate the summary.
|
||||||
|
"""
|
||||||
|
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now)
|
||||||
|
"""
|
||||||
|
Datetime when the summary was created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_description() -> dict[str, Any]:
|
||||||
|
"""Get the tool description for the LLM."""
|
||||||
|
return {
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': 'solvability_summary',
|
||||||
|
'description': 'Generate a human-readable summary of the solvability analysis.',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'summary': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'A high-level (at most two sentences) summary of the solvability report.',
|
||||||
|
},
|
||||||
|
'actionable_feedback': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': (
|
||||||
|
'Bullet list of 1-3 pieces of actionable feedback on how the user can address the lowest scoring relevant features.'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
'positive_feedback': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': (
|
||||||
|
'Bullet list of 1-3 pieces of positive feedback on the issue, highlighting what is good about it.'
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
'required': ['summary', 'actionable_feedback'],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_choice() -> dict[str, Any]:
|
||||||
|
"""Get the tool choice for the LLM."""
|
||||||
|
return {
|
||||||
|
'type': 'function',
|
||||||
|
'function': {
|
||||||
|
'name': 'solvability_summary',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def system_message() -> dict[str, Any]:
|
||||||
|
"""Get the system message for the LLM."""
|
||||||
|
return {
|
||||||
|
'role': 'system',
|
||||||
|
'content': load_prompt('summary_system_message'),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def user_message(report: SolvabilityReport) -> dict[str, Any]:
|
||||||
|
"""Get the user message for the LLM."""
|
||||||
|
return {
|
||||||
|
'role': 'user',
|
||||||
|
'content': load_prompt(
|
||||||
|
'summary_user_message',
|
||||||
|
report=report.model_dump(),
|
||||||
|
difficulty_level=DifficultyLevel.from_score(report.score).value[0],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_report(report: SolvabilityReport, llm: LLM) -> SolvabilitySummary:
|
||||||
|
"""Create a SolvabilitySummary from a SolvabilityReport."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = llm.completion(
|
||||||
|
messages=[
|
||||||
|
SolvabilitySummary.system_message(),
|
||||||
|
SolvabilitySummary.user_message(report),
|
||||||
|
],
|
||||||
|
tools=[SolvabilitySummary.tool_description()],
|
||||||
|
tool_choice=SolvabilitySummary.tool_choice(),
|
||||||
|
)
|
||||||
|
response_latency = time.time() - start_time
|
||||||
|
|
||||||
|
# Grab the arguments from the forced function call
|
||||||
|
arguments = json.loads(
|
||||||
|
response.choices[0].message.tool_calls[0].function.arguments
|
||||||
|
)
|
||||||
|
|
||||||
|
return SolvabilitySummary(
|
||||||
|
# The score is copied directly from the report
|
||||||
|
score=report.score,
|
||||||
|
# Performance and usage metrics are pulled from the response
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
completion_tokens=response.usage.completion_tokens,
|
||||||
|
response_latency=response_latency,
|
||||||
|
# Every other field should be taken from the forced function call
|
||||||
|
**arguments,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_as_markdown(self) -> str:
|
||||||
|
"""Format the summary content as Markdown."""
|
||||||
|
# Convert score to difficulty level enum
|
||||||
|
difficulty_level = DifficultyLevel.from_score(self.score)
|
||||||
|
|
||||||
|
# Create the main difficulty display
|
||||||
|
result = f'{difficulty_level.format_display()}\n\n{self.summary}'
|
||||||
|
|
||||||
|
# If not easy, show the three features with lowest importance scores
|
||||||
|
if difficulty_level != DifficultyLevel.EASY:
|
||||||
|
# Add dropdown with lowest importance features
|
||||||
|
result += '\n\nYou can make the issue easier to resolve by addressing these concerns in the conversation:\n\n'
|
||||||
|
result += self.actionable_feedback
|
||||||
|
|
||||||
|
# If the difficulty isn't hard, add some positive feedback
|
||||||
|
if difficulty_level != DifficultyLevel.HARD:
|
||||||
|
result += '\n\nPositive feedback:\n\n'
|
||||||
|
result += self.positive_feedback
|
||||||
|
|
||||||
|
return result
|
||||||
13
enterprise/integrations/solvability/prompts/__init__.py
Normal file
13
enterprise/integrations/solvability/prompts/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import jinja2
|
||||||
|
|
||||||
|
|
||||||
|
def load_prompt(prompt: str, **kwargs) -> str:
|
||||||
|
"""Load a prompt by name. Passes all the keyword arguments to the prompt template."""
|
||||||
|
env = jinja2.Environment(loader=jinja2.FileSystemLoader(Path(__file__).parent))
|
||||||
|
template = env.get_template(f'{prompt}.j2')
|
||||||
|
return template.render(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['load_prompt']
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
You are a helpful assistant that generates human-readable summaries of solvability reports.
|
||||||
|
The report predicts how likely it is that the issue can be resolved, and is produced purely based on the information provided in the issue description and comments.
|
||||||
|
The report explains which features are present in the issue and how impactful they are to the solvability score (using SHAP values).
|
||||||
|
Your task is to create a concise, high-level summary of the solvability analysis,
|
||||||
|
with an emphasis on the key factors that make the issue easy or hard to resolve.
|
||||||
|
Focus on the features with extreme scores, BUT ONLY if they are related to the issue at hand after careful consideration.
|
||||||
|
You should NEVER mention: SHAP, scores, feature names, or technical metrics.
|
||||||
|
You will also be given the expected difficulty of the issue, as EASY/MEDIUM/HARD.
|
||||||
|
Be sure to frame your responses with that difficulty in mind.
|
||||||
|
For example, if the issue is HARD you should not describe it as "straightforward".
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
Generate a high-level summary of the solvability report:
|
||||||
|
|
||||||
|
{{ report }}
|
||||||
|
|
||||||
|
We estimate the issue is {{ difficulty_level }}.
|
||||||
|
The summary should be concise (at most two sentences) and describe the primary characteristics of this issue.
|
||||||
|
Focus on what information is present and what factors are most relevant to resolution.
|
||||||
|
Actionable feedback should be something that can be addressed by the user purely by providing more information.
|
||||||
|
Positive feedback should explain the features that are positively contributing to the solvability score.
|
||||||
0
enterprise/integrations/solvability/py.typed
Normal file
0
enterprise/integrations/solvability/py.typed
Normal file
73
enterprise/integrations/stripe_service.py
Normal file
73
enterprise/integrations/stripe_service.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
import stripe
|
||||||
|
from server.auth.token_manager import TokenManager
|
||||||
|
from server.constants import STRIPE_API_KEY
|
||||||
|
from server.logger import logger
|
||||||
|
from storage.database import session_maker
|
||||||
|
from storage.stripe_customer import StripeCustomer
|
||||||
|
|
||||||
|
stripe.api_key = STRIPE_API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
async def find_customer_id_by_user_id(user_id: str) -> str | None:
|
||||||
|
# First search our own DB...
|
||||||
|
with session_maker() as session:
|
||||||
|
stripe_customer = (
|
||||||
|
session.query(StripeCustomer)
|
||||||
|
.filter(StripeCustomer.keycloak_user_id == user_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if stripe_customer:
|
||||||
|
return stripe_customer.stripe_customer_id
|
||||||
|
|
||||||
|
# If that fails, fallback to stripe
|
||||||
|
search_result = await stripe.Customer.search_async(
|
||||||
|
query=f"metadata['user_id']:'{user_id}'",
|
||||||
|
)
|
||||||
|
data = search_result.data
|
||||||
|
if not data:
|
||||||
|
logger.info('no_customer_for_user_id', extra={'user_id': user_id})
|
||||||
|
return None
|
||||||
|
return data[0].id # type: ignore [attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
async def find_or_create_customer(user_id: str) -> str:
|
||||||
|
customer_id = await find_customer_id_by_user_id(user_id)
|
||||||
|
if customer_id:
|
||||||
|
return customer_id
|
||||||
|
logger.info('creating_customer', extra={'user_id': user_id})
|
||||||
|
|
||||||
|
# Get the user info from keycloak
|
||||||
|
token_manager = TokenManager()
|
||||||
|
user_info = await token_manager.get_user_info_from_user_id(user_id) or {}
|
||||||
|
|
||||||
|
# Create the customer in stripe
|
||||||
|
customer = await stripe.Customer.create_async(
|
||||||
|
email=str(user_info.get('email', '')),
|
||||||
|
metadata={'user_id': user_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the stripe customer in the local db
|
||||||
|
with session_maker() as session:
|
||||||
|
session.add(
|
||||||
|
StripeCustomer(keycloak_user_id=user_id, stripe_customer_id=customer.id)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
'created_customer',
|
||||||
|
extra={'user_id': user_id, 'stripe_customer_id': customer.id},
|
||||||
|
)
|
||||||
|
return customer.id
|
||||||
|
|
||||||
|
|
||||||
|
async def has_payment_method(user_id: str) -> bool:
|
||||||
|
customer_id = await find_customer_id_by_user_id(user_id)
|
||||||
|
if customer_id is None:
|
||||||
|
return False
|
||||||
|
payment_methods = await stripe.Customer.list_payment_methods_async(
|
||||||
|
customer_id,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f'has_payment_method:{user_id}:{customer_id}:{bool(payment_methods.data)}'
|
||||||
|
)
|
||||||
|
return bool(payment_methods.data)
|
||||||
51
enterprise/integrations/types.py
Normal file
51
enterprise/integrations/types.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from jinja2 import Environment
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class GitLabResourceType(Enum):
|
||||||
|
GROUP = 'group'
|
||||||
|
SUBGROUP = 'subgroup'
|
||||||
|
PROJECT = 'project'
|
||||||
|
|
||||||
|
|
||||||
|
class PRStatus(Enum):
|
||||||
|
CLOSED = 'CLOSED'
|
||||||
|
MERGED = 'MERGED'
|
||||||
|
|
||||||
|
|
||||||
|
class UserData(BaseModel):
|
||||||
|
user_id: int
|
||||||
|
username: str
|
||||||
|
keycloak_user_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SummaryExtractionTracker:
|
||||||
|
conversation_id: str
|
||||||
|
should_extract: bool
|
||||||
|
send_summary_instruction: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ResolverViewInterface(SummaryExtractionTracker):
|
||||||
|
installation_id: int
|
||||||
|
user_info: UserData
|
||||||
|
issue_number: int
|
||||||
|
full_repo_name: str
|
||||||
|
is_public_repo: bool
|
||||||
|
raw_payload: dict
|
||||||
|
|
||||||
|
def _get_instructions(self, jinja_env: Environment) -> tuple[str, str]:
|
||||||
|
"Instructions passed when conversation is first initialized"
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def create_new_conversation(self, jinja_env: Environment, token: str):
|
||||||
|
"Create a new conversation"
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_callback_id(self) -> str:
|
||||||
|
"Unique callback id for subscribription made to EventStream for fetching agent summary"
|
||||||
|
raise NotImplementedError()
|
||||||
546
enterprise/integrations/utils.py
Normal file
546
enterprise/integrations/utils.py
Normal file
@@ -0,0 +1,546 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from jinja2 import Environment, FileSystemLoader
|
||||||
|
from server.constants import WEB_HOST
|
||||||
|
from storage.repository_store import RepositoryStore
|
||||||
|
from storage.stored_repository import StoredRepository
|
||||||
|
from storage.user_repo_map import UserRepositoryMap
|
||||||
|
from storage.user_repo_map_store import UserRepositoryMapStore
|
||||||
|
|
||||||
|
from openhands.core.config.openhands_config import OpenHandsConfig
|
||||||
|
from openhands.core.logger import openhands_logger as logger
|
||||||
|
from openhands.core.schema.agent import AgentState
|
||||||
|
from openhands.events import Event, EventSource
|
||||||
|
from openhands.events.action import (
|
||||||
|
AgentFinishAction,
|
||||||
|
MessageAction,
|
||||||
|
)
|
||||||
|
from openhands.events.event_store_abc import EventStoreABC
|
||||||
|
from openhands.events.observation.agent import AgentStateChangedObservation
|
||||||
|
from openhands.integrations.service_types import Repository
|
||||||
|
from openhands.storage.data_models.conversation_status import ConversationStatus
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openhands.server.conversation_manager.conversation_manager import (
|
||||||
|
ConversationManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- DO NOT REMOVE ----
|
||||||
|
# WARNING: Langfuse depends on the WEB_HOST environment variable being set to track events.
|
||||||
|
HOST = WEB_HOST
|
||||||
|
# ---- DO NOT REMOVE ----
|
||||||
|
|
||||||
|
HOST_URL = f'https://{HOST}'
|
||||||
|
GITHUB_WEBHOOK_URL = f'{HOST_URL}/integration/github/events'
|
||||||
|
GITLAB_WEBHOOK_URL = f'{HOST_URL}/integration/gitlab/events'
|
||||||
|
conversation_prefix = 'conversations/{}'
|
||||||
|
CONVERSATION_URL = f'{HOST_URL}/{conversation_prefix}'
|
||||||
|
|
||||||
|
# Toggle for auto-response feature that proactively starts conversations with users when workflow tests fail
|
||||||
|
ENABLE_PROACTIVE_CONVERSATION_STARTERS = (
|
||||||
|
os.getenv('ENABLE_PROACTIVE_CONVERSATION_STARTERS', 'false').lower() == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Toggle for solvability report feature
|
||||||
|
ENABLE_SOLVABILITY_ANALYSIS = (
|
||||||
|
os.getenv('ENABLE_SOLVABILITY_ANALYSIS', 'false').lower() == 'true'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
OPENHANDS_RESOLVER_TEMPLATES_DIR = 'openhands/integrations/templates/resolver/'
|
||||||
|
jinja_env = Environment(loader=FileSystemLoader(OPENHANDS_RESOLVER_TEMPLATES_DIR))
|
||||||
|
|
||||||
|
|
||||||
|
def get_oh_labels(web_host: str) -> tuple[str, str]:
|
||||||
|
"""Get the OpenHands labels based on the web host.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
web_host: The web host string to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (oh_label, inline_oh_label) where:
|
||||||
|
- oh_label is 'openhands-exp' for staging/local hosts, 'openhands' otherwise
|
||||||
|
- inline_oh_label is '@openhands-exp' for staging/local hosts, '@openhands' otherwise
|
||||||
|
"""
|
||||||
|
web_host = web_host.strip()
|
||||||
|
is_staging_or_local = 'staging' in web_host or 'local' in web_host
|
||||||
|
oh_label = 'openhands-exp' if is_staging_or_local else 'openhands'
|
||||||
|
inline_oh_label = '@openhands-exp' if is_staging_or_local else '@openhands'
|
||||||
|
return oh_label, inline_oh_label
|
||||||
|
|
||||||
|
|
||||||
|
def get_summary_instruction():
|
||||||
|
summary_instruction_template = jinja_env.get_template('summary_prompt.j2')
|
||||||
|
summary_instruction = summary_instruction_template.render()
|
||||||
|
return summary_instruction
|
||||||
|
|
||||||
|
|
||||||
|
def has_exact_mention(text: str, mention: str) -> bool:
|
||||||
|
"""Check if the text contains an exact mention (not part of a larger word).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text to check for mentions
|
||||||
|
mention: The mention to look for (e.g. "@openhands")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the exact mention is found, False otherwise
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> has_exact_mention("Hello @openhands!", "@openhands") # True
|
||||||
|
>>> has_exact_mention("Hello @openhands-agent!", "@openhands") # False
|
||||||
|
>>> has_exact_mention("(@openhands)", "@openhands") # True
|
||||||
|
>>> has_exact_mention("user@openhands.com", "@openhands") # False
|
||||||
|
>>> has_exact_mention("Hello @OpenHands!", "@openhands") # True (case-insensitive)
|
||||||
|
"""
|
||||||
|
# Convert both text and mention to lowercase for case-insensitive matching
|
||||||
|
text_lower = text.lower()
|
||||||
|
mention_lower = mention.lower()
|
||||||
|
|
||||||
|
pattern = re.escape(mention_lower)
|
||||||
|
# Match mention that is not part of a larger word
|
||||||
|
return bool(re.search(rf'(?:^|[^\w@]){pattern}(?![\w-])', text_lower))
|
||||||
|
|
||||||
|
|
||||||
|
def confirm_event_type(event: Event):
|
||||||
|
return isinstance(event, AgentStateChangedObservation) and not (
|
||||||
|
event.agent_state == AgentState.REJECTED
|
||||||
|
or event.agent_state == AgentState.USER_CONFIRMED
|
||||||
|
or event.agent_state == AgentState.USER_REJECTED
|
||||||
|
or event.agent_state == AgentState.LOADING
|
||||||
|
or event.agent_state == AgentState.RUNNING
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_readable_error_reason(reason: str):
|
||||||
|
if reason == 'STATUS$ERROR_LLM_AUTHENTICATION':
|
||||||
|
reason = 'Authentication with the LLM provider failed. Please check your API key or credentials'
|
||||||
|
elif reason == 'STATUS$ERROR_LLM_SERVICE_UNAVAILABLE':
|
||||||
|
reason = 'The LLM service is temporarily unavailable. Please try again later'
|
||||||
|
elif reason == 'STATUS$ERROR_LLM_INTERNAL_SERVER_ERROR':
|
||||||
|
reason = 'The LLM provider encountered an internal error. Please try again soon'
|
||||||
|
elif reason == 'STATUS$ERROR_LLM_OUT_OF_CREDITS':
|
||||||
|
reason = "You've run out of credits. Please top up to continue"
|
||||||
|
elif reason == 'STATUS$ERROR_LLM_CONTENT_POLICY_VIOLATION':
|
||||||
|
reason = 'Content policy violation. The output was blocked by content filtering policy'
|
||||||
|
return reason
|
||||||
|
|
||||||
|
|
||||||
|
def get_summary_for_agent_state(
|
||||||
|
observations: list[AgentStateChangedObservation], conversation_link: str
|
||||||
|
) -> str:
|
||||||
|
unknown_error_msg = f'OpenHands encountered an unknown error. [See the conversation]({conversation_link}) for more information, or try again'
|
||||||
|
|
||||||
|
if len(observations) == 0:
|
||||||
|
logger.error(
|
||||||
|
'Unknown error: No agent state observations found',
|
||||||
|
extra={'conversation_link': conversation_link},
|
||||||
|
)
|
||||||
|
return unknown_error_msg
|
||||||
|
|
||||||
|
observation: AgentStateChangedObservation = observations[0]
|
||||||
|
state = observation.agent_state
|
||||||
|
|
||||||
|
if state == AgentState.RATE_LIMITED:
|
||||||
|
logger.warning(
|
||||||
|
'Agent was rate limited',
|
||||||
|
extra={
|
||||||
|
'agent_state': state.value,
|
||||||
|
'conversation_link': conversation_link,
|
||||||
|
'observation_reason': getattr(observation, 'reason', None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return 'OpenHands was rate limited by the LLM provider. Please try again later.'
|
||||||
|
|
||||||
|
if state == AgentState.ERROR:
|
||||||
|
reason = observation.reason
|
||||||
|
reason = get_readable_error_reason(reason)
|
||||||
|
|
||||||
|
logger.error(
|
||||||
|
'Agent encountered an error',
|
||||||
|
extra={
|
||||||
|
'agent_state': state.value,
|
||||||
|
'conversation_link': conversation_link,
|
||||||
|
'observation_reason': observation.reason,
|
||||||
|
'readable_reason': reason,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return f'OpenHands encountered an error: **{reason}**.\n\n[See the conversation]({conversation_link}) for more information.'
|
||||||
|
|
||||||
|
# Log unknown agent state as error
|
||||||
|
logger.error(
|
||||||
|
'Unknown error: Unhandled agent state',
|
||||||
|
extra={
|
||||||
|
'agent_state': state.value if hasattr(state, 'value') else str(state),
|
||||||
|
'conversation_link': conversation_link,
|
||||||
|
'observation_reason': getattr(observation, 'reason', None),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return unknown_error_msg
|
||||||
|
|
||||||
|
|
||||||
|
def get_final_agent_observation(
|
||||||
|
event_store: EventStoreABC,
|
||||||
|
) -> list[AgentStateChangedObservation]:
|
||||||
|
return event_store.get_matching_events(
|
||||||
|
source=EventSource.ENVIRONMENT,
|
||||||
|
event_types=(AgentStateChangedObservation,),
|
||||||
|
limit=1,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_user_msg(event_store: EventStoreABC) -> list[MessageAction]:
|
||||||
|
return event_store.get_matching_events(
|
||||||
|
source=EventSource.USER, event_types=(MessageAction,), limit=1, reverse='true'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_summary_from_event_store(
|
||||||
|
event_store: EventStoreABC, conversation_id: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get agent summary or alternative message depending on current AgentState
|
||||||
|
"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||||
|
summary_instruction = get_summary_instruction()
|
||||||
|
|
||||||
|
instruction_event: list[MessageAction] = event_store.get_matching_events(
|
||||||
|
query=json.dumps(summary_instruction),
|
||||||
|
source=EventSource.USER,
|
||||||
|
event_types=(MessageAction,),
|
||||||
|
limit=1,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
final_agent_observation = get_final_agent_observation(event_store)
|
||||||
|
|
||||||
|
# Find summary instruction event ID
|
||||||
|
if len(instruction_event) == 0:
|
||||||
|
logger.warning(
|
||||||
|
'no_instruction_event_found', extra={'conversation_id': conversation_id}
|
||||||
|
)
|
||||||
|
return get_summary_for_agent_state(
|
||||||
|
final_agent_observation, conversation_link
|
||||||
|
) # Agent did not receive summary instruction
|
||||||
|
|
||||||
|
event_id: int = instruction_event[0].id
|
||||||
|
|
||||||
|
agent_messages: list[MessageAction | AgentFinishAction] = (
|
||||||
|
event_store.get_matching_events(
|
||||||
|
start_id=event_id,
|
||||||
|
source=EventSource.AGENT,
|
||||||
|
event_types=(MessageAction, AgentFinishAction),
|
||||||
|
reverse=True,
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(agent_messages) == 0:
|
||||||
|
logger.warning(
|
||||||
|
'no_agent_messages_found', extra={'conversation_id': conversation_id}
|
||||||
|
)
|
||||||
|
return get_summary_for_agent_state(
|
||||||
|
final_agent_observation, conversation_link
|
||||||
|
) # Agent failed to generate summary
|
||||||
|
|
||||||
|
summary_event: MessageAction | AgentFinishAction = agent_messages[0]
|
||||||
|
if isinstance(summary_event, MessageAction):
|
||||||
|
return summary_event.content
|
||||||
|
|
||||||
|
return summary_event.final_thought
|
||||||
|
|
||||||
|
|
||||||
|
async def get_event_store_from_conversation_manager(
|
||||||
|
conversation_manager: ConversationManager, conversation_id: str
|
||||||
|
) -> EventStoreABC:
|
||||||
|
agent_loop_infos = await conversation_manager.get_agent_loop_info(
|
||||||
|
filter_to_sids={conversation_id}
|
||||||
|
)
|
||||||
|
if not agent_loop_infos or agent_loop_infos[0].status != ConversationStatus.RUNNING:
|
||||||
|
raise RuntimeError(f'conversation_not_running:{conversation_id}')
|
||||||
|
event_store = agent_loop_infos[0].event_store
|
||||||
|
if not event_store:
|
||||||
|
raise RuntimeError(f'event_store_missing:{conversation_id}')
|
||||||
|
return event_store
|
||||||
|
|
||||||
|
|
||||||
|
async def get_last_user_msg_from_conversation_manager(
|
||||||
|
conversation_manager: ConversationManager, conversation_id: str
|
||||||
|
):
|
||||||
|
event_store = await get_event_store_from_conversation_manager(
|
||||||
|
conversation_manager, conversation_id
|
||||||
|
)
|
||||||
|
return get_last_user_msg(event_store)
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_summary_from_conversation_manager(
|
||||||
|
conversation_manager: ConversationManager, conversation_id: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Get agent summary or alternative message depending on current AgentState
|
||||||
|
"""
|
||||||
|
|
||||||
|
event_store = await get_event_store_from_conversation_manager(
|
||||||
|
conversation_manager, conversation_id
|
||||||
|
)
|
||||||
|
summary = extract_summary_from_event_store(event_store, conversation_id)
|
||||||
|
return append_conversation_footer(summary, conversation_id)
|
||||||
|
|
||||||
|
|
||||||
|
def append_conversation_footer(message: str, conversation_id: str) -> str:
|
||||||
|
"""
|
||||||
|
Append a small footer with the conversation URL to a message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The original message content
|
||||||
|
conversation_id: The conversation ID to link to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The message with the conversation footer appended
|
||||||
|
"""
|
||||||
|
conversation_link = CONVERSATION_URL.format(conversation_id)
|
||||||
|
footer = f'\n\n<sub>[View full conversation]({conversation_link})</sub>'
|
||||||
|
return message + footer
|
||||||
|
|
||||||
|
|
||||||
|
async def store_repositories_in_db(repos: list[Repository], user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Store repositories in DB and create user-repository mappings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repos: List of Repository objects to store
|
||||||
|
user_id: User ID associated with these repositories
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Convert Repository objects to StoredRepository objects
|
||||||
|
# Convert Repository objects to UserRepositoryMap objects
|
||||||
|
stored_repos = []
|
||||||
|
user_repos = []
|
||||||
|
for repo in repos:
|
||||||
|
repo_id = f'{repo.git_provider.value}##{str(repo.id)}'
|
||||||
|
stored_repo = StoredRepository(
|
||||||
|
repo_name=repo.full_name,
|
||||||
|
repo_id=repo_id,
|
||||||
|
is_public=repo.is_public,
|
||||||
|
# Optional fields set to None by default
|
||||||
|
has_microagent=None,
|
||||||
|
has_setup_script=None,
|
||||||
|
)
|
||||||
|
stored_repos.append(stored_repo)
|
||||||
|
user_repo_map = UserRepositoryMap(user_id=user_id, repo_id=repo_id, admin=None)
|
||||||
|
|
||||||
|
user_repos.append(user_repo_map)
|
||||||
|
|
||||||
|
# Get config instance
|
||||||
|
config = OpenHandsConfig()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Store repositories in the repos table
|
||||||
|
repo_store = RepositoryStore.get_instance(config)
|
||||||
|
repo_store.store_projects(stored_repos)
|
||||||
|
|
||||||
|
# Store user-repository mappings in the user-repos table
|
||||||
|
user_repo_store = UserRepositoryMapStore.get_instance(config)
|
||||||
|
user_repo_store.store_user_repo_mappings(user_repos)
|
||||||
|
|
||||||
|
logger.info(f'Saved repos for user {user_id}')
|
||||||
|
except Exception:
|
||||||
|
logger.warning('Failed to save repos', exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def infer_repo_from_message(user_msg: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Extract all repository names in the format 'owner/repo' from various Git provider URLs
|
||||||
|
and direct mentions in text. Supports GitHub, GitLab, and BitBucket.
|
||||||
|
Args:
|
||||||
|
user_msg: Input message that may contain repository references
|
||||||
|
Returns:
|
||||||
|
List of repository names in 'owner/repo' format, empty list if none found
|
||||||
|
"""
|
||||||
|
# Normalize the message by removing extra whitespace and newlines
|
||||||
|
normalized_msg = re.sub(r'\s+', ' ', user_msg.strip())
|
||||||
|
|
||||||
|
# Pattern to match Git URLs from GitHub, GitLab, and BitBucket
|
||||||
|
# Captures: protocol, domain, owner, repo (with optional .git extension)
|
||||||
|
git_url_pattern = r'https?://(?:github\.com|gitlab\.com|bitbucket\.org)/([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+?)(?:\.git)?(?:[/?#].*?)?(?=\s|$|[^\w.-])'
|
||||||
|
|
||||||
|
# Pattern to match direct owner/repo mentions (e.g., "All-Hands-AI/OpenHands")
|
||||||
|
# Must be surrounded by word boundaries or specific characters to avoid false positives
|
||||||
|
direct_pattern = (
|
||||||
|
r'(?:^|\s|[\[\(\'"])([a-zA-Z0-9_.-]+)/([a-zA-Z0-9_.-]+)(?=\s|$|[\]\)\'",.])'
|
||||||
|
)
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
# First, find all Git URLs (highest priority)
|
||||||
|
git_matches = re.findall(git_url_pattern, normalized_msg)
|
||||||
|
for owner, repo in git_matches:
|
||||||
|
# Remove .git extension if present
|
||||||
|
repo = re.sub(r'\.git$', '', repo)
|
||||||
|
matches.append(f'{owner}/{repo}')
|
||||||
|
|
||||||
|
# Second, find all direct owner/repo mentions
|
||||||
|
direct_matches = re.findall(direct_pattern, normalized_msg)
|
||||||
|
for owner, repo in direct_matches:
|
||||||
|
full_match = f'{owner}/{repo}'
|
||||||
|
|
||||||
|
# Skip if it looks like a version number, date, or file path
|
||||||
|
if (
|
||||||
|
re.match(r'^\d+\.\d+/\d+\.\d+$', full_match) # version numbers
|
||||||
|
or re.match(r'^\d{1,2}/\d{1,2}$', full_match) # dates
|
||||||
|
or re.match(r'^[A-Z]/[A-Z]$', full_match) # single letters
|
||||||
|
or repo.endswith('.txt')
|
||||||
|
or repo.endswith('.md') # file extensions
|
||||||
|
or repo.endswith('.py')
|
||||||
|
or repo.endswith('.js')
|
||||||
|
or '.' in repo
|
||||||
|
and len(repo.split('.')) > 2
|
||||||
|
): # complex file paths
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Avoid duplicates from Git URLs already found
|
||||||
|
if full_match not in matches:
|
||||||
|
matches.append(full_match)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
def filter_potential_repos_by_user_msg(
|
||||||
|
user_msg: str, user_repos: list[Repository]
|
||||||
|
) -> tuple[bool, list[Repository]]:
|
||||||
|
"""Filter repositories based on user message inference."""
|
||||||
|
inferred_repos = infer_repo_from_message(user_msg)
|
||||||
|
if not inferred_repos:
|
||||||
|
return False, user_repos[0:99]
|
||||||
|
|
||||||
|
final_repos = []
|
||||||
|
for repo in user_repos:
|
||||||
|
# Check if the repo matches any of the inferred repositories
|
||||||
|
for inferred_repo in inferred_repos:
|
||||||
|
if inferred_repo.lower() in repo.full_name.lower():
|
||||||
|
final_repos.append(repo)
|
||||||
|
break # Avoid adding the same repo multiple times
|
||||||
|
|
||||||
|
# no repos matched, return original list
|
||||||
|
if len(final_repos) == 0:
|
||||||
|
return False, user_repos[0:99]
|
||||||
|
|
||||||
|
# Found exact match
|
||||||
|
elif len(final_repos) == 1:
|
||||||
|
return True, final_repos
|
||||||
|
|
||||||
|
# Found partial matches
|
||||||
|
return False, final_repos[0:99]
|
||||||
|
|
||||||
|
|
||||||
|
def markdown_to_jira_markup(markdown_text: str) -> str:
|
||||||
|
"""
|
||||||
|
Convert markdown text to Jira Wiki Markup format.
|
||||||
|
This function handles common markdown elements and converts them to their
|
||||||
|
Jira Wiki Markup equivalents. It's designed to be exception-safe.
|
||||||
|
Args:
|
||||||
|
markdown_text: The markdown text to convert
|
||||||
|
Returns:
|
||||||
|
str: The converted Jira Wiki Markup text
|
||||||
|
"""
|
||||||
|
if not markdown_text or not isinstance(markdown_text, str):
|
||||||
|
return ''
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Work with a copy to avoid modifying the original
|
||||||
|
text = markdown_text
|
||||||
|
|
||||||
|
# Convert headers (# ## ### #### ##### ######)
|
||||||
|
text = re.sub(r'^#{6}\s+(.*?)$', r'h6. \1', text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r'^#{5}\s+(.*?)$', r'h5. \1', text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r'^#{4}\s+(.*?)$', r'h4. \1', text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r'^#{3}\s+(.*?)$', r'h3. \1', text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r'^#{2}\s+(.*?)$', r'h2. \1', text, flags=re.MULTILINE)
|
||||||
|
text = re.sub(r'^#{1}\s+(.*?)$', r'h1. \1', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# Convert code blocks first (before other formatting)
|
||||||
|
text = re.sub(
|
||||||
|
r'```(\w+)\n(.*?)\n```', r'{code:\1}\n\2\n{code}', text, flags=re.DOTALL
|
||||||
|
)
|
||||||
|
text = re.sub(r'```\n(.*?)\n```', r'{code}\n\1\n{code}', text, flags=re.DOTALL)
|
||||||
|
|
||||||
|
# Convert inline code (`code`)
|
||||||
|
text = re.sub(r'`([^`]+)`', r'{{\1}}', text)
|
||||||
|
|
||||||
|
# Convert markdown formatting to Jira formatting
|
||||||
|
# Use temporary placeholders to avoid conflicts between bold and italic conversion
|
||||||
|
|
||||||
|
# First convert bold (double markers) to temporary placeholders
|
||||||
|
text = re.sub(r'\*\*(.*?)\*\*', r'JIRA_BOLD_START\1JIRA_BOLD_END', text)
|
||||||
|
text = re.sub(r'__(.*?)__', r'JIRA_BOLD_START\1JIRA_BOLD_END', text)
|
||||||
|
|
||||||
|
# Now convert single asterisk italics
|
||||||
|
text = re.sub(r'\*([^*]+?)\*', r'_\1_', text)
|
||||||
|
|
||||||
|
# Convert underscore italics
|
||||||
|
text = re.sub(r'(?<!_)_([^_]+?)_(?!_)', r'_\1_', text)
|
||||||
|
|
||||||
|
# Finally, restore bold markers
|
||||||
|
text = text.replace('JIRA_BOLD_START', '*')
|
||||||
|
text = text.replace('JIRA_BOLD_END', '*')
|
||||||
|
|
||||||
|
# Convert links [text](url)
|
||||||
|
text = re.sub(r'\[([^\]]+)\]\(([^)]+)\)', r'[\1|\2]', text)
|
||||||
|
|
||||||
|
# Convert unordered lists (- or * or +)
|
||||||
|
text = re.sub(r'^[\s]*[-*+]\s+(.*?)$', r'* \1', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# Convert ordered lists (1. 2. etc.)
|
||||||
|
text = re.sub(r'^[\s]*\d+\.\s+(.*?)$', r'# \1', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# Convert strikethrough (~~text~~)
|
||||||
|
text = re.sub(r'~~(.*?)~~', r'-\1-', text)
|
||||||
|
|
||||||
|
# Convert horizontal rules (---, ***, ___)
|
||||||
|
text = re.sub(r'^[\s]*[-*_]{3,}[\s]*$', r'----', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# Convert blockquotes (> text)
|
||||||
|
text = re.sub(r'^>\s+(.*?)$', r'bq. \1', text, flags=re.MULTILINE)
|
||||||
|
|
||||||
|
# Convert tables (basic support)
|
||||||
|
# This is a simplified table conversion - Jira tables are quite different
|
||||||
|
lines = text.split('\n')
|
||||||
|
in_table = False
|
||||||
|
converted_lines = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
if (
|
||||||
|
'|' in line
|
||||||
|
and line.strip().startswith('|')
|
||||||
|
and line.strip().endswith('|')
|
||||||
|
):
|
||||||
|
# Skip markdown table separator lines (contain ---)
|
||||||
|
if '---' in line:
|
||||||
|
continue
|
||||||
|
if not in_table:
|
||||||
|
in_table = True
|
||||||
|
# Convert markdown table row to Jira table row
|
||||||
|
cells = [cell.strip() for cell in line.split('|')[1:-1]]
|
||||||
|
converted_line = '|' + '|'.join(cells) + '|'
|
||||||
|
converted_lines.append(converted_line)
|
||||||
|
elif in_table and line.strip() and '|' not in line:
|
||||||
|
in_table = False
|
||||||
|
converted_lines.append(line)
|
||||||
|
else:
|
||||||
|
in_table = False
|
||||||
|
converted_lines.append(line)
|
||||||
|
|
||||||
|
text = '\n'.join(converted_lines)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log the error but don't raise it - return original text as fallback
|
||||||
|
print(f'Error converting markdown to Jira markup: {str(e)}')
|
||||||
|
return markdown_text or ''
|
||||||
114
enterprise/migrations/env.py
Normal file
114
enterprise/migrations/env.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from google.cloud.sql.connector import Connector
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from storage.base import Base
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
DB_USER = os.getenv('DB_USER', 'postgres')
|
||||||
|
DB_PASS = os.getenv('DB_PASS', 'postgres')
|
||||||
|
DB_HOST = os.getenv('DB_HOST', 'localhost')
|
||||||
|
DB_PORT = os.getenv('DB_PORT', '5432')
|
||||||
|
DB_NAME = os.getenv('DB_NAME', 'openhands')
|
||||||
|
|
||||||
|
GCP_DB_INSTANCE = os.getenv('GCP_DB_INSTANCE')
|
||||||
|
GCP_PROJECT = os.getenv('GCP_PROJECT')
|
||||||
|
GCP_REGION = os.getenv('GCP_REGION')
|
||||||
|
|
||||||
|
POOL_SIZE = int(os.getenv('DB_POOL_SIZE', '25'))
|
||||||
|
MAX_OVERFLOW = int(os.getenv('DB_MAX_OVERFLOW', '10'))
|
||||||
|
|
||||||
|
|
||||||
|
def get_engine(database_name=DB_NAME):
|
||||||
|
"""Create SQLAlchemy engine with optional database name."""
|
||||||
|
if GCP_DB_INSTANCE:
|
||||||
|
|
||||||
|
def get_db_connection():
|
||||||
|
connector = Connector()
|
||||||
|
instance_string = f'{GCP_PROJECT}:{GCP_REGION}:{GCP_DB_INSTANCE}'
|
||||||
|
return connector.connect(
|
||||||
|
instance_string,
|
||||||
|
'pg8000',
|
||||||
|
user=DB_USER,
|
||||||
|
password=DB_PASS.strip(),
|
||||||
|
db=database_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_engine(
|
||||||
|
'postgresql+pg8000://',
|
||||||
|
creator=get_db_connection,
|
||||||
|
pool_size=POOL_SIZE,
|
||||||
|
max_overflow=MAX_OVERFLOW,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
url = f'postgresql://{DB_USER}:{DB_PASS}@{DB_HOST}:{DB_PORT}/{database_name}'
|
||||||
|
return create_engine(
|
||||||
|
url,
|
||||||
|
pool_size=POOL_SIZE,
|
||||||
|
max_overflow=MAX_OVERFLOW,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
engine = get_engine()
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
"""
|
||||||
|
url = config.get_main_option('sqlalchemy.url')
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={'paramstyle': 'named'},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
"""
|
||||||
|
connectable = engine
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
version_table_schema=target_metadata.schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
26
enterprise/migrations/script.py.mako
Normal file
26
enterprise/migrations/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
45
enterprise/migrations/versions/001_create_feedback_table.py
Normal file
45
enterprise/migrations/versions/001_create_feedback_table.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
"""Create feedback table
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2024-03-19 10:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '001'
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'feedback',
|
||||||
|
sa.Column('id', sa.String(), nullable=False),
|
||||||
|
sa.Column('version', sa.String(), nullable=False),
|
||||||
|
sa.Column('email', sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
'polarity',
|
||||||
|
sa.Enum('positive', 'negative', name='polarity_enum'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'permissions',
|
||||||
|
sa.Enum('public', 'private', name='permissions_enum'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column('trajectory', sa.JSON(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('feedback')
|
||||||
|
op.execute('DROP TYPE polarity_enum')
|
||||||
|
op.execute('DROP TYPE permissions_enum')
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""create saas settings table
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2025-01-27 20:08:58.360566
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '002'
|
||||||
|
down_revision: Union[str, None] = '001'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# This was created to match the settings object - in future some of these strings should probabyl
|
||||||
|
# be replaced with enum types.
|
||||||
|
op.create_table(
|
||||||
|
'settings',
|
||||||
|
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('language', sa.String(), nullable=True),
|
||||||
|
sa.Column('agent', sa.String(), nullable=True),
|
||||||
|
sa.Column('max_iterations', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('security_analyzer', sa.String(), nullable=True),
|
||||||
|
sa.Column('confirmation_mode', sa.Boolean(), nullable=True, default=False),
|
||||||
|
sa.Column('llm_model', sa.String(), nullable=True),
|
||||||
|
sa.Column('llm_api_key', sa.String(), nullable=True),
|
||||||
|
sa.Column('llm_base_url', sa.String(), nullable=True),
|
||||||
|
sa.Column('remote_runtime_resource_factor', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('github_token', sa.String(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
'enable_default_condenser', sa.Boolean(), nullable=False, default=False
|
||||||
|
),
|
||||||
|
sa.Column('user_consents_to_analytics', sa.Boolean(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('settings')
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""create saas conversations table
|
||||||
|
|
||||||
|
Revision ID: 003
|
||||||
|
Revises: 002
|
||||||
|
Create Date: 2025-01-29 09:36:49.475467
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '003'
|
||||||
|
down_revision: Union[str, None] = '002'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('conversation_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(), nullable=False, index=True),
|
||||||
|
sa.Column('selected_repository', sa.String(), nullable=True),
|
||||||
|
sa.Column('title', sa.String(), nullable=True),
|
||||||
|
sa.Column('last_updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False, index=True),
|
||||||
|
sa.PrimaryKeyConstraint('conversation_id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('conversation_metadata')
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
"""create saas conversations table
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2025-01-29 09:36:49.475467
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '004'
|
||||||
|
down_revision: Union[str, None] = '003'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'billing_sessions',
|
||||||
|
sa.Column('id', sa.String(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
'status',
|
||||||
|
sa.Enum(
|
||||||
|
'in_progress',
|
||||||
|
'completed',
|
||||||
|
'cancelled',
|
||||||
|
'error',
|
||||||
|
name='billing_session_status_enum',
|
||||||
|
),
|
||||||
|
nullable=False,
|
||||||
|
default='in_progress',
|
||||||
|
),
|
||||||
|
sa.Column('price', sa.DECIMAL(19, 4), nullable=False),
|
||||||
|
sa.Column('price_code', sa.String(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('billing_sessions')
|
||||||
|
op.execute('DROP TYPE billing_session_status_enum')
|
||||||
26
enterprise/migrations/versions/005_add_margin_column.py
Normal file
26
enterprise/migrations/versions/005_add_margin_column.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""add margin column
|
||||||
|
|
||||||
|
Revision ID: 005
|
||||||
|
Revises: 004
|
||||||
|
Create Date: 2025-02-10 08:36:49.475467
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '005'
|
||||||
|
down_revision: Union[str, None] = '004'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('settings', sa.Column('margin', sa.Float(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('settings', 'margin')
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""add branch column to convo metadata table
|
||||||
|
|
||||||
|
Revision ID: 006
|
||||||
|
Revises: 005
|
||||||
|
Create Date: 2025-02-11 14:59:09.415
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '006'
|
||||||
|
down_revision: Union[str, None] = '005'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('selected_branch', sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('conversation_metadata', 'selected_branch')
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""add enable_sound_notifications column to settings table
|
||||||
|
|
||||||
|
Revision ID: 007
|
||||||
|
Revises: 006
|
||||||
|
Create Date: 2025-05-01 10:00:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '007'
|
||||||
|
down_revision: Union[str, None] = '006'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column(
|
||||||
|
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('settings', 'enable_sound_notifications')
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""fix enable_sound_notifications settings to not be nullable
|
||||||
|
|
||||||
|
Revision ID: 008
|
||||||
|
Revises: 007
|
||||||
|
Create Date: 2025-02-28 18:28:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '008'
|
||||||
|
down_revision: Union[str, None] = '007'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column(
|
||||||
|
'enable_sound_notifications', sa.Boolean(), nullable=False, default=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column(
|
||||||
|
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
"""fix enable_sound_notifications settings to not be nullable
|
||||||
|
|
||||||
|
Revision ID: 009
|
||||||
|
Revises: 008
|
||||||
|
Create Date: 2025-02-28 18:28:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '009'
|
||||||
|
down_revision: Union[str, None] = '008'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
'UPDATE settings SET enable_sound_notifications=FALSE where enable_sound_notifications IS NULL'
|
||||||
|
)
|
||||||
|
op.alter_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column(
|
||||||
|
'enable_sound_notifications', sa.Boolean(), nullable=False, default=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column(
|
||||||
|
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
"""create offline tokens table.
|
||||||
|
|
||||||
|
Revision ID: 010
|
||||||
|
Revises: 009_fix_enable_sound_notifications_column
|
||||||
|
Create Date: 2024-03-11
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '010'
|
||||||
|
down_revision: Union[str, None] = '009'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'offline_tokens',
|
||||||
|
sa.Column('user_id', sa.String(length=255), primary_key=True),
|
||||||
|
sa.Column('offline_token', sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
'created_at', sa.DateTime(), server_default=sa.text('now()'), nullable=False
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'updated_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('now()'),
|
||||||
|
onupdate=sa.text('now()'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('offline_tokens')
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""create user settings table
|
||||||
|
|
||||||
|
Revision ID: 011
|
||||||
|
Revises: 010
|
||||||
|
Create Date: 2024-03-11 23:39:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '011'
|
||||||
|
down_revision: Union[str, None] = '010'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'user_settings',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('language', sa.String(), nullable=True),
|
||||||
|
sa.Column('agent', sa.String(), nullable=True),
|
||||||
|
sa.Column('max_iterations', sa.Integer(), nullable=True),
|
||||||
|
sa.Column('security_analyzer', sa.String(), nullable=True),
|
||||||
|
sa.Column('confirmation_mode', sa.Boolean(), nullable=True, default=False),
|
||||||
|
sa.Column('llm_model', sa.String(), nullable=True),
|
||||||
|
sa.Column('llm_api_key', sa.String(), nullable=True),
|
||||||
|
sa.Column('llm_base_url', sa.String(), nullable=True),
|
||||||
|
sa.Column('remote_runtime_resource_factor', sa.Integer(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
'enable_default_condenser', sa.Boolean(), nullable=False, default=False
|
||||||
|
),
|
||||||
|
sa.Column('user_consents_to_analytics', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('billing_margin', sa.Float(), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
'enable_sound_notifications', sa.Boolean(), nullable=True, default=False
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Create indexes for faster lookups
|
||||||
|
op.create_index('idx_keycloak_user_id', 'user_settings', ['keycloak_user_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('idx_keycloak_user_id', 'user_settings')
|
||||||
|
op.drop_table('user_settings')
|
||||||
@@ -0,0 +1,26 @@
|
|||||||
|
"""add secret_store column to settings table
|
||||||
|
Revision ID: 012
|
||||||
|
Revises: 011
|
||||||
|
Create Date: 2025-05-01 10:00:00.000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '012'
|
||||||
|
down_revision: Union[str, None] = '011'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'settings', sa.Column('secrets_store', sa.JSON(), nullable=True, default=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('settings', 'secrets_store')
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
"""create user settings table
|
||||||
|
|
||||||
|
Revision ID: 013
|
||||||
|
Revises: 012
|
||||||
|
Create Date: 2024-03-12 23:39:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '013'
|
||||||
|
down_revision: Union[str, None] = '012'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'github_app_installations',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), primary_key=True),
|
||||||
|
sa.Column('installation_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('encrypted_token', sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
'created_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('now()'),
|
||||||
|
onupdate=sa.text('now()'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'updated_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('now()'),
|
||||||
|
onupdate=sa.text('now()'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Create indexes for faster lookups
|
||||||
|
op.create_index(
|
||||||
|
'idx_installation_id',
|
||||||
|
'github_app_installations',
|
||||||
|
['installation_id'],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('idx_installation_id', 'github_app_installations')
|
||||||
|
op.drop_table('github_app_installations')
|
||||||
40
enterprise/migrations/versions/014_add_github_user_id.py
Normal file
40
enterprise/migrations/versions/014_add_github_user_id.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Add github_user_id field and rename user_id to github_user_id.
|
||||||
|
|
||||||
|
This migration:
|
||||||
|
1. Renames the existing user_id column to github_user_id
|
||||||
|
2. Creates a new user_id column
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy import Column, String
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '014'
|
||||||
|
down_revision: Union[str, None] = '013'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# First rename the existing user_id column to github_user_id
|
||||||
|
op.alter_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
'user_id',
|
||||||
|
nullable=True,
|
||||||
|
new_column_name='github_user_id',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Then add the new user_id column
|
||||||
|
op.add_column('conversation_metadata', Column('user_id', String, nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Drop the new user_id column
|
||||||
|
op.drop_column('conversation_metadata', 'user_id')
|
||||||
|
|
||||||
|
# Rename github_user_id back to user_id
|
||||||
|
op.alter_column(
|
||||||
|
'conversation_metadata', 'github_user_id', new_column_name='user_id'
|
||||||
|
)
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"""add sandbox_base_container_image and sandbox_runtime_container_image columns
|
||||||
|
|
||||||
|
Revision ID: 015
|
||||||
|
Revises: 014
|
||||||
|
Create Date: 2025-03-19 19:30:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '015'
|
||||||
|
down_revision: Union[str, None] = '014'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Add columns to settings table
|
||||||
|
op.add_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column('sandbox_base_container_image', sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'settings',
|
||||||
|
sa.Column('sandbox_runtime_container_image', sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add columns to user_settings table
|
||||||
|
op.add_column(
|
||||||
|
'user_settings',
|
||||||
|
sa.Column('sandbox_base_container_image', sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'user_settings',
|
||||||
|
sa.Column('sandbox_runtime_container_image', sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop columns from settings table
|
||||||
|
op.drop_column('settings', 'sandbox_base_container_image')
|
||||||
|
op.drop_column('settings', 'sandbox_runtime_container_image')
|
||||||
|
|
||||||
|
# Drop columns from user_settings table
|
||||||
|
op.drop_column('user_settings', 'sandbox_base_container_image')
|
||||||
|
op.drop_column('user_settings', 'sandbox_runtime_container_image')
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""Add user settings version which acts as a hint of external db state
|
||||||
|
|
||||||
|
Revision ID: 016
|
||||||
|
Revises: 015
|
||||||
|
Create Date: 2025-03-20 16:30:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '016'
|
||||||
|
down_revision: Union[str, None] = '015'
|
||||||
|
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'user_settings',
|
||||||
|
sa.Column('user_version', sa.Integer(), nullable=False, server_default='0'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('user_settings', 'user_version')
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
"""Add a stripe customers table
|
||||||
|
|
||||||
|
Revision ID: 017
|
||||||
|
Revises: 016
|
||||||
|
Create Date: 2025-03-20 16:30:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '017'
|
||||||
|
down_revision: Union[str, None] = '016'
|
||||||
|
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'stripe_customers',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('keycloak_user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('stripe_customer_id', sa.String(), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
'created_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('now()'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'updated_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('now()'),
|
||||||
|
onupdate=sa.text('now()'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Create indexes for faster lookups
|
||||||
|
op.create_index(
|
||||||
|
'idx_stripe_customers_keycloak_user_id',
|
||||||
|
'stripe_customers',
|
||||||
|
['keycloak_user_id'],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
'idx_stripe_customers_stripe_customer_id',
|
||||||
|
'stripe_customers',
|
||||||
|
['stripe_customer_id'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('stripe_customers')
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""Add a table for tracking output from maintainance scripts. These are basically migrations that are not sql centric.
|
||||||
|
Revision ID: 018
|
||||||
|
Revises: 017
|
||||||
|
Create Date: 2025-03-26 19:45:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '018'
|
||||||
|
down_revision: Union[str, None] = '017'
|
||||||
|
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'script_results',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('revision', sa.String(), nullable=False, index=True),
|
||||||
|
sa.Column('data', sa.JSON()),
|
||||||
|
sa.Column(
|
||||||
|
'created_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table('script_results')
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
"""Remove duplicates from stripe. This is a non standard alembic migration for non sql resources.
|
||||||
|
|
||||||
|
Revision ID: 019
|
||||||
|
Revises: 018
|
||||||
|
Create Date: 2025-03-20 16:30:00.000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import stripe
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.sql import text
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '019'
|
||||||
|
down_revision: Union[str, None] = '018'
|
||||||
|
branch_labels: Union[str, sa.Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, sa.Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Skip migration if STRIPE_API_KEY is not set
|
||||||
|
if 'STRIPE_API_KEY' not in os.environ:
|
||||||
|
print('Skipping migration 019: STRIPE_API_KEY not set')
|
||||||
|
return
|
||||||
|
|
||||||
|
stripe.api_key = os.environ['STRIPE_API_KEY']
|
||||||
|
|
||||||
|
# Get all users from stripe
|
||||||
|
user_id_to_customer_ids = defaultdict(list)
|
||||||
|
customers = stripe.Customer.list()
|
||||||
|
for customer in customers.auto_paging_iter():
|
||||||
|
user_id = customer.metadata.get('user_id')
|
||||||
|
if user_id:
|
||||||
|
user_id_to_customer_ids[user_id].append(customer.id)
|
||||||
|
|
||||||
|
# Canonical
|
||||||
|
stripe_customers = {
|
||||||
|
row[0]: row[1]
|
||||||
|
for row in op.get_bind().execute(
|
||||||
|
text('SELECT keycloak_user_id, stripe_customer_id FROM stripe_customers')
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
to_delete = []
|
||||||
|
for user_id, customer_ids in user_id_to_customer_ids.items():
|
||||||
|
if len(customer_ids) == 1:
|
||||||
|
continue
|
||||||
|
canonical_customer_id = stripe_customers.get(user_id)
|
||||||
|
if canonical_customer_id:
|
||||||
|
for customer_id in customer_ids:
|
||||||
|
if customer_id != canonical_customer_id:
|
||||||
|
to_delete.append({'user_id': user_id, 'customer_id': customer_id})
|
||||||
|
else:
|
||||||
|
# Prioritize deletion of items that don't have payment methods
|
||||||
|
to_delete_for_customer = []
|
||||||
|
for customer_id in customer_ids:
|
||||||
|
payment_methods = stripe.Customer.list_payment_methods(customer_id)
|
||||||
|
to_delete_for_customer.append(
|
||||||
|
{
|
||||||
|
'user_id': user_id,
|
||||||
|
'customer_id': customer_id,
|
||||||
|
'num_payment_methods': len(payment_methods),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
to_delete_for_customer.sort(
|
||||||
|
key=lambda c: c['num_payment_methods'], reverse=True
|
||||||
|
)
|
||||||
|
to_delete.extend(to_delete_for_customer[1:])
|
||||||
|
|
||||||
|
for item in to_delete:
|
||||||
|
op.get_bind().execute(
|
||||||
|
text(
|
||||||
|
'INSERT INTO script_results (revision, data) VALUES (:revision, :data)'
|
||||||
|
),
|
||||||
|
{
|
||||||
|
'revision': revision,
|
||||||
|
'data': json.dumps(item),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
stripe.Customer.delete(item['customer_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.get_bind().execute(
|
||||||
|
text('DELETE FROM script_results WHERE revision=:revision'),
|
||||||
|
{'revision': revision},
|
||||||
|
)
|
||||||
40
enterprise/migrations/versions/020_set_condenser_to_false.py
Normal file
40
enterprise/migrations/versions/020_set_condenser_to_false.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""set condenser to false for all users
|
||||||
|
|
||||||
|
Revision ID: 020
|
||||||
|
Revises: 019
|
||||||
|
Create Date: 2025-04-02 12:45:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.sql import column, table
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '020'
|
||||||
|
down_revision: Union[str, None] = '019'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Define tables for update operations
|
||||||
|
settings_table = table('settings', column('enable_default_condenser', sa.Boolean))
|
||||||
|
|
||||||
|
user_settings_table = table(
|
||||||
|
'user_settings', column('enable_default_condenser', sa.Boolean)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the enable_default_condenser column to False for all users in the settings table
|
||||||
|
op.execute(settings_table.update().values(enable_default_condenser=False))
|
||||||
|
|
||||||
|
# Update the enable_default_condenser column to False for all users in the user_settings table
|
||||||
|
op.execute(user_settings_table.update().values(enable_default_condenser=False))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# No downgrade operation needed as we're just setting a value
|
||||||
|
# and not changing schema structure
|
||||||
|
pass
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""create auth tokens table
|
||||||
|
|
||||||
|
Revision ID: 021
|
||||||
|
Revises: 020
|
||||||
|
Create Date: 2025-03-30 20:15:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '021'
|
||||||
|
down_revision: Union[str, None] = '020'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'auth_tokens',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('keycloak_user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('identity_provider', sa.String(), nullable=False),
|
||||||
|
sa.Column('access_token', sa.String(), nullable=False),
|
||||||
|
sa.Column('refresh_token', sa.String(), nullable=False),
|
||||||
|
sa.Column('access_token_expires_at', sa.BigInteger(), nullable=False),
|
||||||
|
sa.Column('refresh_token_expires_at', sa.BigInteger(), nullable=False),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
'idx_auth_tokens_keycloak_user_id', 'auth_tokens', ['keycloak_user_id']
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
'idx_auth_tokens_keycloak_user_identity_provider',
|
||||||
|
'auth_tokens',
|
||||||
|
['keycloak_user_id', 'identity_provider'],
|
||||||
|
unique=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('idx_auth_tokens_keycloak_user_identity_provider', 'auth_tokens')
|
||||||
|
op.drop_index('idx_auth_tokens_keycloak_user_id', 'auth_tokens')
|
||||||
|
op.drop_table('auth_tokens')
|
||||||
44
enterprise/migrations/versions/022_create_api_keys_table.py
Normal file
44
enterprise/migrations/versions/022_create_api_keys_table.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""Create API keys table
|
||||||
|
|
||||||
|
Revision ID: 022
|
||||||
|
Revises: 021
|
||||||
|
Create Date: 2025-04-03
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '022'
|
||||||
|
down_revision = '021'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.create_table(
|
||||||
|
'api_keys',
|
||||||
|
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('key', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('user_id', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
'created_at',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('CURRENT_TIMESTAMP'),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.Column('expires_at', sa.DateTime(), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint('id'),
|
||||||
|
sa.UniqueConstraint('key'),
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_api_keys_key'), 'api_keys', ['key'], unique=True)
|
||||||
|
op.create_index(op.f('ix_api_keys_user_id'), 'api_keys', ['user_id'], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_index(op.f('ix_api_keys_user_id'), table_name='api_keys')
|
||||||
|
op.drop_index(op.f('ix_api_keys_key'), table_name='api_keys')
|
||||||
|
op.drop_table('api_keys')
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""Add cost and token metrics columns to conversation_metadata table.
|
||||||
|
|
||||||
|
Revision ID: 023
|
||||||
|
Revises: 022
|
||||||
|
Create Date: 2025-04-07
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '023'
|
||||||
|
down_revision = '022'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# Add cost and token metrics columns to conversation_metadata table
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('accumulated_cost', sa.Float(), nullable=True, server_default='0.0'),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('prompt_tokens', sa.Integer(), nullable=True, server_default='0'),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('completion_tokens', sa.Integer(), nullable=True, server_default='0'),
|
||||||
|
)
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('total_tokens', sa.Integer(), nullable=True, server_default='0'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# Remove cost and token metrics columns from conversation_metadata table
|
||||||
|
op.drop_column('conversation_metadata', 'accumulated_cost')
|
||||||
|
op.drop_column('conversation_metadata', 'prompt_tokens')
|
||||||
|
op.drop_column('conversation_metadata', 'completion_tokens')
|
||||||
|
op.drop_column('conversation_metadata', 'total_tokens')
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
"""update enable_default_condenser default to True
|
||||||
|
|
||||||
|
Revision ID: 024
|
||||||
|
Revises: 023
|
||||||
|
Create Date: 2024-04-08 15:30:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.sql import column, table
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '024'
|
||||||
|
down_revision: Union[str, None] = '023'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Update existing rows in settings table
|
||||||
|
settings_table = table('settings', column('enable_default_condenser', sa.Boolean))
|
||||||
|
op.execute(settings_table.update().values(enable_default_condenser=True))
|
||||||
|
|
||||||
|
# Update existing rows in user_settings table
|
||||||
|
user_settings_table = table(
|
||||||
|
'user_settings', column('enable_default_condenser', sa.Boolean)
|
||||||
|
)
|
||||||
|
op.execute(user_settings_table.update().values(enable_default_condenser=True))
|
||||||
|
|
||||||
|
# Alter the default value for settings table
|
||||||
|
op.alter_column(
|
||||||
|
'settings',
|
||||||
|
'enable_default_condenser',
|
||||||
|
existing_type=sa.Boolean(),
|
||||||
|
server_default=sa.true(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Alter the default value for user_settings table
|
||||||
|
op.alter_column(
|
||||||
|
'user_settings',
|
||||||
|
'enable_default_condenser',
|
||||||
|
existing_type=sa.Boolean(),
|
||||||
|
server_default=sa.true(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Revert the default value for settings table
|
||||||
|
op.alter_column(
|
||||||
|
'settings',
|
||||||
|
'enable_default_condenser',
|
||||||
|
existing_type=sa.Boolean(),
|
||||||
|
server_default=sa.false(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Revert the default value for user_settings table
|
||||||
|
op.alter_column(
|
||||||
|
'user_settings',
|
||||||
|
'enable_default_condenser',
|
||||||
|
existing_type=sa.Boolean(),
|
||||||
|
server_default=sa.false(),
|
||||||
|
existing_nullable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: We don't revert the data changes in the downgrade function
|
||||||
|
# as it would be arbitrary which rows to change back
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
"""Revert user_version from 3 to 2
|
||||||
|
|
||||||
|
Revision ID: 025
|
||||||
|
Revises: 024
|
||||||
|
Create Date: 2025-04-09
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '025'
|
||||||
|
down_revision: Union[str, None] = '024'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Update user_version from 3 to 2 for all users who have version 3
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE user_settings
|
||||||
|
SET user_version = 2
|
||||||
|
WHERE user_version = 3
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Revert back to version 3 for users who have version 2
|
||||||
|
# Note: This is not a perfect downgrade as we can't know which users originally had version 3
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE user_settings
|
||||||
|
SET user_version = 3
|
||||||
|
WHERE user_version = 2
|
||||||
|
"""
|
||||||
|
)
|
||||||
@@ -0,0 +1,29 @@
|
|||||||
|
"""add branch column to convo metadata table
|
||||||
|
|
||||||
|
Revision ID: 026
|
||||||
|
Revises: 025
|
||||||
|
Create Date: 2025-04-16 14:59:09.415
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '026'
|
||||||
|
down_revision: Union[str, None] = '025'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'conversation_metadata',
|
||||||
|
sa.Column('trigger', sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('conversation_metadata', 'trigger')
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
"""create saas settings table
|
||||||
|
|
||||||
|
Revision ID: 027
|
||||||
|
Revises: 026
|
||||||
|
Create Date: 2025-01-27 20:08:58.360566
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '027'
|
||||||
|
down_revision: Union[str, None] = '026'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# This was created to match the settings object - in future some of these strings should probabyl
|
||||||
|
# be replaced with enum types.
|
||||||
|
op.create_table(
|
||||||
|
'gitlab-webhook',
|
||||||
|
sa.Column(
|
||||||
|
'id', sa.Integer(), nullable=False, primary_key=True, autoincrement=True
|
||||||
|
),
|
||||||
|
sa.Column('group_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('project_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('webhook_exists', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('webhook_name', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('webhook_url', sa.String(), nullable=True),
|
||||||
|
sa.Column('webhook_secret', sa.String(), nullable=True),
|
||||||
|
sa.Column('scopes', sa.String, nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes for faster lookups
|
||||||
|
op.create_index('ix_gitlab_webhook_user_id', 'gitlab-webhook', ['user_id'])
|
||||||
|
op.create_index('ix_gitlab_webhook_group_id', 'gitlab-webhook', ['group_id'])
|
||||||
|
op.create_index('ix_gitlab_webhook_project_id', 'gitlab-webhook', ['project_id'])
|
||||||
|
|
||||||
|
# Add unique constraints on group_id and project_id to support UPSERT operations
|
||||||
|
op.create_unique_constraint(
|
||||||
|
'uq_gitlab_webhook_group_id', 'gitlab-webhook', ['group_id']
|
||||||
|
)
|
||||||
|
op.create_unique_constraint(
|
||||||
|
'uq_gitlab_webhook_project_id', 'gitlab-webhook', ['project_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop the constraints and indexes first before dropping the table
|
||||||
|
op.drop_constraint('uq_gitlab_webhook_group_id', 'gitlab-webhook', type_='unique')
|
||||||
|
op.drop_constraint('uq_gitlab_webhook_project_id', 'gitlab-webhook', type_='unique')
|
||||||
|
op.drop_index('ix_gitlab_webhook_user_id', table_name='gitlab-webhook')
|
||||||
|
op.drop_index('ix_gitlab_webhook_group_id', table_name='gitlab-webhook')
|
||||||
|
op.drop_index('ix_gitlab_webhook_project_id', table_name='gitlab-webhook')
|
||||||
|
op.drop_table('gitlab-webhook')
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
"""create user-repos table
|
||||||
|
|
||||||
|
Revision ID: 027
|
||||||
|
Revises: 026
|
||||||
|
Create Date: 2025-04-14
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '028'
|
||||||
|
down_revision: Union[str, None] = '027'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'repos',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), primary_key=True),
|
||||||
|
sa.Column('repo_name', sa.String(), nullable=False),
|
||||||
|
sa.Column('repo_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('is_public', sa.Boolean(), nullable=False),
|
||||||
|
sa.Column('has_microagent', sa.Boolean(), nullable=True),
|
||||||
|
sa.Column('has_setup_script', sa.Boolean(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
'idx_repos_repo_id',
|
||||||
|
'repos',
|
||||||
|
['repo_id'],
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
'user-repos',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('repo_id', sa.String(), nullable=False),
|
||||||
|
sa.Column('admin', sa.Boolean(), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_index(
|
||||||
|
'idx_user_repos_repo_id',
|
||||||
|
'user-repos',
|
||||||
|
['repo_id'],
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
'idx_user_repos_user_id',
|
||||||
|
'user-repos',
|
||||||
|
['user_id'],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('idx_repos_repo_id', 'repos')
|
||||||
|
op.drop_index('idx_user_repos_repo_id', 'user-repos')
|
||||||
|
op.drop_index('idx_user_repos_user_id', 'user-repos')
|
||||||
|
op.drop_table('repos')
|
||||||
|
op.drop_table('user-repos')
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
"""add accepted_tos to user_settings
|
||||||
|
|
||||||
|
Revision ID: 029
|
||||||
|
Revises: 028
|
||||||
|
Create Date: 2025-04-23
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '029'
|
||||||
|
down_revision: Union[str, None] = '028'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'user_settings', sa.Column('accepted_tos', sa.DateTime(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('user_settings', 'accepted_tos')
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
"""add proactive conversation starters column
|
||||||
|
|
||||||
|
Revision ID: 030
|
||||||
|
Revises: 029
|
||||||
|
Create Date: 2025-04-30
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '030'
|
||||||
|
down_revision = '029'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.add_column(
|
||||||
|
'user_settings',
|
||||||
|
sa.Column(
|
||||||
|
'enable_proactive_conversation_starters',
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
default=True,
|
||||||
|
server_default='TRUE',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_column('user_settings', 'enable_proactive_conversation_starters')
|
||||||
36
enterprise/migrations/versions/031_add_user_secrets_store.py
Normal file
36
enterprise/migrations/versions/031_add_user_secrets_store.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""create user secrets table
|
||||||
|
|
||||||
|
Revision ID: 031
|
||||||
|
Revises: 030
|
||||||
|
Create Date: 2024-03-11 23:39:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '031'
|
||||||
|
down_revision: Union[str, None] = '030'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
'user_secrets',
|
||||||
|
sa.Column('id', sa.Integer(), sa.Identity(), nullable=False, primary_key=True),
|
||||||
|
sa.Column('keycloak_user_id', sa.String(), nullable=True),
|
||||||
|
sa.Column('custom_secrets', sa.JSON(), nullable=True),
|
||||||
|
)
|
||||||
|
# Create indexes for faster lookups
|
||||||
|
op.create_index(
|
||||||
|
'idx_user_secrets_keycloak_user_id', 'user_secrets', ['keycloak_user_id']
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index('idx_user_secrets_keycloak_user_id', 'user_secrets')
|
||||||
|
op.drop_table('user_secrets')
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""add status column to gitlab-webhook table
|
||||||
|
|
||||||
|
Revision ID: 032
|
||||||
|
Revises: 031
|
||||||
|
Create Date: 2025-04-21
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '032'
|
||||||
|
down_revision: Union[str, None] = '031'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.rename_table('gitlab-webhook', 'gitlab_webhook')
|
||||||
|
|
||||||
|
op.add_column(
|
||||||
|
'gitlab_webhook',
|
||||||
|
sa.Column(
|
||||||
|
'last_synced',
|
||||||
|
sa.DateTime(),
|
||||||
|
server_default=sa.text('now()'),
|
||||||
|
onupdate=sa.text('now()'),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.drop_column('gitlab_webhook', 'webhook_name')
|
||||||
|
|
||||||
|
op.alter_column(
|
||||||
|
'gitlab_webhook',
|
||||||
|
'scopes',
|
||||||
|
existing_type=sa.String,
|
||||||
|
type_=sa.ARRAY(sa.Text()),
|
||||||
|
existing_nullable=True,
|
||||||
|
postgresql_using='ARRAY[]::text[]',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'gitlab_webhook', sa.Column('webhook_name', sa.Boolean(), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Drop the new column from the renamed table
|
||||||
|
op.drop_column('gitlab_webhook', 'last_synced')
|
||||||
|
|
||||||
|
# Rename the table back
|
||||||
|
op.rename_table('gitlab_webhook', 'gitlab-webhook')
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user